In [1]:
from models.wta_cnp import WTA_CNP
import torch

def get_available_gpu_with_most_memory():
    gpu_memory = []
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)  # Switch to the GPU to accurately measure memory
        gpu_memory.append((i, torch.cuda.memory_stats()['reserved_bytes.all.current'] / (1024 ** 2)))

    gpu_memory.sort(key=lambda x: x[1], reverse=True)

    return gpu_memory[0][0]

if torch.cuda.is_available():
    available_gpu = get_available_gpu_with_most_memory()
    if available_gpu == 0:
        device = torch.device("cuda:0")
    else:
        device = torch.device(f"cuda:{available_gpu}")
else:
    device = torch.device("cpu")

print("Device :", device)

###

torch.set_float32_matmul_precision('high')

Device : cpu


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
import numpy as np

def normalize_data(data):
    # Calculate the mean and standard deviation across all data points
    mean = np.mean(data, axis=(0, 1))
    std = np.std(data, axis=(0, 1))
    
    # Handle the case where std is zero (to avoid division by zero)
    std_replaced = np.where(std == 0, 1, std)
    
    # Normalize the data
    normalized_data = (data - mean) / std_replaced
    return normalized_data, mean, std_replaced


walk_heavy_actions, (walk_heavy_observations, _, _) = np.load("data/mocapact/awh.npy"), normalize_data(np.load("data/mocapact/owh.npy"))
run_circle_actions, (run_circle_observations, _, _) = np.load("data/mocapact/arc.npy"), normalize_data(np.load("data/mocapact/orc.npy"))


In [3]:
batch_size = 2
n_max_obs, n_max_tar = 32, 32

num_indiv, t_steps, dx = walk_heavy_observations.shape
_, _, dy = walk_heavy_actions.shape
num_indiv -= 6

num_val = 12
num_classes = 2
num_demos = num_indiv*num_classes

num_val_indiv = num_val//num_classes

colors = ['tomato', 'aqua']

In [4]:
x = torch.zeros(num_demos, t_steps, dx, device=device)
y = torch.zeros(num_demos, t_steps, dy, device=device)
vx = torch.zeros(num_val, t_steps, dx, device=device)
vy = torch.zeros(num_val, t_steps, dy, device=device)

vind = torch.randint(0, num_indiv+num_val_indiv, (num_val_indiv, 1))
tr_ctr, val_ctr = 0, 0
for cur_vind in vind:
    vx[val_ctr] = torch.from_numpy(walk_heavy_observations[cur_vind]).to(device)
    vx[val_ctr+1] = torch.from_numpy(run_circle_observations[cur_vind]).to(device)
    vy[val_ctr] = torch.from_numpy(walk_heavy_actions[cur_vind]).to(device)
    vy[val_ctr+1] = torch.from_numpy(run_circle_actions[cur_vind]).to(device)

    val_ctr += 2

i=0
while i*2 < num_val_indiv+num_indiv:
    if i in vind:
       pass
    else:
        x[tr_ctr] = torch.from_numpy(walk_heavy_observations[i]).to(device)
        y[tr_ctr] = torch.from_numpy(walk_heavy_actions[i]).to(device)
        x[tr_ctr+num_indiv] = torch.from_numpy(run_circle_observations[i]).to(device)
        y[tr_ctr+num_indiv] = torch.from_numpy(run_circle_actions[i]).to(device)
        tr_ctr += 1
    i += 1

print("X:", x.shape, "Y:", y.shape, "VX:", vx.shape, "VY:", vy.shape)

X: torch.Size([62, 208, 287]) Y: torch.Size([62, 208, 56]) VX: torch.Size([12, 208, 287]) VY: torch.Size([12, 208, 56])


In [5]:
def get_batch(x, y, traj_ids, device=device):
    n_o = torch.randint(1, n_max_obs, (1,)).item()
    n_t = torch.randint(1, n_max_tar, (1,)).item()
    
    tar = torch.zeros(batch_size, n_t, dx, device=device)
    tar_val = torch.zeros(batch_size, n_t, dy, device=device)
    obs = torch.zeros(batch_size, n_o, dx+dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        
        o_ids = random_query_ids[:n_o]
        t_ids = random_query_ids[n_o:n_o+n_t]

        obs[i, :, :] = torch.cat((x[traj_ids[i], o_ids], y[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = x[traj_ids[i], t_ids]
        tar_val[i, :, :] = y[traj_ids[i], t_ids]

    return obs, tar, tar_val

def get_validation_batch(vx, vy, traj_ids, device=device):
    num_obs = torch.randint(1, n_max_obs, (1,)).item()

    obs = torch.zeros(batch_size, num_obs, dx+dy, device=device)
    tar = torch.zeros(batch_size, t_steps, dx, device=device)
    tar_val = torch.zeros(batch_size, t_steps, dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        o_ids = random_query_ids[:num_obs]

        obs[i, :, :] = torch.cat((vx[traj_ids[i], o_ids], vy[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = vx[traj_ids[i]]
        tar_val[i, :, :] = vy[traj_ids[i]]

    return obs, tar, tar_val

In [6]:
model_wta_ = WTA_CNP(dx, dy, n_max_obs, n_max_tar, [1024, 512, 256], num_decoders=2, decoder_hidden_dims=[360, 360, 360], batch_size=batch_size, scale_coefs=True).to(device)
optimizer_wta = torch.optim.Adam(lr=3e-5, params=model_wta_.parameters())

def get_parameter_count(model):
    total_num = 0
    for param in model.parameters():
        total_num += param.shape.numel()
    return total_num

print("WTA-CNP:", get_parameter_count(model_wta_))

if torch.__version__ >= "2.0":
    model_wta = torch.compile(model_wta_)

WTA-CNP: 1741362


In [7]:
import time
import os

timestamp = int(time.time())
root_folder = f'outputs/mocapact/{dy}D/{str(timestamp)}/'

if not os.path.exists(root_folder):
    os.makedirs(root_folder)

if not os.path.exists(f'{root_folder}saved_models/'):
    os.makedirs(f'{root_folder}saved_models/')

# if not os.path.exists(f'{root_folder}img/'):
#     os.makedirs(f'{root_folder}img/')

torch.save(y, f'{root_folder}y.pt')


epochs = 1_000_000
epoch_iter = num_demos//batch_size  # number of batches per epoch (e.g. 100//32 = 3)
v_epoch_iter = num_val//batch_size  # number of batches per validation (e.g. 100//32 = 3)
avg_loss_wta = 0

val_per_epoch = 1000
min_val_loss_wta = 1000000

mse_loss = torch.nn.MSELoss()

training_loss_wta, validation_error_wta = [], []

wta_tr_loss_path = f'{root_folder}wta_training_loss.pt'
wta_val_err_path = f'{root_folder}wta_validation_error.pt'

for epoch in range(epochs):
    epoch_loss_wta = 0

    # traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)  # [:batch_size*epoch_iter] because nof_trajectories may be indivisible by batch_size
    traj_ids, v_traj_ids = [], []
    inds = torch.randperm(num_indiv)
    vinds = torch.randperm(num_val)[:num_val_indiv]

    for i in range(inds.shape[0]):
        first = inds[i] * torch.randint(1,3,(1,1)).item()  # randint changes the order if it returns 2. for input randomization
        second = num_demos-first-1
        traj_ids.append([first, second])

    for i in range(vinds.shape[0]):
        v_traj_ids.append([vinds[i], num_val-vinds[i]-1])

    for i in range(epoch_iter):
        optimizer_wta.zero_grad()

        obs_wta, tar_x_wta, tar_y_wta = get_batch(x, y, traj_ids[i], device)
        pred_wta, gate_wta = model_wta(obs_wta, tar_x_wta)
        loss_wta, wta_nll = model_wta.loss(pred_wta, gate_wta, tar_y_wta)
        loss_wta.backward()
        optimizer_wta.step()

        epoch_loss_wta += wta_nll.item()

    training_loss_wta.append(epoch_loss_wta)

    if epoch % val_per_epoch == 0:
        with torch.no_grad():
            # v_traj_ids = torch.randperm(vx.shape[0])[:batch_size*v_epoch_iter].chunk(v_epoch_iter)
            val_loss_wta = 0

            for j in range(v_epoch_iter):
                o_wta, t_wta, tr_wta = get_validation_batch(vx, vy, v_traj_ids[j], device=device)

                p_wta, g_wta = model_wta(o_wta, t_wta)
                dec_id = torch.argmax(g_wta.squeeze(1), dim=-1)
                vp_means = p_wta[dec_id, torch.arange(batch_size), :, :dy]
                val_loss_wta += mse_loss(vp_means, tr_wta).item()

            val_loss_wta /= v_epoch_iter
            validation_error_wta.append(val_loss_wta)
            print(f'(WTA)Validation loss: {val_loss_wta}')
            if val_loss_wta < min_val_loss_wta and epoch > 1e5:
                min_val_loss_wta = val_loss_wta
                print(f'(WTA)New best: {min_val_loss_wta}')
                torch.save(model_wta_.state_dict(), f'{root_folder}saved_models/wta_on_synth.pt')
  
        # if epoch % (val_per_epoch*10) == 0:
        #     draw_val_plot(root_folder, epoch)


    avg_loss_wta += epoch_loss_wta

    if epoch % val_per_epoch == 0:
        print("Epoch: {}, WTA-Loss: {}".format(epoch, avg_loss_wta/val_per_epoch))
        avg_loss_wta = 0

torch.save(torch.Tensor(training_loss_wta), wta_tr_loss_path)
torch.save(torch.Tensor(validation_error_wta), wta_val_err_path)

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


(WTA)Validation loss: 0.4004123757282893
Epoch: 0, WTA-Loss: 0.03878955090045929
(WTA)Validation loss: 0.020542152225971222
Epoch: 1000, WTA-Loss: 17.698264050126078
(WTA)Validation loss: 0.019522906591494877
Epoch: 2000, WTA-Loss: 17.396631286382675
(WTA)Validation loss: 0.01881197684754928
Epoch: 3000, WTA-Loss: 17.35177454453707
(WTA)Validation loss: 0.018003832238415878
Epoch: 4000, WTA-Loss: 17.326031472444534




(WTA)Validation loss: 0.01806862683345874
Epoch: 5000, WTA-Loss: 17.30832500445843
(WTA)Validation loss: 0.017290004373838503
Epoch: 6000, WTA-Loss: 17.294547849714757
(WTA)Validation loss: 0.01782176100338499
Epoch: 7000, WTA-Loss: 17.28356212031841
(WTA)Validation loss: 0.020119949554403622
Epoch: 8000, WTA-Loss: 17.274070960342883
(WTA)Validation loss: 0.01832819435124596
Epoch: 9000, WTA-Loss: 17.265990211308
(WTA)Validation loss: 0.019862435292452574
Epoch: 10000, WTA-Loss: 17.258642950177194
(WTA)Validation loss: 0.02109148974219958
Epoch: 11000, WTA-Loss: 17.252478099286556
(WTA)Validation loss: 0.02088058413937688
Epoch: 12000, WTA-Loss: 17.2469499669075
(WTA)Validation loss: 0.02078925088668863
Epoch: 13000, WTA-Loss: 17.241587978124617
(WTA)Validation loss: 0.02217521828909715
Epoch: 14000, WTA-Loss: 17.237081035494803
(WTA)Validation loss: 0.02007465964804093
Epoch: 15000, WTA-Loss: 17.232904402375222
(WTA)Validation loss: 0.021700883905092876
Epoch: 16000, WTA-Loss: 17.2291

KeyboardInterrupt: 