## Dataset

We used a subset of WaterDrop dataset from Deepmind. The videos only covers the specific case of a water droplet in vacuum, but that is fine with us, as that is exactly what we wanted to model!

### Converting TFRecord to torch tensors

Unfortunately, the dataset is not available in a format that is easy to use with PyTorch. We need to convert it to a format that is more suitable for PyTorch.

In [1]:
import numpy as np
import tensorflow as tf

# for example in tf.data.TFRecordDataset("./dataset/WaterDropSample/test.tfrecord").take(1):
#     parsed_example = tf.train.Example.FromString(example.numpy())
#     print(parsed_example.features.feature)

raw_dataset = tf.data.TFRecordDataset("./dataset/WaterDropSample/test.tfrecord")

for raw_record in raw_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    # print(example)
    
    result = {}
    # example.features.feature is the dictionary
    for key, feature in example.features.feature.items():
    # The values are the Feature objects which contain a `kind` which contains:
    # one of three fields: bytes_list, float_list, int64_list

        kind = feature.WhichOneof('kind')
        #print(kind)
        result[key] = np.array(getattr(feature, kind).value)
        #print(result[key])
        #print(result[key].dtype.type)

        # exmaple: particle_type: bytes_list -> numpy array unit8 (= byte array)
        # looks like we don't need conversion of float_list and int64_list types (not proven)
        if result[key].dtype.type == np.bytes_:
            arr = np.frombuffer(b''.join(result[key]), dtype=np.uint8)
            arr = arr.reshape((-1, len(result[key][0])))
            result[key] = arr
            #print("akka")

    print(result)

2023-03-22 17:21:01.773827: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-22 17:21:13.251536: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


{'key': array([0]), 'particle_type': array([[5, 0, 0, ..., 0, 0, 5]], dtype=uint8)}


## Preprocessing

1. Apply noise to the training to mitigate error accumulation over long rollouts. We use a simple approach to make the model more robust to noisy inputs: at training we corrupt the input velocities of the model with random-walk noise N (0, $\sigma_v$ = 0.0003) (adjusting input positions), so the training distribution is closer to the distribution generated during rollouts. 
2. Normalize all input and target vectors elementwise to zero mean and unit variance, using statistics computed online during training. Preliminary experiments showed that normalization led to faster training, though converged performance was not noticeably improved.

## GNN model

### MLP

MLP is used in a lot of different places throughout the architecture, most notably the encoder and the decoder are both MLPs. We define it as a class to make it easier to use.

All MLPs have two hidden layers (with ReLU activations), followed by a nonactivated output layer, each layer with size of 128. All MLPs (except the output decoder) are followed by a LayerNorm layer.

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

class MLP(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dim=128, layer_norm=True):
        super(MLP, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.layer3 = nn.Linear(hidden_dim, output_dim)
        self.layer_norm = layer_norm
    
    def reset_parameters(self):
        # The rationale behind setting the standard deviation of the normal distribution to 1/sqrt(layer.in_features)
        # is to normalize the variance of the layer's inputs and outputs. This helps to prevent the outputs
        # from exploding or vanishing during training. The 1/sqrt(layer.in_features) factor is based on the recommendation
        # in the paper "Understanding the difficulty of training deep feedforward neural networks" by Glorot and Bengio (2010).
        self.layer1.weight.data.normal_(0, 1 / torch.sqrt(self.layer1.in_features))
        # Setting the bias to 0 allows the network to learn the appropriate bias values during training.
        self.layer1.bias.data.fill_(0)
        # The same reasoning applies to the other layers.
        self.layer2.weight.data.normal_(0, 1 / torch.sqrt(self.layer2.in_features))
        self.layer2.bias.data.fill_(0)
        self.layer3.weight.data.normal_(0, 1 / torch.sqrt(self.layer3.in_features))
        self.layer3.bias.data.fill_(0)
        
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer2(x)
        if self.layer_norm:
            x = nn.LayerNorm(x)
        return x

### GNN Layer
Here we implement InteractionNetwork\
paper: https://proceedings.neurips.cc/paper_files/paper/2016/file/3147da8ab4a0437c15ef51a5cc7f2dc4-Paper.pdf

We use MPL that we defined above to generate messages for nodes and edges.

Updaed features for nodes, `v_i` and edges, `e_ij`:

$$
v_i^{k+1} = v_i^k + MLP_n(v_i^k, \sum_{v_j \in N(v_i)}{MPL_e(v_i^k, v_j^k, e_ij^k)}) \\
$$ 
$$ e_ij^{k+1} = e_ij^k + MLP_e(v_i^k, v_j^k, e_ij^k) $$

where $MLP_e(\cdot, \cdot, \cdot)$ is only computed once and then used twice.

In [None]:
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter
from torch import cat

class InteractionNetwork(MessagePassing):
    def __init__(self, hidden_dim):
        super().__init__()
        self.node_msg = MLP(2 * hidden_dim, hidden_dim, hidden_dim)
        self.edge_msg = MLP(3 * hidden_dim, hidden_dim, hidden_dim)

    def forward(self, x, edge_index, edge_feature):
        # propagate invokes message() and aggregate(), which return (inputs, out)
        # we update edge feature as: current edge feature + current message passing it
        edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
        edge_out = edge_feature + edge_out

        # we update node features as: sum of neigbouring messages and current
        # node feature get passed through coresponding MLP.
        # To include self correction a bit we add current feature to that output
        node_out = x + self.node_msg(cat((x, aggr), dim=-1))

        return node_out, edge_out

    def message(self, x_i, x_j, edge_feature):
        # here we create messages as an output of MPL with 3 inputs:
        # edge feature and feature of each node connected by this edge
        x = self.edge_msg(cat((x_i, x_j, edge_feature), dim=-1))
        
        return x

    def aggregate(self, source, index, dim_size=None):
        # we sum all neighbouring messages for each node, which we will use to 
        # update next layer of node features
        out = scatter(source, index, dim_size=dim_size, dim=self.node_dim, reduce="sum")

        return (source, out)

In [None]:
from torch.nn import Embedding, ModuleList

class GNS(nn.Module):
    def __init__(
        self,
        hidden_dim = 128,
        num_types = 9,
        emb_dim = 16,
        num_gnn_layers = 5,
        simulation_dim = 2
    ):
        super().__init__()
        # IMPORTANT: this is the input dimension of data. It means that the model
        # gets data from 2 previouos frames(2*sim_dim) plus the embedding.
        # this variable is precomputed here for transparency and used in node_input
        node_input_dim = 2 * simulation_dim + emb_dim
        
        # classic torch.nn Embedding
        self.embedding = Embedding(num_types, emb_dim)

        # node inputs and outputs
        self.node_input = MLP(node_input_dim, hidden_dim, hidden_dim)
        self.node_outpt = MLP(hidden_dim, hidden_dim, simulation_dim)

        self.edge_input = MLP(simulation_dim + 1, hidden_dim, hidden_dim, layer_norm=False)

        # initialize gnn layers as InteractionNetwork layers
        self.gnns = ModuleList([InteractionNetwork(hidden_dim) for i in range(num_gnn_layers)])

        # just save number of layers for later use
        self.num_gnns = num_gnn_layers

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embedding.weight)

    def forward(self, data):
        # first we embed the data into features
        # SELF NOTE: I was guessing that `data` will have keys `x`, `pos`, `edge_attr` and `edge_index`
        node_features = self.node_input(cat(self.embedding(data.x), data.pos), dim=-1)
        edge_features = self.edge_input(data.edge_attr)

        # then propagate them trough model layers
        for gnn in self.gnns:
            node_features, edge_features = gnn(x=node_features, edge_features=edge_features, edge_index=data.edge_index)

        # and finally return node positions, in our case: x, y coordinates
        node_output = self.node_outpt(node_features)

        return node_output

## Training and testing

We first define the arguments.

In [3]:
args = {
        'input_dim': 3,
        'output_dim': 3,
        'hidden_dim': 128,
        'layer_norm': True,
        'lr': 0.01,
        'weight_decay': 5e-3,
        'batch_size': 2,
        'epochs': 1,
        'dropout': 0.5,
        'opt': 'adam',
        'validate_interval': 1000,
        'save_model': True,
        'model_path': './model.pt'
}

In [None]:
import torch_geometric as pyg

@torch.no_grad()
def validate_onestep(model: torch.nn.Module, data_loader: pyg.data.Dataset, device: torch.device): # type: ignore
    model.eval()
    total_loss = 0
    for batch_number, data in enumerate(data_loader):
        data = data.to(device)
        out = model(data)
        loss = F.mse_loss(out, data.y)
        total_loss += loss.item()
    return total_loss / (batch_number + 1) # reportUnboundVariable: ignore

In [None]:
from typing import Any
from tqdm import tqdm


def train(args: dict[str, Any], model: torch.nn.Module, train_loader, valid_loader):
    
    # Set the device to GPU if available, otherwise CPU
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    # init optimiser
    if args['opt'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    elif args['opt'] == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    else:
        raise ValueError('Unknown optimizer: {}'.format(args['opt']))
    
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
    
    # loss function
    loss_fn = nn.MSELoss()
    
    # track the losses to be able to plot the learning curve
    train_loss = []
    validate_loss = []
    
    # track the total number of steps
    steps = 0
    
    # main train loop
    for epoch in range(args['epochs']):
        model.train()
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')
        
        # keep track of the total loss and the number of batches
        total_loss = 0
        batch_count = 0
        
        for data in progress_bar:
            # forward pass
            optimizer.zero_grad()
            data = data.cuda()
            out = model(data)
            loss = loss_fn(out, data.y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # update progress bar
            total_loss += loss.item()
            batch_count += 1
            progress_bar.set_postfix({"loss": loss.item(), "avg_loss": total_loss / batch_count})
            steps += 1
            train_loss.append((steps, loss.item()))

            # evaluation
            if steps % args["validate_interval"] == 0:
                model.eval()
                loss = validate_onestep(model, valid_loader, device)
                validate_loss.append((steps, validate_loss))
                tqdm.write(f"\nEval: Loss: {validate_loss}")
                model.train()
    
    if args['save_model']:
        torch.save(model.state_dict(), args['model_path'])
    
    return train_loss, validate_loss

Now for the actual initialization of the data, model and training of the model.

In [None]:
# here the actual training takes place
train_dataset = None # TODO: init train dataset
valid_dataset = None # TODO: init valid dataset

train_loader = pyg.data.DataLoader(train_dataset, batch_size=args['batch_size'], drop_last=True, shuffle=True, pin_memory=True, num_workers=4)
valid_loader = pyg.data.DataLoader(valid_dataset, batch_size=args['batch_size'], drop_last=True, shuffle=False, pin_memory=True, num_workers=4)

model = None # TODO: init model to GNS

train_loss, validate_loss = train(args, model, train_loader, valid_loader)

### Plot the loss curve

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

# visualize the loss curve
plt.figure()
plt.plot(*zip(*train_loss), label="train")
plt.plot(*zip(*validate_loss), label="valid")
plt.title('Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()