In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import numpy as np

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device.type)

# ---

torch.manual_seed(42)
np.random.seed(42)

cpu


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
data_path="../data/cnp_data/"

X, Y, G = np.load(f'{data_path}input/dx.npy'), np.load(f'{data_path}input/dy.npy'), np.load(f'{data_path}input/dg.npy')
vX, vY, vG = np.load(f'{data_path}input/vdx.npy'), np.load(f'{data_path}input/vdy.npy'), np.load(f'{data_path}input/vdg.npy')

In [3]:
n_max = 20  # n < n_max random points
d_size, t_size = X.shape[0], X.shape[1]

dx, dy, dg = X.shape[-1], Y.shape[-1], G.shape[-1]
vd_size = vX.shape[0]

enc_layers = [256, 256, 256]
dec_layers = [256, 256, 256, dy*2]


def sample_training_demonstration():
    rand_traj_ind = np.random.randint(0, d_size)
    n = np.random.randint(1, n_max+1)

    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: +1 is for sampling the target
    
    observations = torch.from_numpy(np.concatenate((X[rand_traj_ind, observation_indices[:-1], :], 
                                                    G[rand_traj_ind, observation_indices[:-1], :], 
                                                    Y[rand_traj_ind, observation_indices[:-1], :]), axis=1))

    targetXG = torch.unsqueeze(torch.from_numpy(np.concatenate((X[rand_traj_ind, observation_indices[-1], :], 
                                                    G[rand_traj_ind, observation_indices[-1], :]))), 0)
    targetY = torch.unsqueeze(torch.from_numpy(Y[rand_traj_ind, observation_indices[-1], :]), 0)
    if device.type == 'cuda':
        return observations.float().cuda(), targetXG.float().cuda(), targetY.float().cuda()
    else:
        return observations.float(), targetXG.float(), targetY.float()


def sample_validation_demonstration(i):    
    n = np.random.randint(1, n_max+1)
    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: same reason as above
    
    observations = torch.from_numpy(np.concatenate((vX[i, observation_indices[:-1], :], 
                                                    vG[i, observation_indices[:-1], :], 
                                                    vY[i, observation_indices[:-1], :]), axis=1))
    targetXG = torch.unsqueeze(torch.from_numpy(np.concatenate((vX[i, observation_indices[-1], :], 
                                                                vG[i, observation_indices[-1], :]))), 0)
    targetY = torch.unsqueeze(torch.from_numpy(vY[i, observation_indices[-1], :]), 0)
    
    if device.type == 'cuda':
        return observations.float().cuda(), targetXG.float().cuda(), targetY.float().cuda()
    else:
        return observations.float(), targetXG.float(), targetY.float()


In [4]:
class CNP(nn.Module):
    def __init__(self):
        super(CNP, self).__init__()        
        self.encoder = nn.Sequential(
            nn.Linear(dx+dy+dg,enc_layers[0]),
            nn.ReLU(),
            nn.Linear(enc_layers[0], enc_layers[1]),
            nn.ReLU(),
            nn.Linear(enc_layers[1], enc_layers[2])
        )
        
        self.query = nn.Sequential(
            nn.Linear(dec_layers[0]+dx+dg, dec_layers[1]),
            nn.ReLU(),
            nn.Linear(dec_layers[1], dec_layers[2]),
            nn.ReLU(),
            nn.Linear(dec_layers[2], dec_layers[3])
        )

    def forward(self, observations, target):
        r = self.encoder(observations)
        r_avg = torch.mean(r, dim=0)
        r_avgs = r_avg.repeat(target.shape[0], 1)  # repeating the same r_avg for each target
        r_avg_target = torch.cat((r_avgs, target), 1)
        query_out = self.query(r_avg_target)
        
        return query_out

    
def log_prob_loss(output, target):
    mean, sigma = output.chunk(2, dim = -1)
    sigma = F.softplus(sigma)
    dist = D.Independent(D.Normal(loc=mean, scale=sigma), 1)
    return -torch.mean(dist.log_prob(target))

# def initialize_weights(m):
#     if isinstance(m, nn.Conv2d):
#         nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
#         if m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)
#     elif isinstance(m, nn.BatchNorm2d):
#         nn.init.constant_(m.weight.data, 1)
#         nn.init.constant_(m.bias.data, 0)
#     elif isinstance(m, nn.Linear):
#         nn.init.kaiming_uniform_(m.weight.data)
#         nn.init.constant_(m.bias.data, 0)


def validate():
    vloss=np.zeros(vd_size)
    with torch.no_grad():
        for i in range(vd_size):
            obss, tx, ty = sample_validation_demonstration(i)
            ty_pred = model(obss, tx)
            vloss[i] = log_prob_loss(ty_pred, ty)
            
    return np.mean(vloss)

In [5]:
from tqdm import tqdm

model = CNP()
model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters(), betas=(0.9, 0.999), amsgrad=True)

epoch = 2000000
val_after_epoch = 2000

losses = np.zeros(int(epoch/val_after_epoch))
min_loss = 1e6

for i in range(epoch):
    obss, tx, ty = sample_training_demonstration()
    
    optimizer.zero_grad()
    ty_pred = model(obss, tx)
    loss = log_prob_loss(ty, ty_pred)
    
    loss.backward()
    optimizer.step()
    
    print(i, end="\r")
    if i%val_after_epoch == 0:
        val_loss = validate()
        print(f"{i}: {val_loss}", end='')
    
        if val_loss < min_loss:
            min_loss = val_loss
            torch.save(model.state_dict(), f'{data_path}output/best_model.pt')
            print(' - val best')
        else:
            print()

0: 1.3618967533111572 - val best
2000: 1.5191482702891033
4000: 1.647318720817566
6000: 1.7100950876871746
8000: 1.6617879072825115
10000: 1.7195289532343547
12000: 1.8232763608296711
14000: 1.7101085583368938
16000: 1.8100399573644002
18000: 1.505852500597636
20000: 1.5977237224578857
22000: 1.5463158289591472
24000: 1.6218249003092449
26000: 1.6647044817606609
28000: 1.7568193276723225
30000: 1.7508166233698528
32000: 1.568538784980774
34000: 1.5901918013890584
36000: 1.7632366021474202
38000: 1.556598385175069
40000: 1.4904919862747192
42000: 1.994827111562093
44000: 1.8233261108398438
46000: 1.7381186087926228
48000: 1.3869424263636272
50000: 1.6322482029596965
52000: 1.3939286470413208
54000: 1.5958702961603801
56000: 1.488824446996053
58000: 1.7050193548202515
60000: 1.4263322750727336
62000: 1.4433401028315227
64000: 1.465017278989156
66000: 1.5840741793314617
68000: 1.548019568125407
70000: 1.4067023197809856
72000: 1.8477022250493367
74000: 1.4783775806427002
76000: 1.89800016

612000: 1.3860507011413574
614000: 1.7432927290598552
616000: 1.9947329759597778
618000: 1.7092089255650837
620000: 1.8033054669698079
622000: 1.6613590717315674
624000: 1.4436704715092976
626000: 1.5980887810389202
628000: 1.7368818124135335
630000: 1.6273386478424072
632000: 1.5199368000030518
634000: 1.6805615027745564
636000: 1.5206234852472942
638000: 1.9154207706451416
640000: 1.3531436522801716
642000: 1.3579129775365193
644000: 1.5619808038075764
646000: 1.7708371082941692
648000: 1.772602637608846
650000: 1.8426658312479656
652000: 1.8456332286198933
654000: 1.7074519395828247
656000: 1.5330214500427246
658000: 1.432846983273824
660000: 1.3760617971420288
662000: 1.5639921029408772
664000: 1.4297815163930256
666000: 1.7094218333562214
668000: 1.3956358035405476
670000: 1.581854025522868
672000: 1.4177109003067017
674000: 1.9554882049560547
676000: 1.5172642469406128
678000: 1.6107936302820842
680000: 1.7620027860005696
682000: 1.3515175580978394
684000: 1.9613429705301921
6860

1216000: 1.550467610359192
1218000: 1.8298861583073933
1220000: 1.6645521720250447
1222000: 1.5452589988708496
1224000: 1.5709007183710735
1226000: 1.5330223242441814
1228000: 1.6650228103001912
1230000: 1.4093022346496582
1232000: 1.38447101910909
1234000: 1.580965240796407
1236000: 1.7680285374323528
1238000: 1.7897279659907024
1240000: 1.4403772751490276
1242000: 1.8207648992538452
1244000: 1.4418437480926514
1246000: 1.635948379834493
1248000: 1.4981165329615276
1250000: 1.5376842419306438
1252000: 1.5565301179885864
1254000: 1.6203471422195435
1256000: 1.5821536779403687
1258000: 1.482298771540324
1260000: 1.4556392431259155
1262000: 1.531586766242981
1264000: 1.7359424432118733
1266000: 1.5624825954437256
1268000: 1.5898463726043701
1270000: 1.611224373181661
1272000: 1.6005394856135051
1274000: 1.3352302710215251
1276000: 1.7049362659454346
1278000: 1.38390318552653
1280000: 1.4049312671025593
1282000: 1.7115575869878132
1284000: 1.5917864243189495
1286000: 1.6631762186686199
12

1806000: 1.7356325387954712
1808000: 1.4820435047149658
1810000: 1.4402831395467122
1812000: 1.4766218662261963
1814000: 1.7184959252675374
1816000: 1.5267955859502156
1818000: 1.3769879738489788
1820000: 1.4874823888142903
1822000: 1.5786739190419514
1824000: 1.6514935493469238
1826000: 1.5121560096740723
1828000: 1.658669392267863
1830000: 1.5212384462356567
1832000: 1.390324314435323
1834000: 1.3395167986551921
1836000: 1.6616651217142742
1838000: 1.7902676264444988
1840000: 1.3165669441223145
1842000: 1.5103785594304402
1844000: 1.394514520963033
1846000: 1.879367192586263
1848000: 1.549653172492981
1850000: 1.5990111629168193
1852000: 1.4749711354573567
1854000: 1.8156641324361165
1856000: 1.4659690062204997
1858000: 1.708293040593465
1860000: 1.698261022567749
1862000: 1.4943118492762248
1864000: 1.702469825744629
1866000: 1.619176944096883
1868000: 1.4644729296366374
1870000: 1.4953944683074951
1872000: 1.5345164934794109
1874000: 1.4719874858856201
1876000: 1.6012298266092937
1