In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import numpy as np

**Torch, Numpy stuff**

In [2]:
if False:#torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device.type)

# ---

torch.manual_seed(42)
np.random.seed(42)

cpu


---
**Training Data**

Demonstration data contains a list of **d** scenes:

- ~~Each scene contains a list of trajectories of **p** people, where p is **not constant**~~
- A trajectory is a list of **t** states, where t = 400
- A state is a **s** = 4 dimensional variable
 - State = (d<sub>goal<sub>x</sub></sub>, d<sub>goal<sub>y</sub></sub>, v<sub>x</sub>, v<sub>y</sub>)
    
> Shape of data is (d, ~~p,~~ t, s) => (d, ~~p,~~ 400, 4)

---

Train, test, val split: 0.8, 0.01, 0.19

In [3]:
# from sklearn.model_selection import train_test_split

# def data_from_demonstrations(obs_dims=[0,1], tar_dims=[2,3], path="../data/input/"):
#     data = np.load(f"{path}/states_processed.npy")
#     X, Y = data[:,:,:,obs_dims], data[:,:,:,tar_dims]
#     return X, Y


# # for testing purposes
# def synthetic_data(n=10):
#     tlen = 400
#     X = np.zeros((n, tlen, 1))
#     Y = np.zeros((n, tlen, 1))
#     for i in range(n):
#         X[i] = np.random.uniform(0, 1, tlen).reshape(tlen, 1)
#         Y[i] = np.sin(X[i]*2*np.pi)/n + i/n
#     return X, Y

# #X, Y = data_from_demonstrations()
# X, Y = synthetic_data()

# X_train, X_rem, Y_train, Y_rem = train_test_split(X, Y, train_size=0.8)

# # 0.2 * 0.05 = 0.01 (test size in entire data)
# test_sz = 0.05
# X_val, X_test, Y_val, Y_test = train_test_split(X_rem, Y_rem, test_size=test_sz)

In [4]:
# import matplotlib.pyplot as plt

# def plot_trajectories(X, Y):
#     d, t, s = X.shape
#     for i in range(d):
#         plt.plot(X[i], Y[i], ".")


# plot_trajectories(X_train, Y_train)

---

### Preparing the data

**n <= n<sub>max</sub>** random number of random observations on a random trajectory

**get image**

In [42]:
from PIL import Image
import torchvision.transforms as T

def get_frames(path, demonstration_id, observation_ids):
    frames_path = f'{path}{demonstration_id}/'
    transform = T.Compose([T.ToTensor()])
    frames = []
    for i in observation_ids:
        frames.append(transform(Image.open(f"{frames_path}{i}.jpg")))
                      
    frames = torch.stack(frames, 0)
    return frames

In [56]:
n_max = 10
d_size = 10000  # nof demonstrations
t_size = 400  # length of trajectories
path="../data/processed/input/"

def sample_training_demonstration():
#     d, t, s = X_train.shape
    
    rand_traj_ind = np.random.randint(0, d_size)
    n = np.random.randint(1, n_max+1)
#    rand_traj = X_train[rand_traj_ind]
#     rand_out = Y_train[rand_traj_ind]

    rand_traj = np.load(f"{path}{rand_traj_ind}/states.npy")

    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: +1 is for sampling the target
    
    frames = get_frames(path, rand_traj_ind, observation_indices[:-1])
    
    observations = torch.from_numpy(rand_traj[observation_indices[:-1], :])
    targetX = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 0:2]), 0)
    targetY = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 2:]), 0)
    
    return frames.float(), observations.float(), targetX.float(), targetY.float()

---
### Model

In [58]:
class CNP(nn.Module):
    def __init__(self):
        super(CNP, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 4, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(4, 8, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(4+32*4*4,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024)
        )
        
        self.query = nn.Sequential(
            nn.Linear(1024+2,1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Linear(256,2*2)
        )

    def forward(self, frames, observations, target):
        # n < n_max frames of a scene along with momentary observations are concatenated to constitute input
        scene_encodings = torch.flatten(self.cnn(frames))
        scene_encodings = self.cnn(frames).view(frames.shape[0], 32*4*4)
        encoder_in = torch.cat((observations, scene_encodings), 1)
        r = self.encoder(encoder_in)
        
        r_avg = torch.mean(r, dim=0)
        r_avgs = r_avg.repeat(target.shape[0], 1)  # repeating the same r_avg for each target
        r_avg_target = torch.cat((r_avgs, target), 1)
        query_out = self.query(r_avg_target)
        
        return query_out

    
def log_prob_loss(output, target):
    mean, sigma = output.chunk(2, dim = -1)
    sigma = F.softplus(sigma)
    dist = D.Independent(D.Normal(loc=mean, scale=sigma), 1)
    return -torch.mean(dist.log_prob(target))

In [None]:
model = CNP()
model.to(device)

optimizer = torch.optim.Adam(lr=1e-3, params=model.parameters(), betas=(0.9, 0.999), amsgrad=True)

epoch = 1000
for i in range(epoch):
    fs, obss, tx, ty = sample_training_demonstration()
    
    optimizer.zero_grad()
    ty_pred = model(fs, obss, tx)
    loss = log_prob_loss(ty, ty_pred)
    
    loss.backward()
    optimizer.step()
    
    print(loss)
    
#     enc = model(images, torch.rand(2, 4), torch.rand(2, 2))
#     print(enc.shape)

tensor(5.6259, grad_fn=<NegBackward0>)
tensor(6.1686, grad_fn=<NegBackward0>)
tensor(5.5382, grad_fn=<NegBackward0>)
tensor(1.2591, grad_fn=<NegBackward0>)
tensor(24.0867, grad_fn=<NegBackward0>)
tensor(3.0986, grad_fn=<NegBackward0>)
tensor(3.7591, grad_fn=<NegBackward0>)
tensor(5.5022, grad_fn=<NegBackward0>)
tensor(4.9500, grad_fn=<NegBackward0>)
tensor(5.4900, grad_fn=<NegBackward0>)
tensor(4.6404, grad_fn=<NegBackward0>)
tensor(4.5358, grad_fn=<NegBackward0>)
tensor(4.8668, grad_fn=<NegBackward0>)
tensor(7.3401, grad_fn=<NegBackward0>)
tensor(4.1209, grad_fn=<NegBackward0>)
tensor(5.2485, grad_fn=<NegBackward0>)
tensor(1.8495, grad_fn=<NegBackward0>)
tensor(1.6569, grad_fn=<NegBackward0>)
tensor(3.1767, grad_fn=<NegBackward0>)
tensor(19.6081, grad_fn=<NegBackward0>)
tensor(1.8366, grad_fn=<NegBackward0>)
tensor(0.6381, grad_fn=<NegBackward0>)
tensor(5.1956, grad_fn=<NegBackward0>)
tensor(4.2625, grad_fn=<NegBackward0>)
tensor(3.2711, grad_fn=<NegBackward0>)
tensor(4.9956, grad_fn=

tensor(1.6779, grad_fn=<NegBackward0>)
tensor(2.9727, grad_fn=<NegBackward0>)
tensor(1.9856, grad_fn=<NegBackward0>)
tensor(3.8853, grad_fn=<NegBackward0>)
tensor(1.6599, grad_fn=<NegBackward0>)
tensor(3.4203, grad_fn=<NegBackward0>)
tensor(-0.2541, grad_fn=<NegBackward0>)
tensor(1.9480, grad_fn=<NegBackward0>)
tensor(3.1590, grad_fn=<NegBackward0>)
tensor(1.4512, grad_fn=<NegBackward0>)
tensor(1.6882, grad_fn=<NegBackward0>)
tensor(1.7867, grad_fn=<NegBackward0>)
tensor(2.3787, grad_fn=<NegBackward0>)
tensor(0.9974, grad_fn=<NegBackward0>)
tensor(3.9506, grad_fn=<NegBackward0>)
tensor(1.3902, grad_fn=<NegBackward0>)
tensor(2.5898, grad_fn=<NegBackward0>)
tensor(2.0534, grad_fn=<NegBackward0>)
tensor(2.3196, grad_fn=<NegBackward0>)
tensor(2.7344, grad_fn=<NegBackward0>)
tensor(1.8916, grad_fn=<NegBackward0>)
tensor(0.9395, grad_fn=<NegBackward0>)
tensor(3.6456, grad_fn=<NegBackward0>)
tensor(3.5701, grad_fn=<NegBackward0>)
tensor(2.2708, grad_fn=<NegBackward0>)
tensor(3.9659, grad_fn=<

tensor(1.6943, grad_fn=<NegBackward0>)
tensor(3.0633, grad_fn=<NegBackward0>)
tensor(1.4660, grad_fn=<NegBackward0>)
tensor(3.5104, grad_fn=<NegBackward0>)
tensor(2.1122, grad_fn=<NegBackward0>)
tensor(1.2924, grad_fn=<NegBackward0>)
tensor(0.7707, grad_fn=<NegBackward0>)
tensor(2.8142, grad_fn=<NegBackward0>)
tensor(2.9057, grad_fn=<NegBackward0>)
tensor(2.7133, grad_fn=<NegBackward0>)
tensor(3.2982, grad_fn=<NegBackward0>)
tensor(2.3174, grad_fn=<NegBackward0>)
tensor(3.2414, grad_fn=<NegBackward0>)
tensor(3.3906, grad_fn=<NegBackward0>)
tensor(1.5904, grad_fn=<NegBackward0>)
tensor(3.4677, grad_fn=<NegBackward0>)
tensor(3.3860, grad_fn=<NegBackward0>)
tensor(2.6922, grad_fn=<NegBackward0>)
tensor(1.1967, grad_fn=<NegBackward0>)
tensor(2.5827, grad_fn=<NegBackward0>)
tensor(4.1040, grad_fn=<NegBackward0>)
tensor(3.4805, grad_fn=<NegBackward0>)
tensor(0.1557, grad_fn=<NegBackward0>)
tensor(0.4052, grad_fn=<NegBackward0>)
tensor(1.6730, grad_fn=<NegBackward0>)
tensor(0.6650, grad_fn=<N