In [11]:
from models.cnp import CNP
from models.wta_cnp import WTA_CNP

from data.data_generators import *
import torch


def get_available_gpu_with_most_memory():
    gpu_memory = []
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)  # Switch to the GPU to accurately measure memory
        gpu_memory.append((i, torch.cuda.memory_stats()['reserved_bytes.all.current'] / (1024 ** 2)))
    
    gpu_memory.sort(key=lambda x: x[1], reverse=True)
    
    return gpu_memory[0][0]

if torch.cuda.is_available():
    available_gpu = get_available_gpu_with_most_memory()
    if available_gpu == 0:
        device_wta = torch.device("cuda:0")
        device_cnp = torch.device("cuda:0")
    else:
        device_wta = torch.device(f"cuda:{available_gpu}")
        device_cnp = torch.device(f"cuda:{available_gpu}")
else:
    device_wta = torch.device("cpu")
    device_cnp = torch.device("cpu")

print("Device WTA:", device_wta, "Device CNP:", device_cnp)

###

torch.set_float32_matmul_precision('high')

Device WTA: cpu Device CNP: cpu


In [12]:
batch_size = 2
n_max_obs, n_max_tar = 6, 6

t_steps = 200
num_demos = 2
num_classes = 2
num_indiv = num_demos//num_classes  # number of demos per class
noise_clip = 0.0
dx, dy = 1, 7

num_val = 2
num_val_indiv = num_val//num_classes

colors = ['tomato', 'aqua']

In [13]:
import csv

root_path = 'data/baxter/start_end/'
files = ['demo_2023-10-13_10-33-32.csv', 'demo_2023-10-13_10-34-28.csv']

raw_data = []

for i in range(len(files)):
    temp_data = []
    f_rel_path = root_path + files[i]
    with open(f_rel_path, 'r') as f:
        for i, line in enumerate(csv.reader(f)):
            if i > 0:
                # temp_data.append([float(line[-7]), float(line[-6]), float(line[-5]), float(line[-4]), float(line[-3]), float(line[-2]), float(line[-1])])  # joint space
                temp_data.append([float(line[3]), float(line[4]), float(line[5]), float(line[6]), float(line[7]), float(line[8]), float(line[9])])  # cartesian space

    raw_data.append(temp_data)

demonstration = []
for i in range(len(raw_data)):
    traj = raw_data[i]
    idx = torch.linspace(0, len(traj)-1, t_steps, dtype=int)
    processed_traj = []
    for ind in idx:
        processed_traj.append(traj[ind])

    demonstration.append(processed_traj)

y = torch.tensor(demonstration, dtype=torch.float32, device=device_wta)
x = torch.unsqueeze(torch.linspace(0, 1, t_steps).repeat(num_classes, 1), -1).to(device_wta)

vx = x.clone()
noise = torch.clamp(torch.randn(x.shape)*1e-4**0.5, min=0).to(device_wta)
vy = y.clone() + noise

print("X:", x.shape, "Y:", y.shape, "VX:", vx.shape, "VY:", vy.shape)

X: torch.Size([2, 200, 1]) Y: torch.Size([2, 200, 7]) VX: torch.Size([2, 200, 1]) VY: torch.Size([2, 200, 7])


In [14]:
from matplotlib import pyplot as plt

# fig, axs = plt.subplots((dy+1)//2, 2, figsize=(9, 8))
# for i in range(num_demos):
#     for j in range(dy):
#         ax0, ax1 = j//2, j%2
#         axs[ax0, ax1].plot(x[i, :, 0].cpu(), y[i, :, j].cpu(), label=f'Demonstration {i+1}')
#         axs[ax0, ax1].grid(True)
#         axs[ax0, ax1].legend()

#         axs[ax0, ax1].set_xlabel('Time (t)')
#         axs[ax0, ax1].set_ylabel('SM(t)')

x0, y0 = x.to(device_wta), y.to(device_wta)
x1, y1 = x.to(device_cnp), y.to(device_cnp)

In [15]:
def get_batch(x, y, traj_ids, device=device_wta):
    n_o = torch.randint(1, n_max_obs, (1,)).item()
    n_t = torch.randint(1, n_max_tar, (1,)).item()
    
    tar = torch.zeros(batch_size, n_t, dx, device=device)
    tar_val = torch.zeros(batch_size, n_t, dy, device=device)
    obs = torch.zeros(batch_size, n_o, dx+dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        
        o_ids = random_query_ids[:n_o]
        t_ids = random_query_ids[n_o:n_o+n_t]

        obs[i, :, :] = torch.cat((x[traj_ids[i], o_ids], y[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = x[traj_ids[i], t_ids]
        tar_val[i, :, :] = y[traj_ids[i], t_ids]

    return obs, tar, tar_val

def get_validation_batch(vx, vy, traj_ids, device=device_wta):
    num_obs = torch.randint(1, n_max_obs, (1,)).item()

    obs = torch.zeros(batch_size, num_obs, dx+dy, device=device)
    tar = torch.zeros(batch_size, t_steps, dx, device=device)
    tar_val = torch.zeros(batch_size, t_steps, dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        o_ids = random_query_ids[:num_obs]

        obs[i, :, :] = torch.cat((vx[traj_ids[i], o_ids], vy[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = vx[traj_ids[i]]
        tar_val[i, :, :] = vy[traj_ids[i]]

    return obs, tar, tar_val

In [16]:
model_wta_ = WTA_CNP(1, dy, n_max_obs, n_max_tar, [128, 128, 128], num_decoders=2, decoder_hidden_dims=[128, 128, 128], batch_size=batch_size, scale_coefs=True).to(device_wta)
optimizer_wta = torch.optim.Adam(lr=1e-4, params=model_wta_.parameters())

model_cnp_ = CNP(input_dim=1, hidden_dim=158, output_dim=dy, n_max_obs=n_max_obs, n_max_tar=n_max_tar, num_layers=3, batch_size=batch_size).to(device_cnp)
optimizer_cnp = torch.optim.Adam(lr=1e-4, params=model_cnp_.parameters())

# print("WTA Model:", model_wta)

In [17]:
def get_parameter_count(model):
    total_num = 0
    for param in model.parameters():
        total_num += param.shape.numel()
    return total_num

print("WTA-CNP:", get_parameter_count(model_wta_))
print("CNP:", get_parameter_count(model_cnp_))

if torch.__version__ >= "2.0":
    model_cnp, model_wta = torch.compile(model_cnp_), torch.compile(model_wta_)

WTA-CNP: 104350
CNP: 104294


In [18]:
from matplotlib.lines import Line2D


def draw_val_plot(root_folder, epoch):
    plt_y_lim = torch.max(vy) + 0.1

    obs = torch.zeros((model_wta.num_decoders, 1, 1, 2)).to(device_wta)
    for i in range(batch_size):
        obs[i] = torch.Tensor([x[i, 80, 0], y[i, 80, 0]]).unsqueeze(0).unsqueeze(0).to(device_wta)

    tar = torch.linspace(0, 1, 200).unsqueeze(0).unsqueeze(-1).to(device_wta)

    with torch.no_grad():
        for i in range(batch_size):
            pred_cnp, _ = model_cnp(obs[i], tar)
            pred_wta, gate = model_wta(obs[i], tar)

            plt.ylim((-plt_y_lim, plt_y_lim))
            plt.scatter(obs[i,:,:,0].cpu(), obs[i,:,:,1].cpu(), c='k')
            for j in range(batch_size):
                plt.plot(torch.linspace(0, 1, 200), pred_wta[j,0,:,0].cpu(), colors[j], alpha=max(0.2, gate[0, 0, j].item()))  # wta pred
            plt.plot(torch.linspace(0, 1, 200), pred_cnp[:, :, :model_cnp.output_dim].squeeze(0).cpu(), 'b')  # cnp pred
            handles = []
            for j in range(batch_size):
                plt.plot(torch.linspace(0, 1, 200), vy[j].squeeze(-1).cpu(), 'k', alpha=0.05 if j!=i else 0.35)  # data
                handles.append(Line2D([0], [0], label=f'gate{j}: {gate[0, 0, j].item():.4f}', color=colors[j]))

            plt.legend(handles=handles, loc='upper right')

            plt.savefig(f'{root_folder}img/{i}_{epoch}.png')
            plt.close()

In [19]:
import time
import os

timestamp = int(time.time())
root_folder = f'outputs/baxter/{dy}D/{str(timestamp)}/'

if not os.path.exists(root_folder):
    os.makedirs(root_folder)

if not os.path.exists(f'{root_folder}saved_models/'):
    os.makedirs(f'{root_folder}saved_models/')

if not os.path.exists(f'{root_folder}img/'):
    os.makedirs(f'{root_folder}img/')

torch.save(y, f'{root_folder}y.pt')


epochs = 1_000_000
epoch_iter = num_demos//batch_size  # number of batches per epoch (e.g. 100//32 = 3)
v_epoch_iter = num_val//batch_size  # number of batches per validation (e.g. 100//32 = 3)
avg_loss_wta, avg_loss_cnp = 0, 0

val_per_epoch = 1000
min_val_loss_wta, min_val_loss_cnp = 1000000, 1000000

mse_loss = torch.nn.MSELoss()

training_loss_wta, validation_error_wta = [], []
training_loss_cnp, validation_error_cnp = [], []

wta_tr_loss_path = f'{root_folder}wta_training_loss.pt'
wta_val_err_path = f'{root_folder}wta_validation_error.pt'
cnp_tr_loss_path = f'{root_folder}cnp_training_loss.pt'
cnp_val_err_path = f'{root_folder}cnp_validation_error.pt'

for epoch in range(epochs):
    epoch_loss_wta, epoch_loss_cnp = 0, 0

    traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)  # [:batch_size*epoch_iter] because nof_trajectories may be indivisible by batch_size

    for i in range(epoch_iter):
        optimizer_wta.zero_grad()
        optimizer_cnp.zero_grad()

        obs_wta, tar_x_wta, tar_y_wta = get_batch(x, y, traj_ids[i], device_wta)
        obs_cnp, tar_x_cnp, tar_y_cnp = obs_wta.clone(), tar_x_wta.clone(), tar_y_wta.clone()

        pred_wta, gate_wta = model_wta(obs_wta, tar_x_wta)
        pred_cnp, encoded_rep_cnp = model_cnp(obs_cnp, tar_x_cnp)

        loss_wta, wta_nll = model_wta.loss(pred_wta, gate_wta, tar_y_wta)

        loss_wta.backward()
        optimizer_wta.step()

        loss_cnp = model_cnp.loss(pred_cnp, tar_y_cnp)
        loss_cnp.backward()
        optimizer_cnp.step()

        epoch_loss_wta += wta_nll.item()
        epoch_loss_cnp += loss_cnp.item()

    training_loss_wta.append(epoch_loss_wta)
    training_loss_cnp.append(epoch_loss_cnp)

    if epoch % val_per_epoch == 0:
        with torch.no_grad():
            v_traj_ids = torch.randperm(vx.shape[0])[:batch_size*v_epoch_iter].chunk(v_epoch_iter)
            val_loss_wta, val_loss_cnp = 0, 0

            for j in range(v_epoch_iter):
                o_wta, t_wta, tr_wta = get_validation_batch(vx, vy, v_traj_ids[j], device=device_wta)
                o_cnp, t_cnp, tr_cnp = o_wta.clone(), t_wta.clone(), tr_wta.clone()

                p_wta, g_wta = model_wta(o_wta, t_wta)
                dec_id = torch.argmax(g_wta.squeeze(1), dim=-1)
                vp_means = p_wta[dec_id, torch.arange(batch_size), :, :dy]
                val_loss_wta += mse_loss(vp_means, tr_wta).item()

                pred_cnp, encoded_rep = model_cnp(o_cnp, t_cnp)
                val_loss_cnp += mse_loss(pred_cnp[:, :, :model_cnp.output_dim], tr_cnp)


            validation_error_wta.append(val_loss_wta)
            if val_loss_wta < min_val_loss_wta:
                min_val_loss_wta = val_loss_wta
                print(f'(WTA)New best: {min_val_loss_wta}')
                torch.save(model_wta.state_dict(), f'{root_folder}saved_models/wta_on_synth.pt')

            validation_error_cnp.append(val_loss_cnp.item())
            if val_loss_cnp < min_val_loss_cnp:
                min_val_loss_cnp = val_loss_cnp
                print(f'(CNP)New best: {min_val_loss_cnp}')
                torch.save(model_cnp.state_dict(), f'{root_folder}saved_models/cnp_on_synth.pt')
  
        # if epoch % (val_per_epoch*10) == 0:
        #     draw_val_plot(root_folder, epoch)


    avg_loss_wta += epoch_loss_wta
    avg_loss_cnp += epoch_loss_cnp

    if epoch % val_per_epoch == 0:
        print("Epoch: {}, WTA-Loss: {}, CNP-Loss: {}".format(epoch, avg_loss_wta/val_per_epoch, avg_loss_cnp/val_per_epoch))
        avg_loss_wta, avg_loss_cnp = 0, 0

torch.save(torch.Tensor(training_loss_wta), wta_tr_loss_path)
torch.save(torch.Tensor(validation_error_wta), wta_val_err_path)
torch.save(torch.Tensor(training_loss_cnp), cnp_tr_loss_path)
torch.save(torch.Tensor(validation_error_cnp), cnp_val_err_path)

(WTA)New best: 0.2329220175743103
(CNP)New best: 0.2546366751194
Epoch: 0, WTA-Loss: 0.0008105372190475464, CNP-Loss: 0.0008176359534263611
(WTA)New best: 0.007430635392665863
(CNP)New best: 0.007286987267434597
Epoch: 1000, WTA-Loss: -1.3400213189311325, CNP-Loss: -1.37302557128435
(WTA)New best: 0.00697508966550231
(CNP)New best: 0.005469901021569967
Epoch: 2000, WTA-Loss: -1.7129488697648048, CNP-Loss: -1.734872560441494
(WTA)New best: 0.005614872556179762
(CNP)New best: 0.0031276815570890903
Epoch: 3000, WTA-Loss: -1.7604880098104476, CNP-Loss: -1.902714949965477
(WTA)New best: 0.0035970790777355433
(CNP)New best: 0.0025927380193024874
Epoch: 4000, WTA-Loss: -1.8695622432231902, CNP-Loss: -2.0076179438829422
(WTA)New best: 0.0030390869360417128
(CNP)New best: 0.002050904557108879
Epoch: 5000, WTA-Loss: -2.0331111538410185, CNP-Loss: -2.073670869469643
(CNP)New best: 0.001745983143337071
Epoch: 6000, WTA-Loss: -2.0708330528736116, CNP-Loss: -2.141706053137779
(WTA)New best: 0.002518

In [None]:
open(f'{root_folder}fin', 'w').close()