In [38]:
import os
from datetime import datetime
from tkinter import ALL

import minerl
import gym
import numpy as np
import tqdm
from kmodes.kprototypes import KPrototypes

import numpy as np
from numpy import array, float32
import pandas as pd

from pathlib import Path

In [39]:
# Hyperparameters
NUM_CLUSTERS = 6 # Number of Macro Actions we want to extract

NUM_EPOCHS = 2
BATCH_SIZE = 10
MAX_ACTIONS = 5000

ENVIRONMENT = 'MineRLObtainDiamond-v0'

In [106]:
# Initial setup
data_path = str(Path().absolute().parent.parent.joinpath('data'))

if not os.path.exists(data_path):
    os.mkdir(data_path)

os.environ['MINERL_DATA_ROOT'] = data_path # Important

# Downloading environment data if it doesn't exist
env_data_path = os.path.join(data_path, ENVIRONMENT)

if not os.path.exists(env_data_path):
    # os.mkdir(f'data/{ENVIRONMENT}')
    minerl.data.download(environment = ENVIRONMENT)

In [65]:
def decode_action(obj):
    proc = {}

    for k, v in obj.items():
        v = v[0]
        if isinstance(v, np.ndarray) and v.size > 1:
            for i, dim in enumerate(v):
                proc[f"{k}{i}"] = dim
        else:
            proc[k] = v.tolist() if isinstance(v, np.ndarray) else v
        
    return proc

def encode_action(obj):
    proc = {}

    for k, v in obj.items():
        if 'camera' not in k:
            try:
                proc[k] = array(int(round(float(v))))
            except:
                proc[k] = v
    
    proc['camera'] = array([obj.get('camera0'), obj.get('camera0')], dtype=float32)
    return proc

In [42]:
from minerl.data import BufferedBatchIter
data = minerl.data.make(ENVIRONMENT)
iterator = BufferedBatchIter(data)
i = 0
collected_actions = []
for current_state, action, reward, next_state, done in iterator.buffered_batch_iter(batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS):
    collected_actions.append(decode_action(action))
    
    i += 1
    if i == MAX_ACTIONS:
        break

df = pd.DataFrame(collected_actions)

100%|██████████| 3208/3208 [00:00<00:00, 39601.75it/s]
100%|██████████| 13009/13009 [00:00<00:00, 26267.13it/s]
100%|██████████| 46648/46648 [00:01<00:00, 27705.16it/s]
100%|██████████| 13196/13196 [00:00<00:00, 34547.61it/s]
100%|██████████| 6581/6581 [00:00<00:00, 31494.63it/s]
100%|██████████| 11572/11572 [00:00<00:00, 18606.58it/s]


In [44]:
df

Unnamed: 0,attack,back,camera0,camera1,craft,equip,forward,jump,left,nearbyCraft,nearbySmelt,place,right,sneak,sprint
0,1,0,0.000000,0.000000,none,none,0,0,0,none,none,none,0,0,0
1,0,0,-0.000000,0.285040,none,stone_pickaxe,0,0,0,none,none,none,0,0,0
2,0,0,0.000000,0.000000,none,none,0,0,0,none,none,none,0,0,0
3,1,0,0.000000,0.000000,none,none,0,0,0,none,none,none,0,0,0
4,0,0,-1.350002,-0.600006,none,none,0,0,0,none,none,none,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,0,0,0.000000,0.000000,none,none,0,0,0,none,none,none,0,0,0
4996,1,0,-2.250000,-0.150024,none,none,0,0,0,none,none,none,0,0,0
4997,1,0,0.000000,0.000000,none,none,0,0,0,none,none,none,0,1,0
4998,1,0,0.000000,0.000000,none,none,0,0,0,none,none,none,0,0,0


In [45]:
mark_array=df.values

In [46]:
categorical_features_idx = [4,5,9,10,11]

In [47]:
kproto = KPrototypes(n_clusters=NUM_CLUSTERS, max_iter=200).fit(mark_array, categorical=categorical_features_idx)

In [48]:
actions_list = ['attack', 'back', 'camera0', 'camera1', 
    'forward', 'jump', 'left',  'right', 'sneak','sprint', 
    'craft', 'equip', 'nearbyCraft', 'nearbySmelt', 'place']

In [67]:
extracted_actions = []
for cluster in kproto.cluster_centroids_:
    extracted_actions.append(encode_action({actions_list[i]: cluster[i] for i in range(len(cluster))}))

### Test on ENV

In [66]:
REPEAT_ACTION = 100
MAX_RENDER = 100000
env = gym.make(ENVIRONMENT)
done = False
obs = env.reset()

i =  0
CURR_ACTION = 0

while not done:
    if i % REPEAT_ACTION == 0:
        CURR_ACTION = np.random.randint(len(extracted_actions))
    
    action = encode_action(extracted_actions[CURR_ACTION])
    obs, reward, done, info = env.step(action)
    env.render()
    i += 1

    if i == MAX_RENDER:
        done = True



IndexError: list index out of range