In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import math

def log_prob_loss(output, y_target): 
    mean, std = output.chunk(2, dim=-1)
    std = F.softplus(std)
    dist = D.Normal(loc=mean, scale=std)
    return -torch.mean(dist.log_prob(y_target)) 

class CNMP(nn.Module):
    def __init__(self, d_x, d_y, d_SM):
        super(CNMP, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(d_x + d_y, 64), nn.LayerNorm(64), nn.ReLU(),
            nn.Linear(64, 64), nn.LayerNorm(64), nn.ReLU(),
            nn.Linear(64, 128), nn.LayerNorm(128), nn.ReLU(),
            nn.Linear(128, 256), nn.LayerNorm(256), nn.ReLU(),
            nn.Linear(256, 256),
        )

        self.decoder = nn.Sequential(
            nn.Linear(d_x + (15) + 256, 512), nn.LayerNorm(512), nn.ReLU(),
            nn.Linear(512, 512), nn.LayerNorm(512), nn.ReLU(),
            nn.Linear(512, 512), nn.LayerNorm(512), nn.ReLU(),
            nn.Linear(512, 256), nn.LayerNorm(256), nn.ReLU(),
            nn.Linear(256, 128), nn.LayerNorm(128), nn.ReLU(),
            nn.Linear(128, 2 * d_SM)  # Output mean and std for
        )
    def forward(self, obs, context, mask, x_tar): # obs is (n, d_x + d_y)

        r = self.encoder(obs)
        masked_r = torch.bmm(mask, r)
        masked_r_sum = torch.sum(masked_r, dim=1, keepdim=True)  # (1, 128)
        r_avg = masked_r_sum / torch.sum(mask, dim=[1,2], keepdim=True)  # (1, 128)
        r_avg = r_avg.repeat(1, x_tar.shape[1], 1)
        context = context.unsqueeze(1).repeat(1, x_tar.shape[1], 1)  # (n, 1, 9)
        concat = torch.cat((r_avg, context, x_tar), dim=-1)
        #concat = torch.cat((r_avg, x_tar), dim=-1)
        output = self.decoder(concat) # (2*d_y,)
        return output, r_avg

# gets random number of random obs. points from a random trajectory. Also gets a 
# random target (x,y) from the same trajectory
def get_training_sample(d_SM, batch_size):

    n = np.random.randint(0, OBS_MAX, batch_size) + 1  # number of observations
    perm = np.random.permutation(d_N)
    d = perm[:batch_size]  # select random trajectories

    observations = np.zeros((batch_size, OBS_MAX, d_x + d_y))
    context = np.zeros((batch_size, 15))
    target_X = np.zeros((batch_size, 1, d_x))
    target_Y = np.zeros((batch_size, 1, d_SM))
    mask = np.zeros((batch_size, OBS_MAX, OBS_MAX))

    for i in range(batch_size):
        perm = np.random.permutation(time_len)
        observations[i,:n[i],:d_x] = X[d[i],perm[:n[i]]]
        observations[i,:n[i],d_x:d_x+d_y] = Y[d[i],perm[:n[i]]]
        #context[i,:] = np.concat((C[d[i]], O[d[i], perm[n[i]]]), axis=-1)
        context[i,:] = C[d[i]]
        target_X[i,0] = X[d[i],perm[n[i]]]
        target_Y[i,0] = Y[d[i],perm[n[i]],:d_SM]
        mask[i,:n[i],:n[i]] = 1
    
    return torch.from_numpy(observations), torch.from_numpy(context), \
            torch.from_numpy(target_X), torch.from_numpy(target_Y), torch.from_numpy(mask)

In [None]:
import validation
import importlib
importlib.reload(validation)

time_len = 451

action_data = np.load('data/reach_arm_actions_v1.npy')  # shape (25, 451, 30)
observation_data = np.load('data/reach_arm_observations_v1.npy')  # shape (25, 451, 42)
print('Action data shape:', action_data.shape)
num_data = action_data.shape[0]

X = np.tile(np.linspace(0, 1, time_len).reshape((1, time_len, 1)), (num_data, 1, 1))  # 25 trajectories
Y = np.zeros((num_data, time_len, 5))
Y[:, 1:] = action_data
C = np.zeros((num_data, 15))
for i in range(num_data):
    C[i, :9] = observation_data[i, 0, 30:39]
    C[i, 9:] = observation_data[i, 0, 42:]  # add the first observation as context

# normalize Y and C by dimensions
for dim in range(Y.shape[-1]):
    Y_min = np.min(Y[:, :, dim], axis=(0, 1), keepdims=True)
    Y_max = np.max(Y[:, :, dim], axis=(0, 1), keepdims=True)
    Y[:, :, dim] = (Y[:, :, dim] - Y_min) / (Y_max - Y_min + 1e-8)
for dim in range(C.shape[-1]):
    C_min = np.min(C[:, dim], axis=0, keepdims=True)
    C_max = np.max(C[:, dim], axis=0, keepdims=True)
    C[:, dim] = (C[:, dim] - C_min) / (C_max - C_min + 1e-8)

OBS_MAX = 10
d_x = X.shape[-1]
d_y = Y.shape[-1]
d_SM = d_y
d_N = Y.shape[0]
batch_size = 8

model = CNMP(d_x, d_y, d_SM).double()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

losses = []
errors = []

val_indices = [0, 1]

for i in range(50_000):

    obs, context, x_tar, y_tar, mask = get_training_sample(d_SM, batch_size)

    optimizer.zero_grad()

    output, _ = model(obs, context, mask, x_tar)
    loss = log_prob_loss(output, y_tar)
    
    loss.backward()
    optimizer.step()

    if i % 10000 == 0:
        print('Iteration ' + str(i))
    if i % 200 == 0:
        #epoch_error = validation.val(model, VAL_Y, VAL_C, d_x, d_y, d_SM)
        #errors.append(epoch_error)
        losses.append(loss.item())
        if min(losses) == loss.item():
            print('Iteration ' + str(i) + ' - Loss: ' + '%.4f' % loss.item())
            print('Saving model...')
            torch.save(model.state_dict(), 'save/best_models_reach_arm_v1/model_' + str(i) + '.pth')