In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import numpy as np

**Torch, Numpy stuff**

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device.type)

# ---

torch.manual_seed(42)
np.random.seed(42)

cuda


---
**Training Data**

Demonstration data contains a list of **d** scenes:

- ~~Each scene contains a list of trajectories of **p** people, where p is **not constant**~~
- A trajectory is a list of **t** states, where t = 400
- A state is a **s** = 4 dimensional variable
 - State = (d<sub>goal<sub>x</sub></sub>, d<sub>goal<sub>y</sub></sub>, v<sub>x</sub>, v<sub>y</sub>)
    
> Shape of data is (d, ~~p,~~ t, s) => (d, ~~p,~~ 400, 4)

---

Train, test, val split: 0.8, 0.01, 0.19

---

### Preparing the data

**n <= n<sub>max</sub>** random number of random observations on a random trajectory

**get image**

In [3]:
from PIL import Image
import torchvision.transforms as T

def get_frames(path, demonstration_id, observation_ids):
    frames_path = f'{path}{demonstration_id}/'
    transform = T.Compose([T.ToTensor()])
    frames = []
    for i in observation_ids:
        frames.append(transform(Image.open(f"{frames_path}{i}.jpg")))
                      
    frames = torch.stack(frames, 0)
    return frames

**get all input components**

In [4]:
n_max = 30
d_size = 9500  # nof demonstrations
t_size = 400  # length of trajectories
path="../data/processed/input/train/"  # .../input/train/ contains folders from 0 to 9499

def sample_training_demonstration():
#     d, t, s = X_train.shape
    
    rand_traj_ind = np.random.randint(0, d_size)
    n = np.random.randint(1, n_max+1)

    rand_traj = np.load(f"{path}{rand_traj_ind}/states.npy")

    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: +1 is for sampling the target
    
    frames = get_frames(path, rand_traj_ind, observation_indices[:-1])
    
    observations = torch.from_numpy(rand_traj[observation_indices[:-1], :])
    targetX = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 0:2]), 0)
    targetY = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 2:]), 0)
    if device.type == 'cuda':
        return frames.float().cuda(), observations.float().cuda(), targetX.float().cuda(), targetY.float().cuda()
    else:
        return frames.float(), observations.float(), targetX.float(), targetY.float()

The same for validation samples

In [5]:
dv_size = 490  # nof validation demonstrations
vpath = "../data/processed/input/val/"  # .../input/val/ contains folders from 9500 to 9989

def get_validation_demonstration(i):
    demostration_ind = i + d_size

    traj = np.load(f"{vpath}{demostration_ind}/states.npy")
    
    n = np.random.randint(1, n_max+1)
    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: +1 is for sampling the target
    
    frames = get_frames(vpath, demostration_ind, observation_indices[:-1])
    
    observations = torch.from_numpy(traj[observation_indices[:-1], :])
    targetX = torch.unsqueeze(torch.from_numpy(traj[observation_indices[-1], 0:2]), 0)
    targetY = torch.unsqueeze(torch.from_numpy(traj[observation_indices[-1], 2:]), 0)
    
    if device.type == 'cuda':
        return frames.float().cuda(), observations.float().cuda(), targetX.float().cuda(), targetY.float().cuda()
    else:
        return frames.float(), observations.float(), targetX.float(), targetY.float()

---
### Model

In [6]:
class CNP(nn.Module):
    def __init__(self):
        super(CNP, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 8, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(4+64*4*4,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024)
        )
        
        self.query = nn.Sequential(
            nn.Linear(1024+2,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,2*2)
        )

    def forward(self, frames, observations, target):
        # n < n_max frames of a scene along with momentary observations are concatenated to constitute input
        scene_encodings = torch.flatten(self.cnn(frames))
        scene_encodings = self.cnn(frames).view(frames.shape[0], 64*4*4)
        encoder_in = torch.cat((observations, scene_encodings), 1)
        r = self.encoder(encoder_in)
        
        r_avg = torch.mean(r, dim=0)
        r_avgs = r_avg.repeat(target.shape[0], 1)  # repeating the same r_avg for each target
        r_avg_target = torch.cat((r_avgs, target), 1)
        query_out = self.query(r_avg_target)
        
        return query_out

    
def log_prob_loss(output, target):
    mean, sigma = output.chunk(2, dim = -1)
    sigma = F.softplus(sigma)
    dist = D.Independent(D.Normal(loc=mean, scale=sigma), 1)
    return -torch.mean(dist.log_prob(target))

def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

### Training

#### 1- Methods and Variables for Validation 

In [7]:
def validate():
    vloss=np.zeros(dv_size)
    with torch.no_grad():
        for i in range(dv_size):
            fs, obss, tx, ty = get_validation_demonstration(i)
            ty_pred = model(fs, obss, tx)
            vloss[i] = log_prob_loss(ty, ty_pred)
            
    return np.mean(vloss)

In [8]:
val_after_epoch = 2500

#### 2- Actual Training

In [None]:
from tqdm import tqdm

model = CNP()
model.apply(initialize_weights)
model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters(), betas=(0.9, 0.999), amsgrad=True)

epoch = 20000000

losses = np.zeros(int(epoch/val_after_epoch))
min_loss = 1e6

for i in range(epoch):
    fs, obss, tx, ty = sample_training_demonstration()
    
    optimizer.zero_grad()
    ty_pred = model(fs, obss, tx)
    loss = log_prob_loss(ty, ty_pred)
    
    loss.backward()
    optimizer.step()
    
    print(i, end="\r")
    if i%val_after_epoch == 0:
        val_loss = validate()
        print(f"{i}: {val_loss}")
    
        if val_loss < min_loss:
            min_loss = val_loss
            torch.save(model.state_dict(), f'{path}../../best_model.pt')


0: 4.075285734084188
2500: 2.2678578089694588
5000: 2.2613347497521614
7500: 2.2547294201899546
10000: 2.3133686401406113
12500: 2.2391293708159
15000: 2.231107589420007
17500: 2.293084923831784
20000: 2.238102569750377
22500: 2.2273242915163234
25000: 2.219854364711411
27500: 2.231999917054663
30000: 2.2133188623554854
32500: 2.2215249553018688
35000: 2.22607035758544
37500: 2.2120101659881826
40000: 2.2119695234055423
42500: 2.2156157277068314
45000: 2.2289829221307014
47500: 2.2485376045411947
50000: 2.225163831759472
52500: 2.227179199578811
55000: 2.2096407909782565
57500: 2.194340138167751
60000: 2.227353668821101
62500: 2.195209777355194
65000: 2.221446590034329
67500: 2.2490262280921547
70000: 2.2210747645825757
72500: 2.1816903952433138
75000: 2.2131019665270437
77500: 2.1954483079666995
80000: 2.2173403680324553
82500: 2.23465002361609
85000: 2.248476248493
87500: 2.274792685922311
90000: 2.2351584418695802
92500: 2.265560816015516
95000: 2.228611462335197
97500: 2.1956218235

777500: 2.204035087142672
780000: 2.220785693003207
782500: 2.225836912588197
785000: 2.1732594677380153
787500: 2.2085006577628
790000: 2.2218373650190784
792500: 2.2054481208324432
795000: 2.1948633540649802
797500: 2.184769683103172
800000: 2.2194122056571803
802500: 2.2042861949424353
805000: 2.2118662857279485
807500: 2.1687325970250733
810000: 2.2204815831719613
812500: 2.2052855770198665
815000: 2.205632370710373
817500: 2.225977864922309
820000: 2.2225239200251443
822500: 2.1916063986262495
825000: 2.183949003535874
827500: 2.1784898649673075
830000: 2.188420902709572
832500: 2.2184329517033636
835000: 2.207422328360227
837500: 2.2052404406119366
840000: 2.1745822933255408
842500: 2.173651716417196
845000: 2.193048302859676
847500: 2.170434959202397
850000: 2.2296082207134793
852500: 2.198954556061297
855000: 2.209611473764692
857500: 2.208892130486819
860000: 2.186390075756579
862500: 2.2140824476066903
865000: 2.2060709253865847
867500: 2.2083931219821076
870000: 2.2377232433

1530000: 2.213674840878467
1532500: 2.235715466494463
1535000: 2.2110882652049164
1537500: 2.198549707081853
1540000: 2.2066728084671254
1542500: 2.226641423969853
1545000: 2.195820205430595
1547500: 2.1907332498200085
1550000: 2.212312020331013
1552500: 2.219773230747301
1555000: 2.1948961118046117
1557500: 2.2061255340673487
1560000: 2.1962544104274437
1562500: 2.19483700200003
1565000: 2.2144704386896015
1567500: 2.1975606248086814
1570000: 2.180394594036803
1572500: 2.1851704866302257
1575000: 2.1703689733330083
1577500: 2.1954470198981615
1580000: 2.21357995113548
1582500: 2.198752536214128
1585000: 2.2440083826074795
1587500: 2.1973088894571577
1590000: 2.199623828153221
1592500: 2.2093701161900343
1595000: 2.2197623831885203
1597500: 2.2182949512588737
1600000: 2.1979239522194374
1602500: 2.211290513860936
1605000: 2.23398989232219
1607500: 2.2004966865996924
1610000: 2.2263521256495493
1612500: 2.2162918192999705
1615000: 2.1754634095697987
1617500: 2.1892012498816666
1620000: 

2275000: 2.197050829444613
2277500: 2.213581740126318
2280000: 2.209341602057827
2282500: 2.2158962946765275
2285000: 2.191322331404199
2287500: 2.192608999232857
2290000: 2.1654347737224735
2292500: 2.2102401115456405
2295000: 2.216118237315392
2297500: 2.169810398500793
2300000: 2.2088176462115072
2302500: 2.196545871909784
2305000: 2.1940039913265075
2307500: 2.1877142631277744
2310000: 2.1875886355127605
2312500: 2.1884167255187523
2315000: 2.2040339190132765
2317500: 2.191119027867609
2320000: 2.181717251636544
2322500: 2.230873760033627
2325000: 2.21841046457388
2327500: 2.176453795238417
2330000: 2.1673748747426638
2332500: 2.2152563445422113
2335000: 2.2283927933293946
2337500: 2.1974884124434726
2340000: 2.182648011616298
2342500: 2.1891169314481775
2345000: 2.183554219226448
2347500: 2.222038264298926
2350000: 2.2242424753247474
2352500: 2.1954569474774965
2355000: 2.2019312672469082
2357500: 2.1838267726557596
2360000: 2.2066589702148827
2362500: 2.207317368959894
2365000: 2

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(len(losses)), losses)