In [10]:
import numpy as np
import torch

from src.grid_world import GridWorld

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda") if USE_CUDA else torch.device("cpu")

TODO

- Sample (state, action)-pairs. Split them to context and target sets!
- Figure out how to split them correctly, what differences are there brtween
- Train a neural process using the HIIT library

### Simulator

Global Fixed Parameters:
- Grid_size int: side lengths of the grid
- Agent_view_size int: side length of the agents view area

User specific parameters

Fixed

- Agent_pos (int, int): Agent's initial position
- Goal_pos (int, int): Goal's initial position

Change each trial

- Mode_densities: 

Keep fixed:
Grid size
Agent_view_size
Mode_densities

Sample:
Goal_pos,
Agent_pos

User parameters:
Number of belief modes
Distribution of the belief modes (Dirichlet)

In [11]:
def sample_mode_densities(max_modes = 3, total_density = 0.9):
  '''
  Samples a list of mode densities that guide the agent's
  behavior.
  '''
  num_modes = np.random.randint(1, max_modes + 1)
  densities = np.random.dirichlet(np.ones(num_modes)) * total_density
  return list(densities)

test_sample = sample_mode_densities(max_modes = 5, total_density = 0.5)
print(test_sample)
print(np.sum(test_sample))

[0.5]
0.5


In [12]:
def generate_user_parameters(grid_size, mode_params = None):
  '''
  Generates a set of parameters that define the user's behavior.
  '''
  
  mode_densities = None
  if mode_params is None:
    mode_densities = sample_mode_densities()
  else:
    mode_densities = sample_mode_densities(mode_params)
  
  mode_positions = np.random.randint(0, grid_size, (len(mode_densities), 2))
  
  return {
    'mode_densities': mode_densities,
    'mode_positions': mode_positions,
  }
  
test_param = generate_user_parameters(10)
print(test_param)

{'mode_densities': [0.05038294008326661, 0.8496170599167334], 'mode_positions': array([[5, 5],
       [0, 8]])}


In [13]:
def generate_user_trajectories(num_trajectories, grid_size, agent_view_size, user_params, traj_length = None):
  mode_densities = user_params['mode_densities']
  mode_positions = user_params['mode_positions']
  
  env = GridWorld(render_mode = "rgb_array", size = grid_size, agent_view_size = agent_view_size, mode_densities = mode_densities, mode_positions=mode_positions)
  
  trajectories = []
  
  while len(trajectories) < num_trajectories:
    trajectory = []
    
    obs = env.reset()
    state = obs[0]['agent_pos']
    
    done = False

    while not done and len(trajectory) < traj_length:

      action = env.max_neighboring_reward()

      action_onehot = [0 for _ in range(env.action_space.n)]
      action_onehot[action] = 1

      trajectory.append((state, action_onehot))

      next_obs, _, done, truncated, _ = env.step(action)
      state = next_obs['agent_pos']
      
      done = done or truncated

    trajectory.append((state, action_onehot))

    if len(trajectory) < traj_length:
      continue

    trajectories.append(trajectory[:traj_length])
    
  return trajectories

num_traj = 200
grid_size = 10
agent_view_size = 5
traj_length  = 10
user_params = generate_user_parameters(grid_size)
trajectories = generate_user_trajectories(num_traj, grid_size, agent_view_size, user_params, traj_length = traj_length)
print(np.all([len(traj) == traj_length for traj in trajectories]))
print(trajectories[0])

True
[((7, 9), [0, 0, 1, 0]), ((7, 8), [0, 0, 1, 0]), ((7, 7), [0, 0, 1, 0]), ((7, 6), [0, 0, 1, 0]), ((7, 5), [0, 1, 0, 0]), ((8, 5), [1, 0, 0, 0]), ((7, 5), [1, 0, 0, 0]), ((6, 5), [1, 0, 0, 0]), ((5, 5), [1, 0, 0, 0]), ((4, 5), [1, 0, 0, 0])]


In [14]:
def split_trajectory_half(trajectory):
  '''
  Splits a trajectory into two halves.
  '''
  half = len(trajectory) // 2
  return trajectory[:half], trajectory[half:]

In [15]:
from itertools import permutations

def context_and_target(dataset):
    num_traj = len(dataset)
    xc, yc, xt, yt = [], [], [], []

    for i in range(num_traj):
        # Pick one as target and split it
        context_part, target_part = split_trajectory_half(dataset[i])

        # Choose the ids of past context trajectories
        past_context_ids = list(range(i)) + list(range(i+1, num_traj))

        # Generate all permutations of past context ids
        all_permutations = list(permutations(past_context_ids))

        #Choose a subset of permutations
        selected_permutations = np.random.choice(len(all_permutations),
                                                 size=min(5, len(all_permutations)),
                                                 replace=False)
        
        # Generate multiple tasks from the chosen target set and different context
        for p_idx in selected_permutations:
            p = list(all_permutations[p_idx])
            
            past_contexts = [dataset[j] for j in p]

            full_context = past_contexts + [context_part]

            # Separate states and actions for context and target

            context_s = torch.tensor([state for traj in full_context for state, _ in traj], dtype = torch.float32)
            context_a = torch.tensor([action for traj in full_context for _, action in traj], dtype = torch.float32)
            target_s = torch.tensor([state for state, _ in target_part], dtype = torch.float32)
            target_a = torch.tensor([action for _, action in target_part], dtype = torch.float32)

            xc.append(context_s)
            yc.append(context_a)
            xt.append(target_s)
            yt.append(target_a)

    xc = torch.stack(xc, dim = 0)
    yc = torch.stack(yc, dim = 0)
    xt = torch.stack(xt, dim = 0)
    yt = torch.stack(yt, dim = 0)

    return xc, yc, xt, yt

In [16]:
def generate_batch(grid_size = 10, agent_view_size = 5, traj_length = 10):

  user_params = generate_user_parameters(grid_size)
  
  num_trajectories = np.random.randint(1, 11)
  
  trajectories = generate_user_trajectories(num_trajectories, grid_size, agent_view_size, user_params, traj_length)
  
  batch = context_and_target(trajectories)
  
  return batch

In [26]:
GRID_SIZE = 10
AGENT_VIEW_SIZE = 3
N_USERS = 100
TRAJ_LENGTH = 10

batch = generate_batch(grid_size = GRID_SIZE, agent_view_size = AGENT_VIEW_SIZE, traj_length = TRAJ_LENGTH)

In [27]:
batch[0].shape

torch.Size([20, 35, 2])

In [22]:
import numpy as np
import numpy.random as npr
import torch
import torch.nn as nn
import torch.optim as optim
import collections
import matplotlib.pyplot as plt
import datetime
import torch.nn.functional as F

In [162]:
train_CNP = True

TRAINING_ITERATIONS =  int(1e4)
PLOT_AFTER = int(1e3)
USE_TEMPERATURE = True

# Conditional Neural Process

CNPs take in pairs (x, y)

## The following is a copy from the relational_neural_process github

Model

In [163]:
class CNPDeterministicEncoder(nn.Module):
    def __init__(self, sizes):
        super(CNPDeterministicEncoder, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(len(sizes) - 1):
            self.linears.append(nn.Linear(sizes[i], sizes[i + 1]))

    def forward(self, context_x, context_y):
        """
        Encode training set as one vector representation

        Args:
            context_x: batch_size x set_size x feature_dim_x
            context_y: batch_size x set_size x feature_dim_y

        Returns: representation: batch_size x representation_size:
        """

        encoder_input = torch.cat((context_x, context_y), dim = -1)
        batch_size, set_size, filter_size = encoder_input.shape
        x = encoder_input.view(batch_size * set_size, -1)
        for i, linear in enumerate(self.linears[:-1]):
            x = torch.relu(linear(x))
        x = self.linears[-1](x)
        x = x.view(batch_size, set_size, -1)
        representation = x.sum(dim=1)
        return representation
            
class CNPDeterministicDecoder(nn.Module):
    def __init__(self, sizes):
        super(CNPDeterministicDecoder, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(len(sizes) - 1):
            self.linears.append(nn.Linear(sizes[i], sizes[i + 1]))

    def forward(self, representation, target_x):
        """
        Take representation representation of current training set, and a target input x,
        return the predictive distribution at x (Gaussian with mean mu and scale sigma)

        Args:
            representation: batch_size x representation_size
            target_x: batch_size x set_size x d
        """
        batch_size, set_size, d = target_x.shape
        
        if representation is None:        
            input = target_x            
        else:
            representation = representation.unsqueeze(1).repeat([1, set_size, 1])
            input = torch.cat((representation, target_x), dim=-1)
        
        #All rows
        x = input.view(batch_size * set_size, -1)
        for linear in self.linears[:-1]:
            x = torch.relu(linear(x))
        logits = self.linears[-1](x)
        logits = logits.view(batch_size, set_size, -1)
        probs = F.softmax(logits, dim = -1)

        dist = torch.distributions.categorical.Categorical(probs = probs)
        return dist, probs, logits
    
        '''
        mu, log_sigma = torch.split(out, 1, dim = -1)
        sigma = 0.01 + 0.99 * torch.nn.functional.softplus(log_sigma)
        dist = torch.distributions.normal.Normal(loc=mu, scale=sigma)
        '''

class CNPDeterministicModel(nn.Module):
    def __init__(self, encoder_size, decoder_size):
        super(CNPDeterministicModel, self).__init__()
        self._encoder = CNPDeterministicEncoder(encoder_size)
        self._decoder = CNPDeterministicDecoder(decoder_size)


    def forward(self, query, target_y = None):
        (context_x, context_y), target_x = query
        representation = self._encoder(context_x, context_y)
        dist, probs, logits = self._decoder(representation, target_x)

        log_p = None
        if target_y is not None:
            #Reverse one hot encoding on target_y
            target_y = torch.argmax(target_y, dim = -1)
            log_p = dist.log_prob(target_y)

        return log_p, probs, logits

In [174]:
import neuralprocesses.torch as nps

d_x, d_in, representation_size, d_out, hidden_size = 2, 6, 258, 4, 128
encoder_sizes = [d_in, hidden_size, hidden_size, hidden_size, representation_size]
decoder_sizes = [representation_size + d_x, hidden_size, hidden_size, hidden_size, d_out]

model = CNPDeterministicModel(encoder_size=encoder_sizes, decoder_size=decoder_sizes)

### Training loop 

In [175]:
opt = torch.optim.Adam(model.parameters(), 1e-3)

total_loss = []

for i in range(TRAINING_ITERATIONS):
    opt.zero_grad()
    
    xc, yc, xt, yt = generate_batch(grid_size = GRID_SIZE, agent_view_size = AGENT_VIEW_SIZE, traj_length = TRAJ_LENGTH)

    query = (xc, yc), xt

    log_p, prob, logits = model(query, yt)

    loss = -log_p.mean()
    loss.backward()
    opt.step()

    total_loss.append(loss.item())

    if i % 100 == 0:
        avg_loss = np.mean(total_loss)
        print(f"iter: {i}, avg_loss = {avg_loss}")
        total_loss = []

    
    

iter: 0, avg_loss = 1.4590564966201782


KeyboardInterrupt: 

### Testing the library

In [49]:
import neuralprocesses.torch as nps

cnp = nps.construct_gnp(dim_x = 2, dim_y = 4, likelihood = "het")

xc, yc, xt, yt = generate_batch(grid_size = GRID_SIZE, agent_view_size = AGENT_VIEW_SIZE, traj_length = TRAJ_LENGTH)
xc = xc.unsqueeze(0)
yc = yc.unsqueeze(0)
xt = xt.unsqueeze(0)
yt = yt.unsqueeze(0)

print(xc.shape)

xc = xc.view(1, xc.shape[1], 2, -1)
yc = yc.view(1, yc.shape[1], 4, -1)
xt = xt.view(1, xt.shape[1], 2, -1)
yt = yt.view(1, yt.shape[1], 4, -1)

dist = cnp(xc, yc, xt)




torch.Size([1, 20, 35, 2])
tensor([[[[ 0.0329,  0.0324,  0.0328,  0.0325,  0.0327],
          [ 0.0405,  0.0395,  0.0402,  0.0392,  0.0399],
          [-0.0140, -0.0133, -0.0139, -0.0137, -0.0139],
          [ 0.0102,  0.0101,  0.0101,  0.0101,  0.0101]],

         [[ 0.0331,  0.0325,  0.0330,  0.0326,  0.0329],
          [ 0.0401,  0.0392,  0.0399,  0.0390,  0.0396],
          [-0.0141, -0.0133, -0.0140, -0.0138, -0.0140],
          [ 0.0110,  0.0108,  0.0109,  0.0108,  0.0109]],

         [[ 0.0321,  0.0326,  0.0322,  0.0326,  0.0324],
          [ 0.0412,  0.0395,  0.0407,  0.0391,  0.0401],
          [-0.0121, -0.0117, -0.0121, -0.0120, -0.0121],
          [ 0.0122,  0.0118,  0.0121,  0.0120,  0.0120]],

         [[ 0.0322,  0.0322,  0.0323,  0.0324,  0.0324],
          [ 0.0408,  0.0395,  0.0403,  0.0391,  0.0399],
          [-0.0121, -0.0119, -0.0120, -0.0121, -0.0120],
          [ 0.0116,  0.0110,  0.0114,  0.0111,  0.0113]],

         [[ 0.0325,  0.0322,  0.0324,  0.0323,  0.032

In [36]:
#Test how torch.split works
import torch
tensor = torch.tensor([1,2,3,4,5,6], dtype = torch.float32)
tensor = tensor.repeat(6,1)
tensor = tensor.view(6,2,-1)
tensor

tensor([[[1., 2., 3.],
         [4., 5., 6.]],

        [[1., 2., 3.],
         [4., 5., 6.]],

        [[1., 2., 3.],
         [4., 5., 6.]],

        [[1., 2., 3.],
         [4., 5., 6.]],

        [[1., 2., 3.],
         [4., 5., 6.]],

        [[1., 2., 3.],
         [4., 5., 6.]]])