# K-Means Clustering for Macro-Actions

### Imports

In [None]:
import os
from datetime import datetime

import minerl
import numpy as np
import tqdm
from minerl.data import BufferedBatchIter
from sklearn.cluster import KMeans


### Creating Data Dir (Local)

In [None]:
data_path = os.path.join(os.getcwd(), "data")

if not os.path.exists(data_path):
    os.mkdir(data_path)

os.environ['MINERL_DATA_ROOT'] = data_path # Important

### Globals

In [None]:
OBF_ENVS = ['MineRLTreechopVectorObf-v0', "MineRLObtainDiamondVectorObf-v0"]
ENVIRONMENT = 'MineRLTreechopVectorObf-v0'
NUM_CLUSTERS = 32 # Number of actions we want to extract
NUM_BATCHES = 1000

#### Data Download

In [None]:
# Downloading environment data if not exists
env_data_path = os.path.join(data_path, ENVIRONMENT)
if not os.path.exists(env_data_path):
    minerl.data.download(data_path, environment = ENVIRONMENT) # Careful

### Main
Samples the dataset storing `NUM_BATCHES` batches of actions. Then performs KMeans clustering to 
find `NUM_CLUSTERS` actions that represent reasonable actions for our agent to take. 

In [None]:
data = minerl.data.make(environment = ENVIRONMENT)

# Load the dataset storing NUM_BATCHES batches of actions
act_vectors = []
for _, act, _, _,_ in tqdm.tqdm(data.batch_iter(16, 32, 2, preload_buffer_size=20)):
    act_vectors.append(act['vector'])
    if len(act_vectors) > NUM_BATCHES:
        break

# Reshape these the action batches
acts = np.concatenate(act_vectors).reshape(-1, 64)
kmeans_acts = acts[:100000] # ?

# Use sklearn to cluster the demonstrated actions
kmeans = KMeans(n_clusters=NUM_CLUSTERS, random_state=0).fit(kmeans_acts)

In [None]:
# Resultant array of n actions
kmeans.cluster_centers_

In [None]:
# Sampling a random action from our n actions
# kmeans.cluster_centers_[np.random.choice(NUM_CLUSTERS)]

# Save action set
date_suffix = datetime.now().strftime('%m%d%M')
filename = f"data/action_sets/action_set_{ENVIRONMENT}_{NUM_CLUSTERS}_{date_suffix}.npy"
np.save(filename, kmeans.cluster_centers_)

# Load action set
# np.load(filename, kmeans.cluster_centers_)
