### Neural Probablistic Physical Engine

In [1]:
import os
import numpy as np

In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer
from torch_geometric.data import Data

In [32]:
def create_fully_connected_edge_index(n_objects):
    n_relations  = n_objects * (n_objects - 1)
    edge_index = torch.zeros((2, n_relations), dtype=torch.long)
    count = 0
    for i in range(n_objects):
        for j in range(n_objects):
            if(i != j):
                edge_index[0, count] = i
                edge_index[1, count] = j
                count += 1
    return edge_index

In [33]:
edge_index = create_fully_connected_edge_index(3)

In [13]:
x = torch.zeros((3, 5))

In [19]:
data = Data(x=x, edge_index=edge_index)

In [38]:
class EdgeModel(nn.Module):
    def __init__(self, node_dims, edge_dims, u_dims, hidden_size=32):
        super(EdgeModel, self).__init__()
        input_size = 2*node_dims+edge_dims
        self.edge_mlp = Seq(Lin(input_size, hidden_size), ReLU(), Lin(hidden_size, edge_dims))

    def forward(self, src, dest, edge_attr, u, batch):
        # source, target: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = torch.cat([src, dest, edge_attr], dim=-1)
        return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
    def __init__(self, node_dims, edge_dims, u_dims, hidden_size=32):
        super(NodeModel, self).__init__()
        mlp_1_input_size = node_dims+edge_dims
        self.node_mlp_1 = Seq(Lin(mlp_1_input_size, hidden_size), ReLU(), Lin(hidden_size, hidden_size))
        mlp_2_input_size = node_dims+hidden_size
        self.node_mlp_2 = Seq(Lin(mlp_2_input_size, hidden_size), ReLU(), Lin(hidden_size, node_dims))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=-1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=1, dim_size=x.size(1))
        out = torch.cat([x, out], dim=-1)
        return self.node_mlp_2(out)

class GlobalModel(torch.nn.Module):
    def __init__(self, node_dims, u_dims, hidden_size=32):
        super(GlobalModel, self).__init__()
        input_size = node_dims
        self.global_mlp = Seq(Lin(input_size, hidden_size), ReLU(), Lin(hidden_size, u_dims))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        out = torch.mean(x, dim=1, keepdim=False)
        #out = scatter_mean(x, batch, dim=0)
        return self.global_mlp(out)

In [77]:
class GNNLSTMRecgnition(nn.Module):
    def __init__(self, hidden_size=16):
        super().__init__()
        ## GNN ##
        self.node_dims = 7+2 #X+y
        self.hidden_size=hidden_size
        self.edge_dims=hidden_size
        self.u_dims=hidden_size 
        edge_model = EdgeModel(self.node_dims, self.edge_dims, self.u_dims, self.hidden_size)
        node_model = NodeModel(self.node_dims, self.edge_dims, self.u_dims, self.hidden_size)
        global_model = GlobalModel(self.node_dims, self.u_dims, self.hidden_size)
        self.op = MetaLayer(edge_model, node_model, global_model)
        ## LSTM ##
        self.edge_rnn = nn.LSTM(self.edge_dims, self.hidden_size, batch_first=True)
        self.node_rnn = nn.LSTM(self.node_dims, self.hidden_size, batch_first=True)
        self.global_rnn = nn.LSTM(self.u_dims, self.hidden_size, batch_first=True)
        ## Linear ##
        self.edge_fc_1 = nn.Linear(self.hidden_size, 2)
        self.edge_fc_2 = nn.Linear(self.hidden_size, 2)
        self.node_fc_1 = nn.Linear(self.hidden_size, 2)
        self.node_fc_2 = nn.Linear(self.hidden_size, 2)
        self.global_fc_1 = nn.Linear(self.hidden_size, 2)
        self.global_fc_2 = nn.Linear(self.hidden_size, 2)

    def forward(self, x):
        # define the forward computation on the latent z
        # x shape: [n_experiments, steps, nodes, node_dims]
        bs, steps, nodes, node_dims = x.size()
        edge_index = create_fully_connected_edge_index(nodes)
        _, edges = edge_index.size()
        x_reshape = x.view(-1, nodes, node_dims)
        node_attr, edge_attr, global_attr = op(x_reshape, edge_index)
        node_attr = node_attr.view(bs, steps, nodes, self.node_dims).permute(0,2,1,3)
        edge_attr = edge_attr.view(bs, steps, edges, self.edge_dims).permute(0,2,1,3)
        global_attr = global_attr.view(bs, steps, self.u_dims)
        node_attr = node_attr.view(-1, steps, self.node_dims)
        edge_attr = edge_attr.view(-1, steps, self.edge_dims)
        # RNN forward
        node_out, _ = self.node_rnn(node_attr)
        edge_out, _ = self.edge_rnn(edge_attr)
        global_out, _ = self.global_rnn(global_attr)
        node_out  = node_out[:,-1]
        edge_out = edge_out[:,-1]
        global_out = global_out[:,-1]
        # FC forward
        z_node_mu = self.node_fc_1(node_out).view(bs, nodes, -1)
        z_node_logvar = self.node_fc_2(node_out).view(bs, nodes, -1)
        z_edge_mu = self.edge_fc_1(edge_out).view(bs, edges, -1)
        z_edge_logvar = self.edge_fc_2(edge_out).view(bs, edges, -1)
        z_global_mu = self.global_fc_1(global_out).view(bs, -1)
        z_global_logvar = self.global_fc_2(global_out).view(bs, -1)
        return [z_node_mu, z_node_logvar, z_edge_mu, z_edge_logvar, z_global_mu, z_global_logvar]

In [78]:
class RelationalModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(RelationalModel, self).__init__()
        
        self.output_size = output_size
        
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.ReLU()
        )
    
    def forward(self, x):
        '''
        Args:
            x: [batch_size, n_relations, input_size]
        Returns:
            [batch_size, n_relations, output_size]
        '''
        batch_size, n_relations, input_size = x.size()
        x = x.view(-1, input_size)
        x = self.layers(x)
        x = x.view(batch_size, n_relations, self.output_size)
        return x

In [79]:
class ObjectModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ObjectModel, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 2), #speedX and speedY
        )
        
    def forward(self, x):
        '''
        Args:
            x: [batch_size, n_objects, input_size]
        Returns:
            [batch_size * n_objects, 2] speedX and speedY
        '''
        input_size = x.size(2)
        x = x.view(-1, input_size)
        return self.layers(x)

In [80]:
a = torch.zeros((3,3))
b = torch.ones((1,1)).repeat(1,3)

In [76]:
class InteractionNetworkGenerator(nn.Module):
    def __init__(self, effect_dims=4, hidden_size=128):
        super().__init__()
        # setup the two linear transformations used
        self.node_dims = 7+2 #X+z_node
        self.z_edge_dims = 2
        self.z_u_dims = 2
        self.effect_dims = effect_dims
        self.hidden_size = hidden_size
        self.relational_model = RelationalModel(2*self.node_dims + self.z_edge_dims, 
                                                self.effect_dims, self.hidden_size)
        self.object_model     = ObjectModel(self.node_dims+self.effect_dims+self.z_u_dims,
                                            self.hidden_size)
    def forward(self, x, z):
        # x shape: [bs, steps, nodes, node_dims]
        bs, steps, nodes, node_dims = x.size()
        edge_index = create_fully_connected_edge_index(nodes)
        sender_relations = F.one_hot(edge_index[0]).T.unsqueeze(0).repeat(bs*steps,1,1)
        receiver_relations = F.one_hot(edge_index[1]).T.unsqueeze(0).repeat(bs*steps,1,1)
        _, edges = edge_index.size()
        row, col = edge_index
        # define the forward computation on the latent z
        z_node, z_edge, z_global = z
        z_node = torch.unsqueeze(z_node, dim=1).repeat(1, steps, 1, 1)
        z_edge = z_edge.repeat(steps, 1, 1)
        z_global = torch.unsqueeze(z_global, dim=1).repeat(steps, nodes, 1)
        x = torch.cat((x, z_node), dim=-1).view(bs*steps, nodes, -1)
        senders   = sender_relations.permute(0, 2, 1).bmm(x)
        receivers = receiver_relations.permute(0, 2, 1).bmm(x)
        effects = self.relational_model(torch.cat([senders, receivers, z_edge], -1))
        effect_receivers = receiver_relations.bmm(effects)
        # predicted shape [bs*steps, nodes, 2]
        predicted = self.object_model(torch.cat([x, z_global,effect_receivers], -1))
        return predicted.view(bs, steps, nodes, 2)

In [None]:
class PhysicsVAE(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, z_dim=2, use_cuda=False):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = GNNLSTMRecgnition()
        self.decoder = InteractionNetworkGenerator()

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
    
    def encode(self, xy):
        return self.encoder(xy)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, x, z):
        return self.decoder(x, z)

    def forward(self, x, y):
        xy = torch.cat((x, y), dim=-1)
        z_stats = self.encode(xy)
        z_node_mu, z_node_logvar, z_edge_mu, z_edge_logvar, z_global_mu, z_global_logvar = z_stats
        z_node = self.reparameterize(z_node_mu, z_node_logvar)
        z_edge = self.reparameterize(z_edge_mu, z_edge_logvar)
        z_global = self.reparameterize(z_global_mu, z_global_logvar)
        z_sample = [z_node, z_edge, z_global]
        return self.decode(x, z_sample), mu, logvar
    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            # score against actual images
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z)
        return loc_img