In [13]:
import h5py
import os

root = '/home/yigit/projects/mbcnp/data/raw/mocapact/'
files = []

# Iterate directory
for file_path in os.listdir(root):
    if file_path.endswith('.hdf5') and os.path.isfile(os.path.join(root, file_path)):
        if 'CMU_016_40.hdf5' in file_path or 'CMU_007_12.hdf5' in file_path:
            # add filename to list
            files.append(file_path)
print(files)

['CMU_016_40.hdf5', 'CMU_007_12.hdf5']


In [2]:
import numpy as np

desired_observables = ['actuator_activation', 'joints_pos', 'joints_vel', 'sensors_gyro', 'end_effectors_pos', 
                       'sensors_torque', 'sensors_touch', 'sensors_velocimeter']

def get_obs_indices(path):
    indices = []

    f = h5py.File(path, 'r+')
    walker_obs_dict = f['observable_indices']['walker']
    for k in walker_obs_dict.keys():
        if k in desired_observables:
            dum = walker_obs_dict[k][:]
            indices.extend(dum)
    f.close()

    return np.array(indices)

# Get indices
indices = get_obs_indices(os.path.join(root, 'CMU_016_40.hdf5'))

In [3]:
#region read mocapact data
full_obs, full_act = [], []

for file in files:
    fp = os.path.join(root, file)
    # Open file
    f = h5py.File(fp, 'r+')

    demos = {}

    num_start_rollouts = f['n_start_rollouts'][()]  # concatenate snippets to create this many rollouts
    for i in range(num_start_rollouts):
        demos.update({i: {}})
        demos[i].update({'obs': {}})
        demos[i].update({'act': {}})
    
    num_snippets = 0
    for key in f.keys():
        if key.startswith('CMU_'):
            num_snippets += 1

    for key in f.keys():
        if key.startswith('CMU_'):
            start, end = int(key.split('-')[-2]), int(key.split('-')[-1])
            for i in range(num_start_rollouts):
                obs = np.array(f[key][str(i)]['observations']['proprioceptive'])
                act = np.array(f[key][str(i)]['actions'])
                for j in range(len(act)):
                    demos[i]['obs'].update({start+j: obs[j, indices]})
                    demos[i]['act'].update({start+j: act[j]})

    for key in f.keys():
        for i in range(num_start_rollouts):
            if key.startswith('CMU_') and f[key]['early_termination'][i] == True:
                if i in demos.keys():
                    demos.pop(i)

    for key in demos.keys():
        full_obs.append(np.array(list(demos[key]['obs'].values())))
        full_act.append(np.array(list(demos[key]['act'].values())))

    f.close()

print(len(full_obs), len(full_act))
#endregion
min_length = 1000
for i in range(len(full_obs)):
    if len(full_obs[i]) < min_length:
        min_length = len(full_obs[i])

processed_obs, processed_act = [], []
for i in range(len(full_obs)):
    processed_obs.append(full_obs[i][np.linspace(0, len(full_obs[i])-1, min_length, dtype=int)])
    processed_act.append(full_act[i][np.linspace(0, len(full_obs[i])-1, min_length, dtype=int)])


20 20


In [4]:
from models.wta_cnp import WTA_CNP
import torch

def get_available_gpu_with_most_memory():
    gpu_memory = []
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)  # Switch to the GPU to accurately measure memory
        gpu_memory.append((i, torch.cuda.memory_stats()['reserved_bytes.all.current'] / (1024 ** 2)))

    gpu_memory.sort(key=lambda x: x[1], reverse=True)

    return gpu_memory[0][0]

if torch.cuda.is_available():
    available_gpu = get_available_gpu_with_most_memory()
    if available_gpu == 0:
        device = torch.device("cuda:0")
    else:
        device = torch.device(f"cuda:{available_gpu}")
else:
    device = torch.device("cpu")

print("Device :", device)

###

torch.set_float32_matmul_precision('high')

Device : cuda:0


In [9]:
batch_size = 2
n_max_obs, n_max_tar = 5, 5

t_steps = min_length
num_val = 4
num_demos = len(full_obs)-num_val
num_classes = 2
num_indiv = num_demos//num_classes  # number of demos per class

dx, dy = 1, len(full_act[0][0])

num_val_indiv = num_val//num_classes

colors = ['tomato', 'aqua']

In [11]:
x = torch.zeros(num_demos, t_steps, dx, device=device)
y = torch.zeros(num_demos, t_steps, dy, device=device)
vx = torch.zeros(num_val, t_steps, dx, device=device)
vy = torch.zeros(num_val, t_steps, dy, device=device)

ind = torch.randperm(len(full_obs))
vind = torch.cat((torch.randint(0, num_indiv, (num_val_indiv, 1)), torch.randint(num_indiv, num_demos, (num_val_indiv, 1))), dim=0)
tr_ctr, val_ctr = 0, 0

for i in range(len(full_obs)):
    if i in vind:
        # vx[val_ctr] = torch.tensor(processed_obs[i], dtype=torch.float32)
        vx[val_ctr] = torch.linspace(0, 1, t_steps, dtype=torch.float32).view_as(vx[val_ctr])
        vy[val_ctr] = torch.tensor(processed_act[i], dtype=torch.float32)
        val_ctr += 1
    else:
        # x[tr_ctr] = torch.tensor(processed_obs[i], dtype=torch.float32)
        x[tr_ctr] = torch.linspace(0, 1, t_steps, dtype=torch.float32).view_as(x[tr_ctr])
        y[tr_ctr] = torch.tensor(processed_act[i], dtype=torch.float32)
        tr_ctr += 1

print("X:", x.shape, "Y:", y.shape, "VX:", vx.shape, "VY:", vy.shape)

X: torch.Size([16, 45, 1]) Y: torch.Size([16, 45, 56]) VX: torch.Size([4, 45, 1]) VY: torch.Size([4, 45, 56])


In [12]:
def get_batch(x, y, traj_ids, device=device):
    n_o = torch.randint(1, n_max_obs, (1,)).item()
    n_t = torch.randint(1, n_max_tar, (1,)).item()
    
    tar = torch.zeros(batch_size, n_t, dx, device=device)
    tar_val = torch.zeros(batch_size, n_t, dy, device=device)
    obs = torch.zeros(batch_size, n_o, dx+dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        
        o_ids = random_query_ids[:n_o]
        t_ids = random_query_ids[n_o:n_o+n_t]

        obs[i, :, :] = torch.cat((x[traj_ids[i], o_ids], y[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = x[traj_ids[i], t_ids]
        tar_val[i, :, :] = y[traj_ids[i], t_ids]

    return obs, tar, tar_val

def get_validation_batch(vx, vy, traj_ids, device=device):
    num_obs = torch.randint(1, n_max_obs, (1,)).item()

    obs = torch.zeros(batch_size, num_obs, dx+dy, device=device)
    tar = torch.zeros(batch_size, t_steps, dx, device=device)
    tar_val = torch.zeros(batch_size, t_steps, dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        o_ids = random_query_ids[:num_obs]

        obs[i, :, :] = torch.cat((vx[traj_ids[i], o_ids], vy[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = vx[traj_ids[i]]
        tar_val[i, :, :] = vy[traj_ids[i]]

    return obs, tar, tar_val

In [8]:
model_wta_ = WTA_CNP(dx, dy, n_max_obs, n_max_tar, [1024, 1024, 1024], num_decoders=2, decoder_hidden_dims=[512, 512, 512], batch_size=batch_size, scale_coefs=True).to(device)
optimizer_wta = torch.optim.Adam(lr=1e-4, params=model_wta_.parameters())

if torch.__version__ >= "2.0":
    model_wta = torch.compile(model_wta_)

In [9]:
import time
import os

timestamp = int(time.time())
root_folder = f'outputs/experimental/{dy}D/{str(timestamp)}/'

if not os.path.exists(root_folder):
    os.makedirs(root_folder)

if not os.path.exists(f'{root_folder}saved_models/'):
    os.makedirs(f'{root_folder}saved_models/')

# if not os.path.exists(f'{root_folder}img/'):
#     os.makedirs(f'{root_folder}img/')

torch.save(y, f'{root_folder}y.pt')


epochs = 5_000_000
epoch_iter = num_demos//batch_size  # number of batches per epoch (e.g. 100//32 = 3)
v_epoch_iter = num_val//batch_size  # number of batches per validation (e.g. 100//32 = 3)
avg_loss_wta = 0

val_per_epoch = 1000
min_val_loss_wta = 1000000

mse_loss = torch.nn.MSELoss()

training_loss_wta, validation_error_wta = [], []

wta_tr_loss_path = f'{root_folder}wta_training_loss.pt'
wta_val_err_path = f'{root_folder}wta_validation_error.pt'

for epoch in range(epochs):
    epoch_loss_wta = 0

    # traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)  # [:batch_size*epoch_iter] because nof_trajectories may be indivisible by batch_size
    traj_ids, v_traj_ids = [], []
    inds = torch.randperm(num_indiv)
    vinds = torch.randperm(num_val_indiv)
    for i in inds:
        traj_ids.append([inds[i], num_demos-inds[i]-1])

    for i in vinds:
        v_traj_ids.append([vinds[i], num_val-vinds[i]-1])

    for i in range(epoch_iter):
        optimizer_wta.zero_grad()

        obs_wta, tar_x_wta, tar_y_wta = get_batch(x, y, traj_ids[i], device)
        pred_wta, gate_wta = model_wta(obs_wta, tar_x_wta)
        loss_wta, wta_nll = model_wta.loss(pred_wta, gate_wta, tar_y_wta)
        loss_wta.backward()
        optimizer_wta.step()

        epoch_loss_wta += wta_nll.item()

    training_loss_wta.append(epoch_loss_wta)

    if epoch % val_per_epoch == 0:
        with torch.no_grad():
            # v_traj_ids = torch.randperm(vx.shape[0])[:batch_size*v_epoch_iter].chunk(v_epoch_iter)
            val_loss_wta = 0

            for j in range(v_epoch_iter):
                o_wta, t_wta, tr_wta = get_validation_batch(vx, vy, v_traj_ids[j], device=device)

                p_wta, g_wta = model_wta(o_wta, t_wta)
                dec_id = torch.argmax(g_wta.squeeze(1), dim=-1)
                vp_means = p_wta[dec_id, torch.arange(batch_size), :, :dy]
                val_loss_wta += mse_loss(vp_means, tr_wta).item()

            validation_error_wta.append(val_loss_wta)
            if val_loss_wta < min_val_loss_wta:
                min_val_loss_wta = val_loss_wta
                print(f'(WTA)New best: {min_val_loss_wta}')
                torch.save(model_wta_.state_dict(), f'{root_folder}saved_models/wta_on_synth.pt')
  
        # if epoch % (val_per_epoch*10) == 0:
        #     draw_val_plot(root_folder, epoch)


    avg_loss_wta += epoch_loss_wta

    if epoch % val_per_epoch == 0:
        print("Epoch: {}, WTA-Loss: {}".format(epoch, avg_loss_wta/val_per_epoch))
        avg_loss_wta = 0

torch.save(torch.Tensor(training_loss_wta), wta_tr_loss_path)
torch.save(torch.Tensor(validation_error_wta), wta_val_err_path)

(WTA)New best: 0.7975544929504395
Epoch: 0, WTA-Loss: 0.008055461645126343
(WTA)New best: 0.061571910977363586
Epoch: 1000, WTA-Loss: -3.4550922681065277
(WTA)New best: 0.05285011604428291
Epoch: 2000, WTA-Loss: -6.927817085117102
(WTA)New best: 0.05199756287038326
Epoch: 3000, WTA-Loss: -8.845476897746325
Epoch: 4000, WTA-Loss: -10.397426395207644
Epoch: 5000, WTA-Loss: -11.668039237648248
Epoch: 6000, WTA-Loss: -12.753860163927078
Epoch: 7000, WTA-Loss: -13.696185856878758
Epoch: 8000, WTA-Loss: -14.404736883461476
Epoch: 9000, WTA-Loss: -15.055532764852048
Epoch: 10000, WTA-Loss: -15.657686552882195
Epoch: 11000, WTA-Loss: -16.117637781977653
Epoch: 12000, WTA-Loss: -16.529323966920376
Epoch: 13000, WTA-Loss: -16.91634295386076
Epoch: 14000, WTA-Loss: -17.29427516967058
Epoch: 15000, WTA-Loss: -17.576423897475003
Epoch: 16000, WTA-Loss: -17.88244921001792
Epoch: 17000, WTA-Loss: -18.067207051753996
Epoch: 18000, WTA-Loss: -18.36862338787317
Epoch: 19000, WTA-Loss: -18.57625954842567

KeyboardInterrupt: 

In [16]:
import numpy as np

In [17]:
import torch
from models.wta_cnp import WTA_CNP

root_folder = f'outputs/experimental/56D/1701871167/'
wta_model_path = f'{root_folder}saved_models/wta_on_synth.pt'

y = torch.load(f'{root_folder}y.pt').cpu()
num_samples, t_steps, dy = y.shape
dx = 205
batch_size = 1
n_max_obs, n_max_tar = 6, 6

wta = WTA_CNP(dx, dy, n_max_obs, n_max_tar, [1024, 1024, 1024], num_decoders=2, decoder_hidden_dims=[512, 512, 512], batch_size=batch_size, scale_coefs=True).to(device)

wta.load_state_dict(torch.load(wta_model_path))
wta.eval()

WTA_CNP(
  (encoder): Sequential(
    (0): Linear(in_features=261, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (decoders): ModuleList(
    (0-1): 2 x Sequential(
      (0): Linear(in_features=1229, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): ReLU()
      (4): Linear(in_features=512, out_features=112, bias=True)
    )
  )
  (gate): Sequential(
    (0): Linear(in_features=1024, out_features=2, bias=True)
    (1): Softmax(dim=-1)
  )
)

In [56]:
import time
from dm_control import viewer
from dm_control import composer
from dm_control.locomotion import arenas
from dm_control.locomotion.tasks import go_to_target
from dm_control.locomotion.walkers import cmu_humanoid
from dm_control_wrapper import StandInitializer
from dm_control.composer import ObservationPadding

seed = 0
dind = 13

def prepare_obs(obs, ind, a):
    vv = []
    for k in obs.keys():
        real_key = k.split('/')[1]
        if real_key in desired_observables:
            vals = obs[k].flatten()
            vv.extend([vals])
    if ind == 0:
        vv.extend([y[dind, ind].numpy()])
    else:
        vv.extend([a])
    v = np.concatenate(vv).reshape(-1)
    return torch.from_numpy(v).view(1, 1, dx+dy).float().to(device)


def prepare_tar(ind):
    return x[dind, ind].view(1, 1, dx).float().to(device)

initializer = StandInitializer()
walker = cmu_humanoid.CMUHumanoidPositionControlledV2020(initializer=initializer)

# Build an empty arena.
arena = arenas.Floor()

# Build a task that rewards the agent for tracking motion capture reference
# data.
task = go_to_target.GoToTarget(walker=walker, arena=arena, physics_timestep=0.005, control_timestep=0.03)
env = composer.Environment(task=task, random_state=seed)
# print(env.control_timestep())

ind = -1
inst_a = None

def tst(ts):
    global ind, inst_a
    ind += 1
    # dm_obs, dm_tar = prepare_obs(ts.observation, ind, inst_a), prepare_tar(ind+1)
    # p_wta, g_wta = wta(dm_obs, dm_tar)
    # # print(g_wta.squeeze(1))
    # inst_a = p_wta[torch.argmax(g_wta.squeeze(1), dim=-1), 0, 0, :dy].cpu().detach().numpy().squeeze()
    # return inst_a
    return y[dind, ind].numpy()
    # return full_act[dind][ind]

# Viewer for visualization
viewer.launch(env, policy=tst)

# Close the environment
env.close()

In [None]:
# dm_obs = prepare_obs(env.reset()[3], 0)
# dm_tar = prepare_tar(1)

# for i in range(1, t_steps):
#     # p_wta, g_wta = wta(dm_obs, dm_tar)
#     # a = p_wta[torch.argmax(g_wta.squeeze(1), dim=-1), 0, 0, :dy].cpu().detach().numpy().squeeze()
#     # print(a)
#     a = y[0, i].numpy()
#     s = env.step(a)
#     # dm_obs, dm_tar = prepare_obs(s.observation, i-1), prepare_tar(i)

numpy.ndarray

In [None]:
for k in f['observable_indices']['walker']:
    if 'reference' not in k:
        print(k, type(f['observable_indices']['walker'][k][()]))

actuator_activation <class 'numpy.ndarray'>
appendages_pos <class 'numpy.ndarray'>
body_height <class 'numpy.ndarray'>
end_effectors_pos <class 'numpy.ndarray'>
gyro_anticlockwise_spin <class 'numpy.ndarray'>
gyro_backward_roll <class 'numpy.ndarray'>
gyro_control <class 'numpy.ndarray'>
gyro_rightward_roll <class 'numpy.ndarray'>
head_height <class 'numpy.ndarray'>
joints_pos <class 'numpy.ndarray'>
joints_vel <class 'numpy.ndarray'>
joints_vel_control <class 'numpy.ndarray'>
orientation <class 'numpy.ndarray'>
position <class 'numpy.ndarray'>
sensors_accelerometer <class 'numpy.ndarray'>
sensors_gyro <class 'numpy.ndarray'>
sensors_torque <class 'numpy.ndarray'>
sensors_touch <class 'numpy.ndarray'>
sensors_velocimeter <class 'numpy.ndarray'>
time_in_clip <class 'numpy.ndarray'>
torso_xvel <class 'numpy.ndarray'>
torso_yvel <class 'numpy.ndarray'>
veloc_forward <class 'numpy.ndarray'>
veloc_strafe <class 'numpy.ndarray'>
veloc_up <class 'numpy.ndarray'>
velocimeter_control <class 'nu

In [None]:
for k in s[3].keys():
    print(k.replace('/', '.').split('.')[-1], s[3][k].shape)

actuator_activation (1, 56)
appendages_pos (1, 15)
body_height (1,)
end_effectors_pos (1, 12)
joints_pos (1, 56)
joints_vel (1, 56)
sensors_accelerometer (1, 3)
sensors_force (1, 0)
sensors_gyro (1, 3)
sensors_torque (1, 6)
sensors_touch (1, 10)
sensors_velocimeter (1, 3)
world_zaxis (1, 3)
target (1, 3)


In [None]:
# def transform_data(data):
#     num_dimensions = data.shape[2]

#     transformation_matrix = torch.zeros((num_dimensions, 2))
#     transformed_data = torch.zeros_like(data)

#     # Apply transformations to each dimension
#     for i in range(num_dimensions):
#         dim_data = data[:, :, i]

#         min_val = dim_data.min()
#         max_val = dim_data.max()

#         transformation_matrix[i, 0] = min_val
#         transformation_matrix[i, 1] = max_val

#         interval = max_val - min_val
#         if interval < 1e-6:
#             interval = 1

#         transformed_dim = 2 * (dim_data - min_val) / interval - 1
#         transformed_data[:, :, i] = transformed_dim

#     return transformed_data, transformation_matrix

# def reconstruct_data(transformed_data, transformation_matrix):
#     num_dimensions = transformed_data.shape[2]

#     reconstructed_data = torch.zeros_like(transformed_data)

#     for i in range(num_dimensions):
#         transformed_dim = transformed_data[:, :, i]
#         min_val, max_val = transformation_matrix[i, 0], transformation_matrix[i, 1]

#         reconstructed_dim = ((transformed_dim + 1) / 2) * (max_val - min_val) + min_val
#         reconstructed_data[:, :, i] = reconstructed_dim

#     return reconstructed_data

# y = data.clone().to(device)
# x = torch.unsqueeze(torch.linspace(0, 1, t_steps).repeat(num_demos, 1), -1).to(device)

# vx = x.clone()
# noise = torch.clamp(torch.randn(x.shape)*1e-4**0.5, min=0).to(device)
# vy = y.clone() + noise

# print("X:", x.shape, "Y:", y.shape, "VX:", vx.shape, "VY:", vy.shape)

X: torch.Size([1, 781, 1]) Y: torch.Size([1, 781, 62]) VX: torch.Size([1, 781, 1]) VY: torch.Size([1, 781, 62])
