In [1]:
from models.cnp import CNP
import torch

def get_available_gpu_with_most_memory():
    gpu_memory = []
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)  # Switch to the GPU to accurately measure memory
        gpu_memory.append((i, torch.cuda.memory_stats()['reserved_bytes.all.current'] / (1024 ** 2)))

    gpu_memory.sort(key=lambda x: x[1], reverse=True)

    return gpu_memory[0][0]

if torch.cuda.is_available():
    available_gpu = get_available_gpu_with_most_memory()
    if available_gpu == 0:
        device = torch.device("cuda:0")
    else:
        device = torch.device(f"cuda:{available_gpu}")
else:
    device = torch.device("cpu")

print("Device :", device)

###

torch.set_float32_matmul_precision('high')

Device : cpu


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
import numpy as np

def normalize_data(data):
    # Calculate the mean and standard deviation across all data points
    mean = np.mean(data, axis=(0, 1))
    std = np.std(data, axis=(0, 1))
    
    # Handle the case where std is zero (to avoid division by zero)
    std_replaced = np.where(std == 0, 1, std)
    
    # Normalize the data
    normalized_data = (data - mean) / std_replaced
    return normalized_data, mean, std_replaced


walk_heavy_actions, (walk_heavy_observations, _, _) = np.load("data/mocapact/awh.npy"), normalize_data(np.load("data/mocapact/owh.npy"))
run_circle_actions, (run_circle_observations, _, _) = np.load("data/mocapact/arc.npy"), normalize_data(np.load("data/mocapact/orc.npy"))

In [3]:
batch_size = 2
n_max_obs, n_max_tar = 32, 32

num_indiv, t_steps, dx = walk_heavy_observations.shape
_, _, dy = walk_heavy_actions.shape
num_indiv -= 6

num_val = 12
num_classes = 2
num_demos = num_indiv*num_classes

num_val_indiv = num_val//num_classes

colors = ['tomato', 'aqua']

In [4]:
x = torch.zeros(num_demos, t_steps, dx, device=device)
y = torch.zeros(num_demos, t_steps, dy, device=device)
vx = torch.zeros(num_val, t_steps, dx, device=device)
vy = torch.zeros(num_val, t_steps, dy, device=device)

vind = torch.randint(0, num_indiv+num_val_indiv, (num_val_indiv, 1))
tr_ctr, val_ctr = 0, 0
for cur_vind in vind:
    vx[val_ctr] = torch.from_numpy(walk_heavy_observations[cur_vind]).to(device)
    vx[val_ctr+1] = torch.from_numpy(run_circle_observations[cur_vind]).to(device)
    vy[val_ctr] = torch.from_numpy(walk_heavy_actions[cur_vind]).to(device)
    vy[val_ctr+1] = torch.from_numpy(run_circle_actions[cur_vind]).to(device)

    val_ctr += 2

i=0
while i*2 < num_val_indiv+num_indiv:
    if i in vind:
       pass
    else:
        x[tr_ctr] = torch.from_numpy(walk_heavy_observations[i]).to(device)
        y[tr_ctr] = torch.from_numpy(walk_heavy_actions[i]).to(device)
        x[tr_ctr+num_indiv] = torch.from_numpy(run_circle_observations[i]).to(device)
        y[tr_ctr+num_indiv] = torch.from_numpy(run_circle_actions[i]).to(device)
        tr_ctr += 1
    i += 1

print("X:", x.shape, "Y:", y.shape, "VX:", vx.shape, "VY:", vy.shape)

X: torch.Size([62, 208, 287]) Y: torch.Size([62, 208, 56]) VX: torch.Size([12, 208, 287]) VY: torch.Size([12, 208, 56])


In [5]:
def get_batch(x, y, traj_ids, device=device):
    n_o = torch.randint(1, n_max_obs, (1,)).item()
    n_t = torch.randint(1, n_max_tar, (1,)).item()
    
    tar = torch.zeros(batch_size, n_t, dx, device=device)
    tar_val = torch.zeros(batch_size, n_t, dy, device=device)
    obs = torch.zeros(batch_size, n_o, dx+dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        
        o_ids = random_query_ids[:n_o]
        t_ids = random_query_ids[n_o:n_o+n_t]

        obs[i, :, :] = torch.cat((x[traj_ids[i], o_ids], y[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = x[traj_ids[i], t_ids]
        tar_val[i, :, :] = y[traj_ids[i], t_ids]

    return obs, tar, tar_val

def get_validation_batch(vx, vy, traj_ids, device=device):
    num_obs = torch.randint(1, n_max_obs, (1,)).item()

    obs = torch.zeros(batch_size, num_obs, dx+dy, device=device)
    tar = torch.zeros(batch_size, t_steps, dx, device=device)
    tar_val = torch.zeros(batch_size, t_steps, dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        o_ids = random_query_ids[:num_obs]

        obs[i, :, :] = torch.cat((vx[traj_ids[i], o_ids], vy[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = vx[traj_ids[i]]
        tar_val[i, :, :] = vy[traj_ids[i]]

    return obs, tar, tar_val

In [6]:
model_cnp_ = CNP(input_dim=dx, hidden_dim=512, output_dim=dy, n_max_obs=n_max_obs, n_max_tar=n_max_tar, num_layers=4, batch_size=batch_size).to(device)
optimizer_cnp = torch.optim.Adam(lr=3e-5, params=model_cnp_.parameters())

def get_parameter_count(model):
    total_num = 0
    for param in model.parameters():
        total_num += param.shape.numel()
    return total_num

print("CNP:", get_parameter_count(model_cnp_))


if torch.__version__ >= "2.0":
    model_cnp = torch.compile(model_cnp_)

CNP: 1956464


In [7]:
import time
import os

timestamp = int(time.time())
root_folder = f'outputs/mocapact/{dy}D/{str(timestamp)}/'

if not os.path.exists(root_folder):
    os.makedirs(root_folder)

if not os.path.exists(f'{root_folder}saved_models/'):
    os.makedirs(f'{root_folder}saved_models/')

# if not os.path.exists(f'{root_folder}img/'):
#     os.makedirs(f'{root_folder}img/')

torch.save(y, f'{root_folder}y.pt')


epochs = 1_000_000
epoch_iter = num_demos//batch_size  # number of batches per epoch (e.g. 100//32 = 3)
v_epoch_iter = num_val//batch_size  # number of batches per validation (e.g. 100//32 = 3)
avg_loss_cnp = 0

val_per_epoch = 1000
min_val_loss_cnp = 1000000

mse_loss = torch.nn.MSELoss()

training_loss_cnp, validation_error_cnp = [], []

cnp_tr_loss_path = f'{root_folder}cnp_training_loss.pt'
cnp_val_err_path = f'{root_folder}cnp_validation_error.pt'

for epoch in range(epochs):
    epoch_loss_cnp = 0

    # traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)  # [:batch_size*epoch_iter] because nof_trajectories may be indivisible by batch_size
    traj_ids, v_traj_ids = [], []
    inds = torch.randperm(num_indiv)
    vinds = torch.randperm(num_val)[:num_val_indiv]

    for i in range(inds.shape[0]):
        first = inds[i] * torch.randint(1,3,(1,1)).item()  # randint changes the order if it returns 2. for input randomization
        second = num_demos-first-1
        traj_ids.append([first, second])

    for i in range(vinds.shape[0]):
        v_traj_ids.append([vinds[i], num_val-vinds[i]-1])

    for i in range(epoch_iter):
        optimizer_cnp.zero_grad()

        obs_cnp, tar_x_cnp, tar_y_cnp = get_batch(x, y, traj_ids[i], device)
        pred_cnp, _ = model_cnp(obs_cnp, tar_x_cnp)
        loss_cnp = model_cnp.loss(pred_cnp, tar_y_cnp)
        loss_cnp.backward()
        optimizer_cnp.step()

        epoch_loss_cnp += loss_cnp.item()

    training_loss_cnp.append(epoch_loss_cnp)

    if epoch % val_per_epoch == 0:
        with torch.no_grad():
            # v_traj_ids = torch.randperm(vx.shape[0])[:batch_size*v_epoch_iter].chunk(v_epoch_iter)
            val_loss_cnp = 0

            for j in range(v_epoch_iter):
                o_cnp, t_cnp, tr_cnp = get_validation_batch(vx, vy, v_traj_ids[j], device=device)

                p_cnp, _ = model_cnp(o_cnp, t_cnp)
                val_loss_cnp += mse_loss(p_cnp[:, :, :model_cnp.output_dim], tr_cnp)

            val_loss_cnp /= num_val
            validation_error_cnp.append(val_loss_cnp)
            print(f'Validation loss: {val_loss_cnp}')
            if val_loss_cnp < min_val_loss_cnp and epoch > 1e5:
                min_val_loss_cnp = val_loss_cnp
                print(f'New best: {min_val_loss_cnp}')
                torch.save(model_cnp_.state_dict(), f'{root_folder}saved_models/cnp_on_synth.pt')
  
        # if epoch % (val_per_epoch*10) == 0:
        #     draw_val_plot(root_folder, epoch)


    avg_loss_cnp += epoch_loss_cnp

    if epoch % val_per_epoch == 0:
        print("Epoch: {}, cnp-Loss: {}".format(epoch, avg_loss_cnp/val_per_epoch))
        avg_loss_cnp = 0

torch.save(torch.Tensor(training_loss_cnp), cnp_tr_loss_path)
torch.save(torch.Tensor(validation_error_cnp), cnp_val_err_path)

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


Validation loss: 0.20279599726200104
Epoch: 0, cnp-Loss: 0.03854569447040558
Validation loss: 0.010029243305325508
Epoch: 1000, cnp-Loss: 17.658447007417678
Validation loss: 0.009185771457850933
Epoch: 2000, cnp-Loss: 17.32927520042658
Validation loss: 0.009483774192631245
Epoch: 3000, cnp-Loss: 17.275837319433688
Validation loss: 0.009795695543289185
Epoch: 4000, cnp-Loss: 17.241886497676372
Validation loss: 0.01035606674849987
Epoch: 5000, cnp-Loss: 17.218094841182232
Validation loss: 0.010936410166323185
Epoch: 6000, cnp-Loss: 17.199477543771266
Validation loss: 0.011406578123569489
Epoch: 7000, cnp-Loss: 17.18525534069538




Validation loss: 0.012042392045259476
Epoch: 8000, cnp-Loss: 17.174116916954517
Validation loss: 0.012432021088898182
Epoch: 9000, cnp-Loss: 17.16534583592415
Validation loss: 0.012798115611076355
Epoch: 10000, cnp-Loss: 17.158465910375117
Validation loss: 0.013360819779336452
Epoch: 11000, cnp-Loss: 17.152914523780346
Validation loss: 0.013575556688010693
Epoch: 12000, cnp-Loss: 17.1484868516922
Validation loss: 0.01390207652002573
Epoch: 13000, cnp-Loss: 17.144867695569992
Validation loss: 0.014200146310031414
Epoch: 14000, cnp-Loss: 17.142008959412575
Validation loss: 0.014411796815693378
Epoch: 15000, cnp-Loss: 17.13975729238987
Validation loss: 0.014683328568935394
Epoch: 16000, cnp-Loss: 17.137812535464764
Validation loss: 0.01483162958174944
Epoch: 17000, cnp-Loss: 17.13644407939911
Validation loss: 0.015044915489852428
Epoch: 18000, cnp-Loss: 17.135176597595215
Validation loss: 0.01521748024970293
Epoch: 19000, cnp-Loss: 17.134154054522515
Validation loss: 0.015250771306455135


KeyboardInterrupt: 