<center><a href="https://www.nvidia.com/en-us/training/"><img src="https://dli-lms.s3.amazonaws.com/assets/general/DLI_Header_White.png" width="400" height="186" /></a></center>

# 2a. Intermediate Fusion



In the last lab, we successfully created a multimodal model. However, the performance was not much better than a single modal model. In this notebook, we will explore a more challenging dataset and different fusion techniques to overcome this challenge.

These experiments will take a moment to run. While these models are training, please continue to read on and try to predict what the results will be.

#### Learning Objectives

The goals of this notebook are to:
* Compare four types of multimodal fusion techniques:
  * Early fusion
  * Late fusion
  * Intermediate fusion with concatenation
  * Intermediate fusion with matrix multiplication

Let's begin by loading the libraries necessary for this lab.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from PIL import Image
from IPython.display import Image as IPy_img

import utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

## 2.1 The Dataset

In the previous lab, our dataset was composed of three different shapes: a cube, a sphere, and a torus. Because the LiDAR data can more directly calculate the center of each object, the RGB data was ignored in our multimodal models.

In this lab, we will use three of the same shape: a red cube, a green cube, and a blue cube. Since LiDAR systems cannot see color, a model based only on LiDAR data would have a hard time distinguishing which cube is which. On the other hand, RGB data would be able to distinguish between the cubes, but an RGB model would have a harder time calculating position in general. Let's see if we can create a model with the best of both worlds.

Before we create any multimodal models, let's see how single modal models would perform on this dataset. In the interest of time, we've already ran a training simulation and collected the results. If you would like to verify, this can be done by running the previous notebook, but changing the data directory to `data/replicator_data_cubes/`.

In [None]:
single_mode_data = np.genfromtxt("data/cubes_only_single_mode_results.csv", delimiter=',', skip_header=1)

plot_x = range(len(single_mode_data))
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.plot(plot_x, single_mode_data[:, 1], "green", label = "RGB Train Loss")
plt.plot(plot_x, single_mode_data[:, 2], "darkgreen", label = "RGB Valid Loss")
plt.plot(plot_x, single_mode_data[:, 3], "orchid", label = "XYZ Train Loss")
plt.plot(plot_x, single_mode_data[:, 4], "darkorchid", label = "XYZ Valid Loss")
plt.title("Cubes Only Single Mode Results")
plt.legend()
plt.show()

With this dataset, a model based on RGB images has a validation loss of about 6. We have 3 objects, so the predicted position of each object is off by 2 units. Omniverse position units are relative, meaning it's not tied to a real-world unit of measurement like feet or centimeters. Many 3D modelers and applications suggest 1 unit should be 1 meter.

On the other hand, the LiDAR model achieved a loss of less than 1 on the train dataset, but over 8 on the validation dataset. This is a classic case of overfitting. Because the LiDAR model can't find the logical pattern to calculate the position of these objects, it instead attempted to memorize the train dataset.

So, to put things into perspective, our objects have a 10 unit by 10 unit range it can randomly place itself in along each of the X, Y, and Z axis. In other words, the objects spawn in a 1000 unit cubed area, and the diagonal along that area is about 17.3 units. In this case, the loss is the same as our error. If the average error for each object is 2 units, that's about 11.6% of the length of diagonal range. Not bad, but not great either.

Let's try to visualize this better. Below is the first set of positions for the red, green and blue cubes. The lighter points represent to correct positions. The darker points represent those same points offset by 2 units.

In [3]:
%%capture
x = [3.431835889816284, -1.59285056591033, 2.45306801795959, 3.13458254, -2.49218181591033, 2.997167998]
y = [2.27847838401794, 3.99227786064147, -1.98521077632904, 1.551694174, 4.414836401, -1.190015916]
z = [2.73768424987792, -4.99803066253662, -3.022789478302, 2.12207423, -4.875654503, -2.755063558]
c = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [.5, 0, 0], [0, .5, 0], [0, 0, .5]]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlim(-6, 6)
ax.set_ylim(-6, 6)
ax.set_zlim(-6, 6)

file_path = 'images/errors.gif'
utils.animate_3D_points(file_path, fig, ax, x, y, z, c)

In [None]:
IPy_img(open(file_path,'rb').read())

Our dataset has the same structure as last time: the first set of values are the `rgb_img`s, the second set is the `lidar_xyza` data, and the third set of values are the object positions. Let's verify that the shapes are what we expect.

In [None]:
train_data, train_dataloader, valid_data, valid_dataloader = utils.get_replicator_dataloaders("data/replicator_data_cubes/")

for i, sample in enumerate(train_data):
    print(i, *(x.shape for x in sample))
    if i == 5:
        break

## 2.2 Model Composition

Time to get these experiments rolling. We're going to test the performance of four models:
* `early_net`: A multimodal model using early fusion
* `late_net`: A multimodal model using late fusion
* `cat_net`: A multimodal model using intermediate fusion with [concatenation](https://pytorch.org/docs/main/generated/torch.cat.html)
* `matmul_net`: A multimodal model using intermediaet fusion with [matrix multiplication](https://pytorch.org/docs/stable/generated/torch.matmul.html)

### 2.2.1 EarlyNet

First, let's make a generic convolutional neural network that we can use with both `early_net` and `late_net`.

In [6]:
num_positions = 9

class Net(nn.Module):
    def __init__(self, in_ch):
        kernel_size = 3
        super().__init__()
        flattened_size = 200 * 8 * 8
        self.conv1 = nn.Conv2d(in_ch, 50, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(50, 100, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(100, 200, kernel_size, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(flattened_size, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

We've created a function in [utils.py](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html) called `train_model` to train each of our experiments. It still needs a function to extract the appropriate information from the dataloader. In the case of `early_net`, the rgb and xyz values should be [concatenated](https://pytorch.org/docs/main/generated/torch.cat.html) before feeding the data into the model.

In [7]:
def get_early_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    inputs_mm_early = torch.cat((inputs_rgb, inputs_xyz), 1)
    return (inputs_mm_early,)

Time for the first experiment. Please run the cell below and continue reading. Training will take some time to complete.

In [None]:
epochs = 20
early_net = Net(8).to(device)
early_opt = Adam(early_net.parameters(), lr=0.0001)
early_train_losses, early_valid_losses = utils.train_model(
    early_net,
    early_opt,
    get_early_inputs,
    epochs,
    train_dataloader,
    valid_dataloader
)

### 2.2.2 LateNet

Next, let's create `late_net`. This will use two of our generic `Net`s that we created in the last section. Unlike the last lab, these nets will not be pretrained, allowing for a fairer comparison with `early_net`. However, since there are two `Net`s, `late_net` has nearly twice as many paramaters as `early_net`, so a perfectly fair comparison would be difficult to achieve. It's still worthwhile to try new things. Let's see how they compare.

In [9]:
rgb_net = Net(4).to(device)
xyz_net = Net(4).to(device)

class LateNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rgb = rgb_net
        self.xyz = xyz_net
        self.fc1 = nn.Linear(num_positions * 2, num_positions * 10)
        self.fc2 = nn.Linear(num_positions * 10, num_positions)

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.rgb(x_rgb)
        x_xyz = self.xyz(x_xyz)
        x = torch.cat((x_rgb, x_xyz), 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

For the rest of these experiments, we will use the same `get_inputs` function. `late_net`, like `cat_net` and `matmul_net`, has two different paths for the data to flow through the network.

In [10]:
def get_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    return (inputs_rgb, inputs_xyz)

Finally, let's define `late_net` and run the experiment. Please run the cell below and read on. Because `late_net` has more parameters, this also means it will take longer to train.

In [None]:
late_net = LateNet().to(device)
late_opt = Adam(late_net.parameters(), lr=0.0001)
late_train_losses, late_valid_losses = utils.train_model(
    late_net,
    late_opt,
    get_inputs,
    epochs,
    train_dataloader,
    valid_dataloader,
)

### 2.2.3 CatNet

For the next experiment, we'll try something completely new. Instead of joining our data streams at the beginning or end of the model, we'll mix them together somewhere in the middle. In this case, we'll have two convolution paths: one for RGB data, and one for XYZ data. We can almost think of a convolution function as taking in some sheets of paper and outputting a different number of sheets of paper. It's like we're taking these output sheets from our two data streams and stacking them on top of each other.

Before moving on, take a moment to look at the architecture below. Can you see where the two pathways meet? It doesn't appear in the `__init__` function. That `cat` operation occurs in the `forward` pass. In the generic `Net` architecture, the last convolution had `200` kernels. In this case we have `2` final convolutions with `100` kernels each.

In [12]:
class ConcatIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        kernel_size = 3
        num_positions = 9
        super().__init__()
        self.rgb_conv1 = nn.Conv2d(rgb_ch, 25, kernel_size, padding=1)
        self.rgb_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.rgb_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)
        
        self.xyz_conv1 = nn.Conv2d(xyz_ch, 25, kernel_size, padding=1)
        self.xyz_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.xyz_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)
        
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(200 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.pool(F.relu(self.rgb_conv1(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv2(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv3(x_rgb)))
        
        x_xyz = self.pool(F.relu(self.xyz_conv1(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv2(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv3(x_xyz)))
        
        x = torch.cat((x_rgb, x_xyz), 1)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

The top of this architecture looks like late fusion, but the bottom of it looks like early fusion. That means the number of parameters for `cat_net` is less that `late_net`. However, more parameters are not always better. Let's see if this is true by running the cell below. This experiment will take some time. Please continue to read on.

In [None]:
epochs = 20
cat_net = ConcatIntermediateNet(4, 4).to(device)
cat_net_opt = Adam(cat_net.parameters(), lr=0.0001)
cat_net_train_losses, cat_net_valid_losses = utils.train_model(
    cat_net,
    cat_net_opt,
    get_inputs,
    epochs,
    train_dataloader,
    valid_dataloader,
)

### 2.2.4 MatmulNet

We're finally at our last experiment, `matmul_net`. This architecture is very similar to `cat_net`, but with one key difference. Instead of concatenating the outputs of the two final convolutions, we're going to broadcast [matrix multiply](https://en.wikipedia.org/wiki/Matrix_multiplication) the outputs together. Since they're both squares of the same shape, this is a valid operation.

The [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)), where we multiply the matrices elementwise, would also be valid. However, matrix multiplication allows for different parts of the matrices to interact with each other, increasing the chance our network finds a useful multiplication operation.

<center><img src="images/matmul.png" width="50%" /></center>

Since we're matrix multiplying instead of concatenating, our `2` x `100` convolution output gets reduced down to `100` matrices. As they are still matrices, they should be `flattened` before passing the information to the dense `linear` section of the network.

In [14]:
class MatmulIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        kernel_size = 3
        color_chs = 9
        num_positions = 9
        super().__init__()
        self.rgb_conv1 = nn.Conv2d(rgb_ch, 25, kernel_size, padding=1)
        self.rgb_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.rgb_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)
        
        self.xyz_conv1 = nn.Conv2d(xyz_ch, 25, kernel_size, padding=1)
        self.xyz_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.xyz_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)
        
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(100 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.pool(F.relu(self.rgb_conv1(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv2(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv3(x_rgb)))
        
        x_xyz = self.pool(F.relu(self.xyz_conv1(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv2(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv3(x_xyz)))
        
        x = torch.matmul(x_rgb, x_xyz)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
epochs = 20
matmul_net = MatmulIntermediateNet(4, 4).to(device)
matmul_net_opt = Adam(matmul_net.parameters(), lr=0.0001)
matmul_net_train_losses, matmul_net_valid_losses = utils.train_model(
    matmul_net,
    matmul_net_opt,
    get_inputs,
    epochs,
    train_dataloader,
    valid_dataloader,
)

## 2.3 Results

Congrats on making it to the end. There's a good chance at this point, the experiment is still running. For now, please move on to the [next notebook](02b_Contrastive_Pretraining.ipynb) in the meantime. Please return when you are finished to see the comparison graph below.

In [None]:
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.plot(plot_x, early_valid_losses, "goldenrod", label = "EarlyNet")
plt.plot(plot_x, late_valid_losses, "green", label = "LateNet")
plt.plot(plot_x, cat_net_valid_losses, "blue", label = "CatNet")
plt.plot(plot_x, matmul_net_valid_losses, "orchid", label = "MatmulNet")
plt.legend()
plt.show()

## Next

We learned a few interesting fusion techniques in this notebook, but we're not done yet. In the next section, we'll learn how to project the embeddings of one model into another.

In [None]:
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<center><a href="https://www.nvidia.com/en-us/training/"><img src="https://dli-lms.s3.amazonaws.com/assets/general/DLI_Header_White.png" width="400" height="186" /></a></center>