In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import numpy as np

**Torch, Numpy stuff**

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device.type)

# ---

torch.manual_seed(42)
np.random.seed(42)

cuda


---
**Training Data**

Demonstration data contains a list of **d** scenes:

- ~~Each scene contains a list of trajectories of **p** people, where p is **not constant**~~
- A trajectory is a list of **t** states, where t = 400
- A state is a **s** = 4 dimensional variable
 - State = (d<sub>goal<sub>x</sub></sub>, d<sub>goal<sub>y</sub></sub>, v<sub>x</sub>, v<sub>y</sub>)
    
> Shape of data is (d, ~~p,~~ t, s) => (d, ~~p,~~ 400, 4)

---

Train, test, val split: 0.8, 0.01, 0.19

In [3]:
# from sklearn.model_selection import train_test_split

# def data_from_demonstrations(obs_dims=[0,1], tar_dims=[2,3], path="../data/input/"):
#     data = np.load(f"{path}/states_processed.npy")
#     X, Y = data[:,:,:,obs_dims], data[:,:,:,tar_dims]
#     return X, Y


# # for testing purposes
# def synthetic_data(n=10):
#     tlen = 400
#     X = np.zeros((n, tlen, 1))
#     Y = np.zeros((n, tlen, 1))
#     for i in range(n):
#         X[i] = np.random.uniform(0, 1, tlen).reshape(tlen, 1)
#         Y[i] = np.sin(X[i]*2*np.pi)/n + i/n
#     return X, Y

# #X, Y = data_from_demonstrations()
# X, Y = synthetic_data()

# X_train, X_rem, Y_train, Y_rem = train_test_split(X, Y, train_size=0.8)

# # 0.2 * 0.05 = 0.01 (test size in entire data)
# test_sz = 0.05
# X_val, X_test, Y_val, Y_test = train_test_split(X_rem, Y_rem, test_size=test_sz)

In [4]:
# import matplotlib.pyplot as plt

# def plot_trajectories(X, Y):
#     d, t, s = X.shape
#     for i in range(d):
#         plt.plot(X[i], Y[i], ".")


# plot_trajectories(X_train, Y_train)

---

### Preparing the data

**n <= n<sub>max</sub>** random number of random observations on a random trajectory

**get image**

In [5]:
from PIL import Image
import torchvision.transforms as T

def get_frames(path, demonstration_id, observation_ids):
    frames_path = f'{path}{demonstration_id}/'
    transform = T.Compose([T.ToTensor()])
    frames = []
    for i in observation_ids:
        frames.append(transform(Image.open(f"{frames_path}{i}.jpg")))
                      
    frames = torch.stack(frames, 0)
    return frames

In [6]:
n_max = 20
d_size = 10000  # nof demonstrations
t_size = 400  # length of trajectories
path="../data/processed/input/"

def sample_training_demonstration():
#     d, t, s = X_train.shape
    
    rand_traj_ind = np.random.randint(0, d_size)
    n = np.random.randint(1, n_max+1)
#    rand_traj = X_train[rand_traj_ind]
#     rand_out = Y_train[rand_traj_ind]

    rand_traj = np.load(f"{path}{rand_traj_ind}/states.npy")

    observation_indices = np.random.choice(np.arange(t_size), n+1, replace=False) # n+1: +1 is for sampling the target
    
    frames = get_frames(path, rand_traj_ind, observation_indices[:-1])
    
    observations = torch.from_numpy(rand_traj[observation_indices[:-1], :])
    targetX = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 0:2]), 0)
    targetY = torch.unsqueeze(torch.from_numpy(rand_traj[observation_indices[-1], 2:]), 0)
    
    return frames.float().cuda(), observations.float().cuda(), targetX.float().cuda(), targetY.float().cuda()

---
### Model

In [7]:
class CNP(nn.Module):
    def __init__(self):
        super(CNP, self).__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 8, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(4+64*4*4,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024)
        )
        
        self.query = nn.Sequential(
            nn.Linear(1024+2,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,2*2)
        )

    def forward(self, frames, observations, target):
        # n < n_max frames of a scene along with momentary observations are concatenated to constitute input
        scene_encodings = torch.flatten(self.cnn(frames))
        scene_encodings = self.cnn(frames).view(frames.shape[0], 64*4*4)
        encoder_in = torch.cat((observations, scene_encodings), 1)
        r = self.encoder(encoder_in)
        
        r_avg = torch.mean(r, dim=0)
        r_avgs = r_avg.repeat(target.shape[0], 1)  # repeating the same r_avg for each target
        r_avg_target = torch.cat((r_avgs, target), 1)
        query_out = self.query(r_avg_target)
        
        return query_out

    
def log_prob_loss(output, target):
    mean, sigma = output.chunk(2, dim = -1)
    sigma = F.softplus(sigma)
    dist = D.Independent(D.Normal(loc=mean, scale=sigma), 1)
    return -torch.mean(dist.log_prob(target))

In [None]:
from tqdm import tqdm

model = CNP()
nn.init.uniform_(model.weight)
model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters(), betas=(0.9, 0.999), amsgrad=True)

epoch = 100000000

losses = np.zeros(epoch)
min_loss = 1e6

for i in range(epoch):
    fs, obss, tx, ty = sample_training_demonstration()
    
    optimizer.zero_grad()
    ty_pred = model(fs, obss, tx)
    loss = log_prob_loss(ty, ty_pred)
    
    loss.backward()
    optimizer.step()
    
    cur_loss = loss.data
    losses[i] = cur_loss
    
    if cur_loss < min_loss:
        min_loss = cur_loss
        print(f"{i}: {cur_loss}")
        torch.save(model.state_dict(), f'{path}../best_model.pt')


0: 8.28280258178711
1: 7.037246227264404
3: 5.927742004394531
5: 5.488465785980225
9: 4.784801483154297
13: 4.704505920410156
14: 4.675440311431885
16: 2.7072501182556152
17: 2.3558242321014404
18: 1.8004528284072876
21: 0.7765635251998901
39: 0.6775355935096741
90: 0.6664631366729736
99: -0.217495858669281
298: -0.49229395389556885
474: -0.5864773988723755
1411: -0.6199854612350464
1537: -0.6524862051010132
1663: -0.6573606729507446
2039: -1.1794071197509766
10535: -1.1954172849655151
46936: -1.2206614017486572
89900: -1.3846296072006226
91760: -1.4742966890335083
214934: -1.6921050548553467


In [None]:
import matplotlib.pyplot as plt

plt.plot(range(len(losses)), losses)