In [1]:
from models.cnp import CNP
from data.data_generators import *
import torch

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
# device = torch.device("cpu")

In [2]:
# Data generation
import matplotlib.pyplot as plt

dx = 1
t_steps = 200
num_train_per_class, num_val_per_class = 32, 8
num_classes = 4

x = torch.linspace(0, 1, t_steps).view(-1, 1)

generator_functions = [generate_sin, generate_cos, generate_cx_sigm, 
    generate_reverse_cx_sigm]
y = []
for i in range(num_train_per_class + num_val_per_class):
    for j in range(num_classes):
        y.append(generator_functions[j](x))

colors = ["b", "r", "g", "y"]
# for i, y_i in enumerate(y):
#     plt.plot(y_i, alpha=0.5, c=colors[i%num_classes])

x = x.unsqueeze(0).repeat(len(y), 1, 1).to(device)
y = torch.stack(y, dim=0).to(device)

vx, vy = x[num_train_per_class*num_classes:], y[num_train_per_class*num_classes:]
x, y = x[:num_train_per_class*num_classes], y[:num_train_per_class*num_classes]

print(x.shape, y.shape)
print(vx.shape, vy.shape)

torch.Size([128, 200, 1]) torch.Size([128, 200, 1])
torch.Size([32, 200, 1]) torch.Size([32, 200, 1])


In [3]:
batch_size = 32

model = CNP(input_dim=1, hidden_dim=287, output_dim=1, n_max_obs=10, n_max_tar=10, num_layers=2, batch_size=batch_size).to(device)
optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())

In [4]:
def get_batch(x, y, traj_ids):
    dx, dy = x.shape[-1], y.shape[-1]
    n = x.shape[1]
    n_t = torch.randint(1, model.n_max_tar, (1,)).item()
    n_o = torch.randint(1, model.n_max_obs, (1,)).item()

    obs = torch.zeros(batch_size, n_o, dx+dy).to(device)
    tar = torch.zeros(batch_size, n_t, dx).to(device)
    tar_val = torch.zeros(batch_size, n_t, dy).to(device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(n)
        o_ids = random_query_ids[:n_o]
        t_ids = random_query_ids[n_o:n_o+n_t]

        # print(x.shape, traj_ids[i], o_ids, t_ids)

        obs[i, :, :] = torch.cat((x[traj_ids[i], o_ids], y[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = x[traj_ids[i], t_ids]
        tar_val[i, :, :] = y[traj_ids[i], t_ids]

    return obs, tar, tar_val

In [5]:
def get_validation_batch(vx, vy, o_ids=[0, -1]):
    obs = torch.cat((vx[:, o_ids, :], vy[:, o_ids, :]), dim=-1)
    tar = vx[:, torch.arange(t_steps)]
    tar_val= vy[:, torch.arange(t_steps)]

    return obs, tar, tar_val

In [6]:
import time

file_name = int(time.time())

epochs = 5000000
epoch_iter = 4

avg_loss = 0

val_per_epoch = 1000
min_val_loss = 1000000

mse_loss = torch.nn.MSELoss()

for epoch in range(epochs):
    epoch_loss = 0

    traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)

    for i in range(epoch_iter):
        optimizer.zero_grad()
        obs, tar_x, tar_y = get_batch(x, y, traj_ids[i])
        pred, encoded_rep = model(obs, tar_x)
        loss = model.loss(pred, tar_y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    if epoch % val_per_epoch == 0:
        with torch.no_grad():
            obs, tar_x, tar_y = get_validation_batch(vx, vy)
            pred, encoded_rep = model(obs, tar_x)
            val_loss = mse_loss(pred[:, :, :model.output_dim], tar_y)
            if val_loss < min_val_loss:
                min_val_loss = val_loss
                print(f'New best: {min_val_loss}')
                torch.save(model.state_dict(), f'saved_models/cnp_synth_{file_name}.pt')

    avg_loss += epoch_loss

    if epoch % 100 == 0:
        print("Epoch: {}, Loss: {}".format(epoch, avg_loss/100))
        avg_loss = 0

New best: 0.27631792426109314
Epoch: 0, Loss: 0.03474356770515442
Epoch: 100, Loss: 1.4071188947558404
Epoch: 200, Loss: 1.2113347433507442
Epoch: 300, Loss: 0.9125959388911724
Epoch: 400, Loss: 0.46387615349842237
Epoch: 500, Loss: -0.011540079317055642
Epoch: 600, Loss: -0.44601027956523465
Epoch: 700, Loss: -0.8791002567857504
Epoch: 800, Loss: -1.2255767969461158
Epoch: 900, Loss: -1.6564016634738072
New best: 0.03659152612090111
Epoch: 1000, Loss: -1.7543905587866901
Epoch: 1100, Loss: -2.0628511653188615
Epoch: 1200, Loss: -2.2742507974943145
Epoch: 1300, Loss: -2.457329126242548
Epoch: 1400, Loss: -2.6690104281157256
Epoch: 1500, Loss: -2.8692777062254025
Epoch: 1600, Loss: -3.1266692296788094
Epoch: 1700, Loss: -3.2261343839392067
Epoch: 1800, Loss: -3.2712085209600628
Epoch: 1900, Loss: -3.500520045682788
New best: 0.024662518873810768
Epoch: 2000, Loss: -3.683808723752154
Epoch: 2100, Loss: -3.9435343439131976
Epoch: 2200, Loss: -4.085511664934456
Epoch: 2300, Loss: -4.086601

KeyboardInterrupt: 