In [1]:
import torch
import numpy as np

from models.wta_cnp import WTA_CNP
from data.data_generators import *

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [2]:
# Hyperparameters
batch_size = 32
n_max_obs, n_max_tar = 10, 10

t_steps = 200
num_demos = 100
num_classes = 4
num_indiv = int(num_demos/num_classes)  # number of demos per class
noise_clip = 0.0
dx, dy = 1, 1

num_val = 32
num_val_indiv = int(num_val/num_classes)

In [3]:
# Generating the data
x = torch.linspace(0, 1, 200).repeat(num_indiv, 1)
y = torch.zeros(num_demos, t_steps, dy)

generator_functions = [generate_sin, generate_cos, generate_cx_sigm, generate_reverse_cx_sigm]

for i in range(num_classes):
    noise = torch.clamp(torch.randn(x.shape)*1e-4**0.5, min=0) - noise_clip
    y[i*num_indiv:(i+1)*num_indiv] = torch.unsqueeze(generator_functions[i](x) + noise, 2)

x = torch.unsqueeze(x.repeat(num_classes, 1), 2)  # since dx = 1
print("X:", x.shape, "Y:", y.shape)
x, y = x.to(device), y.to(device)

X: torch.Size([100, 200, 1]) Y: torch.Size([100, 200, 1])


In [5]:
def get_batch(x, y, traj_ids):
    n_t = torch.randint(1, n_max_tar, (1,)).item()
    n_o = torch.randint(1, n_max_obs, (1,)).item()

    obs = torch.zeros(batch_size, n_o, dx+dy, device=device)
    tar = torch.zeros(batch_size, n_t, dx, device=device)
    tar_val = torch.zeros(batch_size, n_t, dy, device=device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(t_steps)
        o_ids = random_query_ids[:n_o]
        t_ids = random_query_ids[n_o:n_o+n_t]

        obs[i, :, :] = torch.cat((x[traj_ids[i], o_ids], y[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = x[traj_ids[i], t_ids]
        tar_val[i, :, :] = y[traj_ids[i], t_ids]

    # print("Obs:", obs.shape, "Tar:", tar.shape, "Tar_val:", tar_val.shape)
    return obs, tar, tar_val

Obs: torch.Size([32, 6, 2]) Tar: torch.Size([32, 9, 1]) Tar_val: torch.Size([32, 9, 1])


(tensor([[[2.6633e-01, 7.5395e-01],
          [4.1709e-01, 9.8171e-01],
          [1.0050e-02, 4.0944e-02],
          [7.3869e-01, 7.3177e-01],
          [7.8894e-01, 6.1552e-01],
          [5.4774e-01, 9.8894e-01]],
 
         [[1.5578e-01, 4.7009e-01],
          [8.4925e-01, 4.5610e-01],
          [7.5879e-01, 7.2161e-01],
          [9.3467e-01, 2.0539e-01],
          [2.4121e-01, 6.9262e-01],
          [8.5427e-01, 4.6558e-01]],
 
         [[4.0201e-02, 1.2596e-01],
          [9.3970e-01, 1.9953e-01],
          [3.2663e-01, 8.8306e-01],
          [9.6482e-01, 1.3594e-01],
          [2.6633e-01, 7.5582e-01],
          [7.5377e-02, 2.5219e-01]],
 
         [[9.4472e-01, 1.8309e-01],
          [7.0352e-01, 8.1556e-01],
          [1.1055e-01, 3.7386e-01],
          [1.1558e-01, 3.5517e-01],
          [3.2161e-01, 8.6764e-01],
          [6.1307e-01, 9.5219e-01]],
 
         [[7.5377e-01, 6.9868e-01],
          [1.4070e-01, 4.2778e-01],
          [5.6784e-01, 9.9268e-01],
          [4.371

In [2]:
model = WTA_CNP(batch_size=batch_size).to(device)
optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())
# model

In [None]:
epochs = 500000
epoch_iter = num_demos//batch_size  # number of batches per epoch (e.g. 100//32 = 3)
min_loss = 1000000
avg_loss = 0

for epoch in range(epochs):
    epoch_loss = 0

    traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)  # [:batch_size*epoch_iter] because nof_trajectories may be indivisible by batch_size

    for i in range(epoch_iter):
        optimizer.zero_grad()
        obs, tar_x, tar_y = get_batch(x, y, traj_ids[i])
        pred, encoded_rep = model(obs, tar_x)
        loss = model.loss(pred, tar_y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    if epoch_loss < min_loss and epoch > 5e4:
        min_loss = epoch_loss
        torch.save(model.state_dict(), f'best_test_on_rrd.pt')

    avg_loss += epoch_loss

    if epoch % 100 == 0:
        print("Epoch: {}, Loss: {}".format(epoch, avg_loss/100))
        avg_loss = 0

In [5]:
o, t = torch.rand(1, 1, 2, device=device), torch.rand(1, 1, 1, device=device)
p = model(o, t)
p.shape

torch.Size([4, 32, 1, 2])