In [7]:
import torch
import numpy as np

from models.cnp import CNP

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
# device = torch.device("cpu")

In [8]:
x = torch.from_numpy(np.load('/home/yigit/projects/mbcnp/data/raw/dataverse_files/Sines_full/constrained_torques.npy'))[:50]
y = torch.from_numpy(np.load('/home/yigit/projects/mbcnp/data/raw/dataverse_files/Sines_full/measured_angles.npy'))[:50]
print(x.shape, y.shape)


torch.Size([50, 15000, 3]) torch.Size([50, 15000, 3])


In [9]:
batch_size = 8

model = CNP(3, 256, 3, 100, 100, 3, batch_size).to(device)
optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters())

In [10]:
def get_batch(x, y, traj_ids):
    dx, dy = x.shape[-1], y.shape[-1]
    n = x.shape[1]
    n_t = torch.randint(1, model.n_max_tar, (1,)).item()
    n_o = torch.randint(1, model.n_max_obs, (1,)).item()

    obs = torch.zeros(batch_size, n_o, dx+dy).to(device)
    tar = torch.zeros(batch_size, n_t, dx).to(device)
    tar_val = torch.zeros(batch_size, n_t, dy).to(device)

    for i in range(len(traj_ids)):
        random_query_ids = torch.randperm(n)
        o_ids = random_query_ids[:n_o]
        t_ids = random_query_ids[n_o:n_o+n_t]

        # print(x.shape, traj_ids[i], o_ids, t_ids)

        obs[i, :, :] = torch.cat((x[traj_ids[i], o_ids], y[traj_ids[i], o_ids]), dim=-1)
        tar[i, :, :] = x[traj_ids[i], t_ids]
        tar_val[i, :, :] = y[traj_ids[i], t_ids]

    return obs, tar, tar_val

In [11]:
epochs = 250000
epoch_iter = 6
min_loss = 1000000
avg_loss = 0

for epoch in range(epochs):
    epoch_loss = 0

    traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)  # [:batch_size*epoch_iter] because nof_trajectories may be indivisible by batch_size

    for i in range(epoch_iter):
        optimizer.zero_grad()
        obs, tar_x, tar_y = get_batch(x, y, traj_ids[i])
        pred, encoded_rep = model(obs, tar_x)
        loss = model.loss(pred, tar_y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    if epoch_loss < min_loss and epoch > 5e4:
        min_loss = epoch_loss
        torch.save(model.state_dict(), f'best_test_on_rrd.pt')

    avg_loss += epoch_loss

    if epoch % 100 == 0:
        print("Epoch: {}, Loss: {}".format(epoch, avg_loss/100))
        avg_loss = 0

Epoch: 0, Loss: 0.1699453377723694
Epoch: 100, Loss: 1.5677354704914614
Epoch: 200, Loss: -0.1693252184025914
Epoch: 300, Loss: -0.7814577378658578
Epoch: 400, Loss: -0.8889546933770179
Epoch: 500, Loss: -0.9440661490894854
Epoch: 600, Loss: -0.9624732788652182
Epoch: 700, Loss: -0.9981548550724983
Epoch: 800, Loss: -1.0088177287951112
Epoch: 900, Loss: -1.0381607383489608
Epoch: 1000, Loss: -1.03293911293149
Epoch: 1100, Loss: -1.042836830811575
Epoch: 1200, Loss: -1.0622142812050879
Epoch: 1300, Loss: -1.0799645626172423
Epoch: 1400, Loss: -1.0896420384943486
Epoch: 1500, Loss: -1.0958647190220654
Epoch: 1600, Loss: -1.1194428786123172
Epoch: 1700, Loss: -1.1391463088430465
Epoch: 1800, Loss: -1.1473143334686755
Epoch: 1900, Loss: -1.1483034303691237
Epoch: 2000, Loss: -1.1514849009900354
Epoch: 2100, Loss: -1.1745781982690096
Epoch: 2200, Loss: -1.1674170801788568
Epoch: 2300, Loss: -1.1780679830443115
Epoch: 2400, Loss: -1.1873683694563806
Epoch: 2500, Loss: -1.2005730479955674
Epo