## Neural Probablistic Physics Engine

In [24]:
from IPython.core.debugger import set_trace

In [271]:
import os
import numpy as np 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_scatter import scatter_mean

In [3]:
from torch.utils.data import Dataset

In [183]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cuda'

In [64]:
def create_fully_connected_edge_index(n_objects):
    n_relations  = n_objects * (n_objects - 1)
    edge_index = torch.zeros((2, n_relations), dtype=torch.long)
    count = 0
    for i in range(n_objects):
        for j in range(n_objects):
            if(i != j):
                edge_index[0, count] = i
                edge_index[1, count] = j
                count += 1
    return edge_index

In [294]:
class BallSimulationDataset(Dataset):
    def __init__(self, root, raw_folder_name, processed_folder_name, use_cuda=False):
        super().__init__()
        self.root = root
        self.raw_folder = os.path.join(root, raw_folder_name)
        self.processed_folder = os.path.join(root, processed_folder_name)
        if not os.path.exists(self.raw_folder):
            os.mkdir(self.raw_folder)
        if not os.path.exists(self.processed_folder):
            os.mkdir(self.processed_folder)
        self.filenames = [fn.split(".")[0] for fn in os.listdir(self.raw_folder) if "npz" in fn]
        self._process_raw_files()
        if use_cuda:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = 'cpu'
        
    def _process_raw_files(self):
        for filename in self.filenames:
            pt_fn = os.path.join(self.processed_folder, filename+'.pt')
            if not os.path.exists(pt_fn):
                raw_file_path = os.path.join(self.raw_folder, filename+'.npz')
                batch = np.load(raw_file_path)
                batch_x = torch.Tensor(batch["X"])
                batch_y = torch.Tensor(batch["Y"])
                n_objects = batch_x.shape[2]
                edge_index = create_fully_connected_edge_index(n_objects)
                torch.save((batch_x, batch_y, edge_index), pt_fn)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        pt_fn = os.path.join(self.processed_folder, filename+'.pt')
        batch = torch.load(pt_fn)
        batch = [item.to(self.device) for item in batch]
        return batch
    
    def __len__(self):
        return len(self.filenames)

In [201]:
root = "/media/data/pymunk_dataset/"
raw_folder_name = "raw"
processed_folder_name = "processed"
dataset = BallSimulationDataset(root, raw_folder_name, processed_folder_name, use_cuda=True)

In [202]:
batch_x, batch_y, edge_index = dataset[1]

In [188]:
class EdgeModel(nn.Module):
    def __init__(self, node_dims, edge_dims, u_dims, hidden_size=32):
        super().__init__()
        input_size = 2*node_dims
        self.edge_mlp = Seq(Lin(input_size, hidden_size), ReLU(), Lin(hidden_size, edge_dims))

    def forward(self, src, dest):
        # source, target: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = torch.cat([src, dest], dim=-1)
        return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
    def __init__(self, node_dims, edge_dims, u_dims, hidden_size=32):
        super().__init__()
        mlp_1_input_size = node_dims+edge_dims
        self.node_mlp_1 = Seq(Lin(mlp_1_input_size, hidden_size), ReLU(), Lin(hidden_size, hidden_size))
        mlp_2_input_size = node_dims+hidden_size
        self.node_mlp_2 = Seq(Lin(mlp_2_input_size, hidden_size), ReLU(), Lin(hidden_size, node_dims))

    def forward(self, x, edge_index, edge_attr):
        # x: [bs, N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [bs, E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = torch.cat([x[:,row,:], edge_attr], dim=-1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=1, dim_size=x.size(1))
        out = torch.cat([x, out], dim=-1)
        return self.node_mlp_2(out)

class GlobalModel(torch.nn.Module):
    def __init__(self, node_dims, u_dims, hidden_size=32):
        super().__init__()
        input_size = node_dims
        self.global_mlp = Seq(Lin(input_size, hidden_size), ReLU(), Lin(hidden_size, u_dims))

    def forward(self, x):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        out = torch.mean(x, dim=1, keepdim=False)
        #out = scatter_mean(x, batch, dim=0)
        return self.global_mlp(out)

In [189]:
class MetaLayer(torch.nn.Module):
    """A meta layer for building any kind of graph network, inspired by the
    Relational Inductive Biases, Deep Learning, and Graph Networks
    <https://arxiv.org/abs/1806.01261>`_ paper.

    Args:
        edge_model (Module, optional): A callable which updates a graph's edge
            features based on its source and target node features, its current
            edge features and its global features. (default: :obj:`None`)
        node_model (Module, optional): A callable which updates a graph's node
            features based on its current node features, its graph
            connectivity, its edge features and its global features.
            (default: :obj:`None`)
        global_model (Module, optional): A callable which updates a graph's
            global features based on its node features, its graph connectivity,
            its edge features and its current global features.
    """
    def __init__(self, edge_model=None, node_model=None, global_model=None):
        super().__init__()
        self.edge_model = edge_model
        self.node_model = node_model
        self.global_model = global_model

        self.reset_parameters()

    def reset_parameters(self):
        for item in [self.node_model, self.edge_model, self.global_model]:
            if hasattr(item, 'reset_parameters'):
                item.reset_parameters()

    def forward(self, x, edge_index):
        """"""
        row, col = edge_index
        if self.edge_model is not None:
            edge_attr = self.edge_model(x[:,row,:], x[:,col,:])
        
        if self.node_model is not None:
            x = self.node_model(x, edge_index, edge_attr)
        if self.global_model is not None:
            u = self.global_model(x)
        return x, edge_attr, u

    def __repr__(self):
        return ('{}(\n'
                '    edge_model={},\n'
                '    node_model={},\n'
                '    global_model={}\n'
                ')').format(self.__class__.__name__, self.edge_model,
                            self.node_model, self.global_model)

In [272]:
class GNNLSTMRecgnition(nn.Module):
    def __init__(self, hidden_size=16, latent_size=2):
        super().__init__()
        ## GNN ##
        self.node_dims = 7+2 #X+y
        self.hidden_size=hidden_size
        self.edge_dims=hidden_size
        self.u_dims=hidden_size 
        edge_model = EdgeModel(self.node_dims, self.edge_dims, self.u_dims, self.hidden_size)
        node_model = NodeModel(self.node_dims, self.edge_dims, self.u_dims, self.hidden_size)
        global_model = GlobalModel(self.node_dims, self.u_dims, self.hidden_size)
        self.op = MetaLayer(edge_model, node_model, global_model)
        ## LSTM ##
        self.edge_rnn = nn.LSTM(self.edge_dims, self.hidden_size, batch_first=True)
        self.node_rnn = nn.LSTM(self.node_dims, self.hidden_size, batch_first=True)
        self.global_rnn = nn.LSTM(self.u_dims, self.hidden_size, batch_first=True)
        ## Linear ##
        self.edge_fc_1 = nn.Linear(self.hidden_size, latent_size)
        self.edge_fc_2 = nn.Linear(self.hidden_size, latent_size)
        self.node_fc_1 = nn.Linear(self.hidden_size, latent_size)
        self.node_fc_2 = nn.Linear(self.hidden_size, latent_size)
        self.global_fc_1 = nn.Linear(self.hidden_size, latent_size)
        self.global_fc_2 = nn.Linear(self.hidden_size, latent_size)

    def forward(self, x, edge_index):
        # define the forward computation on the latent z
        # x shape: [n_experiments, steps, nodes, node_dims]
        bs, steps, nodes, node_dims = x.size()
        _, edges = edge_index.size()
        x_reshape = x.view(-1, nodes, node_dims)
        node_attr, edge_attr, global_attr = self.op(x_reshape, edge_index)
        node_attr = node_attr.view(bs, steps, nodes, self.node_dims).permute(0,2,1,3)
        edge_attr = edge_attr.view(bs, steps, edges, self.edge_dims).permute(0,2,1,3)
        global_attr = global_attr.view(bs, steps, self.u_dims)
        node_attr = node_attr.reshape(-1, steps, self.node_dims)
        edge_attr = edge_attr.reshape(-1, steps, self.edge_dims)
        # RNN forward
        node_out, _ = self.node_rnn(node_attr)
        edge_out, _ = self.edge_rnn(edge_attr)
        global_out, _ = self.global_rnn(global_attr)
        # get the outputs of the last time step
        node_out  = node_out[:, -1, :]
        edge_out = edge_out[:, -1, :]
        global_out = global_out[:, -1, :]
        # FC forward
        z_node_mu = self.node_fc_1(node_out).view(bs, nodes, -1)
        z_node_logvar = self.node_fc_2(node_out).view(bs, nodes, -1)
        z_edge_mu = self.edge_fc_1(edge_out).view(bs, edges, -1)
        z_edge_logvar = self.edge_fc_2(edge_out).view(bs, edges, -1)
        z_global_mu = self.global_fc_1(global_out).view(bs, -1)
        z_global_logvar = self.global_fc_2(global_out).view(bs, -1)
        return [z_node_mu, z_node_logvar, z_edge_mu, z_edge_logvar, z_global_mu, z_global_logvar]

In [273]:
class RelationalModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        
        self.output_size = output_size
        
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.ReLU()
        )
    
    def forward(self, x):
        '''
        Args:
            x: [batch_size, n_relations, input_size]
        Returns:
            [batch_size, n_relations, output_size]
        '''
        batch_size, n_relations, input_size = x.size()
        x = x.view(-1, input_size)
        x = self.layers(x)
        x = x.view(batch_size, n_relations, self.output_size)
        return x

In [274]:
class DynamicsModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 2), #speedX and speedY
        )
        
    def forward(self, x):
        '''
        Args:
            x: [batch_size, n_objects, input_size]
        Returns:
            [batch_size * n_objects, 2] speedX and speedY
        '''
        input_size = x.size(2)
        x = x.view(-1, input_size)
        return self.layers(x)

In [275]:
class InteractionNetworkGenerator(nn.Module):
    def __init__(self, effect_dims=4, hidden_size=128, latent_size=2):
        super().__init__()
        # setup the two linear transformations used
        self.node_dims = 7+latent_size #X+z_node
        self.z_edge_dims = latent_size
        self.z_u_dims = latent_size
        self.effect_dims = effect_dims
        self.hidden_size = hidden_size
        self.relational_model = RelationalModel(2*self.node_dims + self.z_edge_dims, 
                                                self.effect_dims, self.hidden_size)
        self.object_model     = DynamicsModel(self.node_dims+self.effect_dims+self.z_u_dims,
                                              self.hidden_size)
    def forward(self, x, z, edge_index):
        # x shape: [bs, steps, nodes, node_dims]
        bs, steps, nodes, node_dims = x.size()
        sender_relations = F.one_hot(edge_index[0]).T.unsqueeze(0).repeat(bs*steps,1,1).type(torch.float)
        receiver_relations = F.one_hot(edge_index[1]).T.unsqueeze(0).repeat(bs*steps,1,1).type(torch.float)
        _, edges = edge_index.size()
        row, col = edge_index
        # define the forward computation on the latent z
        z_node, z_edge, z_global = z
        z_node = torch.unsqueeze(z_node, dim=1).repeat(1, steps, 1, 1)
        z_edge = z_edge.repeat(steps, 1, 1)
        z_global = torch.unsqueeze(z_global, dim=1).repeat(steps, nodes, 1)
        x = torch.cat((x, z_node), dim=-1).view(bs*steps, nodes, -1)
        senders   = sender_relations.permute(0, 2, 1).bmm(x)
        receivers = receiver_relations.permute(0, 2, 1).bmm(x)
        effects = self.relational_model(torch.cat([senders, receivers, z_edge], -1))
        effect_receivers = receiver_relations.bmm(effects)
        # predicted shape [bs*steps, nodes, 2]
        predicted = self.object_model(torch.cat([x, z_global,effect_receivers], -1))
        return predicted.view(bs, steps, nodes, 2)

In [282]:
class PhysicsVAE(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, z_dim=128, effect_dims=128, use_cuda=False):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = GNNLSTMRecgnition(hidden_size=128, latent_size=z_dim)
        self.decoder = InteractionNetworkGenerator(effect_dims=effect_dims, hidden_size=128, latent_size=z_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
    
    def encode(self, xy, edge_index):
        return self.encoder(xy, edge_index)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, x, z, edge_index):
        return self.decoder(x, z, edge_index)

    def forward(self, x, y, edge_index):
        xy = torch.cat((x, y), dim=-1)
        self.z_stats = self.encode(xy, edge_index)
        z_node_mu, z_node_logvar, z_edge_mu, z_edge_logvar, z_global_mu, z_global_logvar = self.z_stats
        z_node = self.reparameterize(z_node_mu, z_node_logvar)
        z_edge = self.reparameterize(z_edge_mu, z_edge_logvar)
        z_global = self.reparameterize(z_global_mu, z_global_logvar)
        z_sample = [z_node, z_edge, z_global]
        return self.decode(x, z_sample, edge_index), self.z_stats
    
    def inference(self, x_test, edge_index):
        # x_test shape [1, steps, n_objects, node_dims]
        with torch.no_grad():
            z_node_mu, z_node_logvar, z_edge_mu, z_edge_logvar, z_global_mu, z_global_logvar = self.z_stats
            ez_node_logvar = torch.mean(z_node_logvar, dim=0)
            ez_edge_mu = torch.mean(z_edge_mu, dim=0)
            ez_edge_logvar = torch.mean(z_edge_logvar, dim=0)
            ez_global_mu = torch.mean(z_global_mu, dim=0)
            ez_global_logvar = torch.mean(z_global_logvar, dim=0)
            z_node = self.reparameterize(ez_node_mu, ez_node_logvar).unsqueeze(0)
            z_edge = self.reparameterize(ez_edge_mu, ez_edge_logvar).unsqueeze(0)
            z_global = self.reparameterize(ez_global_mu, ez_global_logvar).unsqueeze(0)
            z_sample = [z_node, z_edge, z_global]
            pred = self.decode(x_test, z_sample, edge_index)
        return pred

In [283]:
vae = PhysicsVAE(use_cuda=True)

In [285]:
def train(epoch, use_cuda=False):
    # Load dataset
    root = "/media/data/pymunk_dataset/"
    raw_folder_name = "raw"
    processed_folder_name = "processed"
    dataset = BallSimulationDataset(root, raw_folder_name, processed_folder_name, 
                                    use_cuda=use_cuda)
    # Create model
    model = PhysicsVAE(use_cuda=use_cuda)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    model.train()
    # Training
    for e_idx in range(epoch):
        train_loss = 0
        for batch_idx, (batch_x, batch_y, edge_index) in enumerate(dataset):
            optimizer.zero_grad()
            recon_batch, z_stats = model(batch_x, batch_y, edge_index)
            loss, reconstr_loss = loss_function(recon_batch, batch_y, z_stats)
            loss.backward()
            train_loss += reconstr_loss.item()
            optimizer.step()
        if e_idx % 50 == 0:
            print(e_idx, train_loss)
    
    return model

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [292]:
def loss_function(recon_x, x, z_stats):
    beta = 1
    theta = 1e+20
    MSE = F.mse_loss(recon_x, x, reduction='sum')
    mmse = F.mse_loss(recon_x, x, reduction='mean')
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    z_node_mu, z_node_logvar, z_edge_mu, z_edge_logvar, z_global_mu, z_global_logvar = z_stats
    KLD_node = -0.5 * torch.sum(1 + z_node_logvar - z_node_mu.pow(2) - z_node_logvar.exp())
    KLD_edge = -0.5 * torch.sum(1 + z_edge_logvar - z_edge_mu.pow(2) - z_edge_logvar.exp())
    KLD_global = -0.5 * torch.sum(1 + z_global_logvar - z_global_mu.pow(2) - z_global_logvar.exp())
    KLD = KLD_node+KLD_edge+KLD_global
    # latent discrepancy within batch
    MSE_node = torch.sum(torch.var(z_node_mu, dim=0))+torch.sum(torch.var(z_node_logvar, dim=0))
    MSE_edge = torch.sum(torch.var(z_edge_mu, dim=0))+torch.sum(torch.var(z_edge_logvar, dim=0))
    MSE_global = torch.sum(torch.var(z_global_mu, dim=0))+torch.sum(torch.var(z_global_logvar, dim=0))
    MSED = MSE_node+MSE_edge+MSE_global
    return MSE + beta*KLD + theta*MSED, mmse

In [293]:
model = train(200, use_cuda=True)

0 216.96520233154297
50 33.0487003326416
100 21.85648822784424
150 16.11427593231201
