<a href="https://colab.research.google.com/github/timlacroix/nips2018-agent/blob/master/Relational_Neural_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Relational Inference

In this session, we implement the ideas described in https://arxiv.org/pdf/1802.04687.pdf .
Most of the code in the solution has been adapted from https://github.com/ethanfetaya/NRI .

First add this drive folder to your own google drive account :
https://drive.google.com/open?id=10Awx22Z8vah5MxBrCSgQGuQaS2HdG2ae

Then follow these setup instructions. The `ls` should show one data folder and a utils.py file.

In [0]:
## SETUP
from google.colab import drive
drive.mount('/content/drive')


In [0]:
%cd /content/drive/My\ Drive/Summer_School
!ls

## Data, baselines and evaluations

In [0]:
from utils import load_data

loaders, location_range, velocity_range = load_data(batch_size=1, suffix='_springs5')

### Plotting the input data

<span style="color:red">Find a good way to display the input data. Display both the particles and the relation matrix </span>

In [0]:
from matplotlib import pyplot as plt
import numpy as np

plt.style.use('seaborn-notebook')

plt.figure(dpi=150)

for x in loaders['train']:
    num_atoms = 5
    off_diag_idx = np.ravel_multi_index(
        np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)),
        [num_atoms, num_atoms]
    )
    interactions = np.reshape(np.zeros((num_atoms, num_atoms)), [-1, 25])
    interactions[0][off_diag_idx] = x[1]
    interactions = np.reshape(interactions, [5,5])

    atoms = []
    for atom in range(num_atoms):
        this_atom = []
        for t in range(x[0].shape[2]):
            datum = x[0][0][atom][t]
            this_atom.append(datum.tolist())
        this_atom = np.array(this_atom)
        plt.scatter(this_atom[:, 0], this_atom[:, 1], s=3*np.sqrt(np.array(range(x[0].shape[2]))), alpha=0.5)
        atoms.append(this_atom)
    
    for atom_a in range(num_atoms):
        for atom_b in range(atom_a + 1, num_atoms):
            if interactions[atom_a, atom_b] == 1:
                for d1, d2 in zip(atoms[atom_a], atoms[atom_b]):
                    plt.plot(
                        [d1[0], d2[0]],
                        [d1[1], d2[1]],
                        'k-',
                        linewidth = 1,
                        alpha=0.2
                    )
                    
    
    break
plt.show()

## LSTM Baseline

In [0]:
from __future__ import division
from __future__ import print_function

import time
import pickle
import os
import datetime

import torch.optim as optim
from torch.optim import lr_scheduler

suffix = '_springs5'
n_atoms =  5

epochs = 1

n_hidden = 256
n_layers = 2

batch_size = 128
learning_rate = 5e-4
dropout = 0
temp = 0.5

timesteps = 49
prediction_steps = 10
valid_freq = 1

var = 5e-5

loaders, location_range, velocity_range = load_data(batch_size=batch_size, suffix=suffix)

### Model
<span style="color: red">
    Complete the following skeleton to implement the LSTM baseline described in the appendix of the paper.
    <ul>
        <li> step: outputs $x_{t+1} = x_t + \delta$ and the new hidden layer </li>
        <li> forward: run step for $b$ <em>burn-in</em> steps with true data as input. Then predict the rest of the sequence  </li>
    </ul>
</span>

In [0]:
import torch
from torch import nn
from torch.functional import F

class RecurrentBaseline(nn.Module):
    """LSTM model for joint trajectory prediction."""

    def __init__(self, n_in, n_hid, n_out, n_atoms, n_layers, do_prob=0.):
        super(RecurrentBaseline, self).__init__()
        
        # Encode positions to n_hid dimensional space
        self.pos_encoder = nn.Sequential(
            nn.Linear(n_in, n_hid),
            nn.ReLU(),
            nn.Dropout(p=do_prob),
            nn.Linear(n_hid, n_hid),
            nn.ReLU()
        )
        
        # RNN : n_atoms * n_hid -> n_atoms * n_hid. LSTM with n_layers.
        self.rnn = nn.LSTM(n_atoms * n_hid, n_atoms * n_hid, n_layers)  # TODO

        # Decode predicted configuration to physical location
        self.pos_decoder = nn.Sequential(
            nn.Linear(n_atoms * n_hid, n_atoms * n_hid),
            nn.ReLU(),
            nn.Linear(n_atoms * n_hid, n_atoms * n_out)
        )

    def step(self, ins, hidden=None):
        # Input shape: [num_sims, n_atoms, n_in]
        
        # Apply first MLP to encode the coordinates
        x = self.pos_encoder(ins).view(1, ins.size(0), -1)
        
        # Apply LSTM given hidden and encoded input
        x, hidden = self.rnn(x, hidden)
        x = x[0, :, :]
        
        # Apply second MLP to decode the output of the LSTM and compute delta
        x = self.pos_decoder(x).view(ins.size(0), ins.size(1), -1)
        x = x + ins

        # Return both output and hidden
        return x, hidden

    def forward(self, inputs, burn_in_steps=1):
        # Input shape: [num_sims, num_things, num_timesteps, n_in]

        outputs = []
        hidden = None

        for step in range(0, inputs.size(2) - 1):
            # If step <= burn_in_steps, the input is the true input
            # Otherwise it's the output of the previous step
            if step <= burn_in_steps:
                ins = inputs[:, :, step, :]
            else:
                ins = outputs[step - 1]

            output, hidden = self.step(ins, hidden)
            outputs.append(output)

        outputs = torch.stack(outputs, dim=2)

        return outputs

### Eval

In [0]:
def test(data_loader, model):
    total_mse = 0
    counter = 0

    model.eval()
    for batch_idx, (inputs, relations) in enumerate(data_loader):
        inputs = inputs.cuda()
        output = model(inputs, burn_in_steps=timesteps-prediction_steps)
        target = inputs[:, :, 1:, :]

        output = output[:, :, timesteps-prediction_steps:timesteps, :]
        target = target[:, :, timesteps-prediction_steps:timesteps, :]
        total_mse += ((target - output) ** 2).sum().item()
        counter += inputs.shape[0]

    return total_mse / counter
    

In [0]:
from tqdm import tqdm_notebook, tnrange

model = RecurrentBaseline(4, 256, 4, 5, 2, 0.2).cuda()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# TODO MOVE SOMEWHERE AND EXPLAIN
def nll_gaussian(preds, target, variance):
    neg_log_p = ((preds - target) ** 2 / (2 * variance))
    return neg_log_p.sum() / (target.size(0) * target.size(1))

# One epoch of training
def train(epoch):
    t = time.time()
    loss_train = []
    loss_val = []
    mse_train = []
    mse_val = []

    model.train()
    with tqdm_notebook(loaders['train'],  bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.BLUE, Fore.RESET), desc=f'training') as t:
        for data, relations in t:
            data, relations = data.cuda(), relations.cuda()

            output = model(data, burn_in_steps=timesteps-prediction_steps)

            target = data[:, :, 1:, :]
            loss = nll_gaussian(output, target, var)
            mse = F.mse_loss(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_train.append(loss.item())
            mse_train.append(mse.item())
            
            t.set_postfix(loss=loss.item(), mse=mse.item())

    return np.mean(loss_train), np.mean(mse_train) 

# Train model
# t_total = time.time()
# for epoch in tnrange(epochs):
#     train_nll, train_mse = train(epoch)
#     print(f"train : {train_nll} (NLL) / {train_mse} (MSE)")
#     if (epoch + 1) % valid_freq == 0:
#         test_mse = test(loaders['test'], model)
#         print(f"  test : {test_mse} (MSE)")

## Neural Relational Inference model

There are two parts to the model : an encoder that estimates the relation matrix and a decoder that produces a sequence given an estimation of the relation matrix.

### The encoder
The equations for the encoder in the paper are :
$${\bf h}^1_j = f_{emb}({\bf x}_j)$$
$$v\rightarrow e:\quad {\bf h}^1_{(i,j)} = f_e^1([{\bf h}^1_i, {\bf h}^1_j])$$
$$e\rightarrow v:\quad{\bf h}^2_j = f_v^1\big(\sum_{i \neq j}{\bf h}^1_{(i,j)}\big)$$
$$v\rightarrow e:\quad{\bf h}^2_{(i,j)} = f_e^2([{\bf h}_i^2, {\bf h}_j^2])$$

Finally, we do a logistic regression on ${\bf h}^2_{(i,j)}$ to obtain the probabilities of edge / non-edge.

We will represent all functions as multi-layer perceptrons.

Let $f$ be the matrix of features such that row $f_i$ is the feature vector for node $i$. The implementation challenge in the encoder is to efficiently concatenate the $f_i$, $f_j$. We do this using ```index_select(input, dim, indices)```.

Given an input of dimension $atoms \times d$, create two index tensors such that for
```python
    x = torch.index_select(input, 0, id1)
    y = torch.index_select(input, 0, id2)
```
We have $x_{i*atoms + j} = input_i$ and $x_{i*atoms + j} = input_j$. 

In [0]:
n_atoms = 5
d = 2
features = torch.FloatTensor([[i] * d for i in range(n_atoms)])

id1 = torch.LongTensor(sum([[i] * n_atoms for i in range(n_atoms)], []))  ## TODO
id2 = torch.LongTensor(sum([list(range(n_atoms)) for i in range(n_atoms)], []))  ## TODO

We can now easily write the concatenation in the $v\rightarrow e$ step:

In [0]:
def v_to_e(x, id1, id2):
    ## TODO
    return torch.cat([
        torch.index_select(x, 0, id1),
        torch.index_select(x, 0, id2),
    ], 1)

Read and understand this implementation of the aggregation in the $e \rightarrow v$ step. Note that self-loops are considered here. We will fix that later.

In [0]:
aggregator = torch.FloatTensor([
    [1./n_atoms if row * n_atoms <= col < (row + 1) * n_atoms else 0 for col in range(n_atoms * n_atoms)]
    for row in range(n_atoms)
])

def e_to_v(x, matrix):
    return matrix @ x

In order to remove self-loops, we will use another index select. Given a tensor resulting from the v\_to\_e funciton above, write a function using index select that returns all edges except self edges. 

In [0]:
id3 = torch.LongTensor([
    i for i in range(n_atoms * n_atoms)
    if i not in set([j*n_atoms + j for j in range(n_atoms)])
])

# print(torch.index_select(v_to_e(features, id1, id2), 0, id3))

In [0]:
import torch.nn.functional as F

class MLP(nn.Module):
    """Two-layer fully-connected ELU net with batch norm."""

    def __init__(self, n_in, n_hid, n_out, do_prob=0.):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc2 = nn.Linear(n_hid, n_out)
        self.bn = nn.BatchNorm1d(n_out)
        self.dropout_prob = do_prob

    def batch_norm(self, inputs):
        x = inputs.view(inputs.size(0) * inputs.size(1), -1)
        x = self.bn(x)
        return x.view(inputs.size(0), inputs.size(1), -1)

    def forward(self, inputs):
        # Input shape: [num_sims, num_things, num_features]
        x = F.elu(self.fc1(inputs))
        x = F.dropout(x, self.dropout_prob, training=self.training)
        x = F.elu(self.fc2(x))
        return self.batch_norm(x)



def ids_and_agg(n_atoms, no_self_edges=False):
    n_for_agg = (n_atoms - 1) if no_self_edges else n_atoms
    return (
        torch.cuda.LongTensor(sum([[i] * n_atoms for i in range(n_atoms)], [])),
        torch.cuda.LongTensor(sum([list(range(n_atoms)) for i in range(n_atoms)], [])),
        torch.cuda.FloatTensor([
            [1. / n_for_agg if row * n_for_agg <= col < (row + 1) * n_for_agg else 0
             for col in range(n_for_agg * n_atoms)]
            for row in range(n_atoms)
        ]),
        torch.cuda.LongTensor([
            i for i in range(n_atoms * n_atoms)
            if i not in set([j*n_atoms + j for j in range(n_atoms)])
        ])
    )

class MLPEncoder(nn.Module):
    def __init__(self, n_atoms, n_in, n_hid, n_out, do_prob=0.):
        super(MLPEncoder, self).__init__()
        self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob)
        self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
        self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob)
        self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob)
        self.fc_out = nn.Linear(n_hid, n_out)
        
        self.id1, self.id2, self.aggregator, self.id3 = ids_and_agg(n_atoms)
        
    def tile(self, x):
        return torch.cat([
            torch.index_select(x, 1, self.id1),
            torch.index_select(x, 1, self.id2),
        ], 2)

    def aggregate(self, x):
        return self.aggregator @ x

    def forward(self, inputs):
        # Input shape: [num_sims, num_atoms, num_timesteps, num_dims]
        x = inputs.view(inputs.size(0), inputs.size(1), -1)
        # New shape: [num_sims, num_atoms, num_timesteps*num_dims]
        
        x = self.mlp1(x)                  # eq 1
        x_skip = self.mlp2(self.tile(x))       # eq 2
        x = self.mlp3(self.aggregate(x_skip))  # eq 3
        x = self.mlp4(torch.cat((self.tile(x), x_skip), dim=2))       # eq 4
        
        logits = self.fc_out(x)
        return torch.index_select(logits, 1, self.id3)       # remove self-edges

Before adding the decoder, let's verify that this MLP Encoder can at least overfit the training data.

In [0]:
from tqdm import tqdm_notebook, tnrange
from torch.nn import BCEWithLogitsLoss
n_atoms = 5
model = MLPEncoder(n_atoms, int(4 * timesteps), 256, 2, 0.2).cuda()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def edge_accuracy(preds, target):
    """
    :param preds: edge logits
    :param target: ground truth
    :return: precision of the prediction
    """
    _, preds = preds.max(-1)
    correct = preds.float().data.eq(
        target.float().data.view_as(preds)).cpu().sum()
    return np.float(correct) / (target.size(0) * target.size(1))

# One epoch of training
def train(epoch):
    t = time.time()
    loss_train = []
    acc_train = []
    loss = nn.BCEWithLogitsLoss(reduction='mean')
    model.train()
    with tqdm_notebook(loaders['train'], 'training') as t:
        for data, relations in t:
            data, relations = data.cuda(), relations.cuda()
            logits = model(data)
            # TODO
            l = loss(
                logits,
                torch.cat((
                    (relations == 0)[:,:, None],
                    (relations == 1)[:,:, None]
                ), 2).float()
            )

            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            edge_acc = edge_accuracy(logits, relations)
            
            loss_train.append(l.item())
            acc_train.append(edge_acc)
            
            t.set_postfix(loss=l.item(), acc=edge_acc)

    return np.mean(loss_train), np.mean(acc_train)

# Train model
t_total = time.time()
best_epoch = 0
for epoch in tnrange(2):
    train_loss, train_acc = train(epoch)

### The decoder

The equations of the recurrent decoder are :
$$
\begin{aligned} v \rightarrow e : \tilde{\mathbf{h}}_{(i, j)}^{t} &=\sum_{k} z_{i j, k} \tilde{f}_{e}^{k}\left(\left[\tilde{\mathbf{h}}_{i}^{t}, \tilde{\mathbf{h}}_{j}^{t}\right]\right) \\ e \rightarrow v : \operatorname{MSG}_{j}^{t} &=\sum_{i \neq j} \tilde{\mathbf{h}}_{(i, j)}^{t} \\ \tilde{\mathbf{h}}_{j}^{t+1} &=\operatorname{GRU}\left(\left[\operatorname{MSG}_{j}^{t}, \mathbf{x}_{j}^{t}\right], \tilde{\mathbf{h}}_{j}^{t}\right) \\ \boldsymbol{\mu}_{j}^{t+1} &=\mathbf{x}_{j}^{t}+f_{\text { out }}\left(\tilde{\mathbf{h}}_{j}^{t+1}\right) \\ p\left(\mathbf{x}^{t+1} | \mathbf{x}^{t}, \mathbf{z}\right) &=\mathcal{N}\left(\boldsymbol{\mu}^{t+1}, \sigma^{2} \mathbf{I}\right) \end{aligned}
$$

We will use only one edge type for simplicity. The last equation will be taken care of by the loss. Complete the following skeleton code for the `RNNDecoder` module.

In [0]:
class RNNDecoder(nn.Module):
    """Recurrent decoder module."""

    def __init__(self, n_dims, n_hid, do_prob=0.):
        super(RNNDecoder, self).__init__()
        # Linear, Tanh, Dropout, Linear, Tanh
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * n_hid, n_hid), nn.Tanh(), nn.Dropout(do_prob),
            nn.Linear(n_hid, n_hid), nn.Tanh()
        )
        # GruCell
        self.gru = nn.GRUCell(input_size=n_dims, hidden_size=n_hid)
        # Linear, ReLU, Linear, ReLU, Linear
        self.decoder = nn.Sequential(
            nn.Linear(n_hid, n_hid), nn.ReLU(),
            nn.Linear(n_hid, n_hid), nn.ReLU(),
            nn.Linear(n_hid, n_dims)
        )
        self.n_dims = n_dims
        self.n_hid = n_hid
        
        self.id1, self.id2, self.aggregator, self.id3 = ids_and_agg(n_atoms, True)
        
    def single_step_forward(self, inputs, edges, hidden):
        hidden_state = torch.cat([
            torch.index_select(hidden, 1, self.id1),
            torch.index_select(hidden, 1, self.id2)
        ], 2)
        hidden_without_self = torch.index_select(hidden_state, 1, self.id3) * edges[:,:, 1].unsqueeze(2)
        transformed = self.edge_mlp(hidden_without_self)
        hidden_state = self.aggregator @ transformed
        
        next_hidden = self.gru(
            inputs.contiguous().view(-1, self.n_dims),
            hidden_state.view(-1, self.n_hid)
        ).reshape(hidden.shape)
        
        output = self.decoder(next_hidden) + inputs

        return output, next_hidden

    def forward(self, data, edges, burn_in_steps=1):

        inputs = data.transpose(1, 2).contiguous()

        time_steps = inputs.size(1)

        hidden = torch.zeros(inputs.size(0), inputs.size(2), self.n_hid).cuda()
        pred_all = []

        for step in range(inputs.size(1) - 1):
            if step <= burn_in_steps:
                ins = inputs[:, step, :, :]
            else:
                ins = pred_all[step - 1]

            pred, hidden = self.single_step_forward(ins, edges, hidden)
            pred_all.append(pred)

        preds = torch.stack(pred_all, dim=1)

        return preds.transpose(1, 2)


### Minimizing the negative ELBO

Our goal is to minimize the following loss :
$$
\mathcal{L}=-\mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})}\left[\log p_{\theta}(\mathbf{x} | \mathbf{z})\right]+\mathrm{KL}\left[q_{\phi}(\mathbf{z} | \mathbf{x}) \| p_{\theta}(\mathbf{z})\right]
$$

Complete the following skeleton code to compute the loss.

### Putting it all together

In [0]:
from torch.nn.functional import gumbel_softmax, softmax

def kl_categorical_uniform(preds, num_atoms):
    kl_div = preds * torch.log(preds + 1e-16)
    return kl_div.sum() / (num_atoms * preds.size(0))

def train(data_loader, optimizer, encoder, decoder):
    loss_train = []

    encoder.train()
    decoder.train()
    with tqdm_notebook(data_loader, desc=f'training') as t:
        for data, relations in t:
            data, relations = data.cuda(), relations.cuda()     

            # Encode
            logits = encoder(data)

            # Compute edges with soft gumbel_softmax
            edges = gumbel_softmax(
                logits.view(-1, 2), tau=temp, hard=False
            ).view(logits.shape)

            # Decode using the edge weights
            output = decoder(
                data, edges, burn_in_steps=timesteps-prediction_steps
            )

            nll = nll_gaussian(
                output, data[:,:,1:,:], var
            )
            kl = kl_categorical_uniform(
                softmax(logits, 2), output.shape[1]
            )
            l = nll + kl
            
            edge_acc = edge_accuracy(logits, relations)

            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            loss_train.append(l.item())
            t.set_postfix(loss=l.item(), nll = nll.item(), kl = kl.item(), acc=edge_acc)
        
    return np.mean(loss_train)

dropout = 0
n_dims = 4
hidden = 256

encoder = MLPEncoder(num_atoms, int(n_dims * timesteps), hidden, 2, dropout).cuda()
decoder = RNNDecoder(n_dims, hidden, dropout).cuda()

optimizer = optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=learning_rate
)

for e in range(10):
    loss = train(loaders['train'], optimizer, encoder, decoder)
    print(f"{loss}")