<center><a href="https://www.nvidia.com/en-us/training/"><img src="https://dli-lms.s3.amazonaws.com/assets/general/DLI_Header_White.png" width="400" height="186" /></a></center>

# 1a. Early and Late Fusion



Multimodal models are a simple concept with a surprisingly complex practice. Let's compare different data types and how we can analyze them in a multimodal model with a robotics use case.

#### Learning Objectives

The goals of this notebook are to:
* Explore the properties of LiDAR data
* Compare single modal models
  * Construct an RGB image model
  * Construct a LiDAR model
* Compare multimodal models
  * Construct a late fusion model
  * Construct and early fusion model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from PIL import Image
from IPython.display import Image as IPy_img

import utils
import os

In PyTorch, we can use our GPU in our operations by setting the [device](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) to `cuda`. The function `torch.cuda.is_available()` will confirm PyTorch can recognize the GPU.

In [None]:
!nvidia-smi

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

## 1.1 The Dataset

Many fields of data science have famous "Hello World" style datasets that can be used to demonstrate a model architecture. Unfortunately, there is no such standard benchmark for multimodal models, so we decided to make such a dataset using [NVIDIA Omniverse](https://www.nvidia.com/en-us/omniverse/).

Real world data collection can be challenging, and so NVIDIA Omniverse is a platform that can be used for creating and simulating virtual copies of real-life objects and locations. These virtual copies are called [Digital Twins](https://www.nvidia.com/en-us/omniverse/solutions/digital-twins/). These digital twins can be used to simulate and capture data that would be difficult, or even impossible, with real world instruments. The process of creating this data is called [Synthetic Data Generation](https://www.nvidia.com/en-us/use-cases/synthetic-data/).

For this first lab, we randomly generated the positions of three objects: a cube, a sphere, or a torus (a.k.a. a donut). What kind of tools could we use to determine the position of the center of these objects? Let's look at our options.

Each of the data types below are generated from the same position of objects. The position is the same across data files of the same file number, in this case, `0163`.

### 1.1.1 RGB Camera Data
The first kind of data are [RGB color images](https://en.wikipedia.org/wiki/RGB_color_model). This could be captured using a camera, such as the one on the back of most smartphones. Let's look at an example.

In [None]:
data_img = Image.open('data/replicator_data_parallel/rgb/0163.png')
plt.imshow(data_img)

Here, we have an image that's 64 pixels tall and 64 pixels wide. Image data like this is generally analyzed using a [convolutional neural network](https://developer.nvidia.com/discover/convolutional-neural-network). We will be using the same technique in our multimodal models.

### 1.1.2 Distance Image

Since the data is simulated, Omniverse can calculate the distance of each pixel to the camera that is capturing the image. RGB cameras work by recording the light that reflects off of the objects in front of it. Because of that, it can be difficult to gather usable data in low-light situations. By instead using the distance, we can see better in the dark.

However, in the real world, there are not many instruments that can capture this data like in the image below. Because of this, we will not be using them in our multimodal models, but the data is available to help verify our analysis.

In [None]:
data = np.load('data/replicator_data_parallel/distance/0163.npy')
plt.imshow(data, cmap=utils.cmap)

### 1.1.3 LiDAR Data

One real world instrument that can capture distance data is a [LiDAR](https://oceanservice.noaa.gov/facts/lidar.html) sensor. LiDAR sensors work by emitting multiple beams of light invisible to the human eye. Thanks to [Ole Rømer](https://en.wikipedia.org/wiki/Ole_R%C3%B8mer) who calculated the speed of light, we can measure how far away something is based on how long it takes to observe the emitted light.

There are many types of LiDAR devices available on the market, but to make a better comparison with RGB data, we simulated 64 by 64 lasers being emitted from the sensor. Because of the shape, we can view it like an image. There will be a little bit of a discrepancy because LiDAR data is collected along a curve, but this 2D projection can still be useful for debugging.

In [None]:
data = np.load('data/replicator_data_parallel/lidar/0163.npy')
plt.imshow(data, cmap=utils.cmap)
data.shape

Hmm, the objects look familiar, but they appear rotated compared to our RGB and Distance images. That's because each light beam is emitted at a different angle, so to align the LiDAR data with the RGB information, we need the angle information. If we have the angles, not only would we have the distance traveled of each emitted laser beam, we could also calculate the position of where the laser landed. 

Because this is a 3-dimensional space, we need at least two angles to calculate the position. Thankfully, we have this data. The [azimuth](https://en.wikipedia.org/wiki/Azimuth) represents the horizontal rotation from the LiDAR sensor's forward, and the [zenith](https://en.wikipedia.org/wiki/Zenith) represents the vertical rotation.

<center><img src="images/lidar_math.png" width="70%"/></center>

This particular LiDAR system has the following properties to consider:
* The azimuth angle corresponds to the rows, and the zenith angle corresponds to the columns.
* The angles use a [left-handed coordinate system](https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/geometry/coordinate-systems.html#:~:text=The%20differentiation%20between%20left%2Dhanded,a%20right%2Dhand%20coordinate%20system.) with Z as the vertical axis.
* Azimuth and zenith as measured relative to a point of reference.

This is not true of all LiDAR systems, so the following calculations may be different if a different system is used. Let's break down the steps as much of the logic can be applied to most azimuth/zenith systems.

First, we'll need the azimuth of each laser.

In [None]:
azimuth = np.load("data/replicator_data_parallel/azimuth.npy")
print(azimuth.shape)
azimuth

Then, we'll need the zenith.

In [None]:
zenith = np.load("data/replicator_data_parallel/zenith.npy")
print(zenith.shape)
zenith

Here's a handy trick to remember whether to use sine or cosine. sin(0) = 0 and cos(0) is 1. We're using an XYZ coordinate plane where X is the horizontal position, Y is the forward and backward position, and Z is the vertical position. If we fire our laser straight out in front of us, the position where it lands should be (0,d,0) where d is the distance the laser travelled.

Based on how this particular LiDAR system collected its data, the sine of the azimuth will help us calculate our X positions and make sure x is 0 when the azimuth is 0. We will add a dimension with `None` so we can [broadcast](https://numpy.org/doc/stable/user/basics.broadcasting.html) the data into the same shape as the lidar depth data.

Notice the `-` sign by `azimuth`? Come back later and remove it to see how it affects the results.

In [None]:
x_surface = np.ones_like(data) * np.sin(-azimuth[:, None])
plt.imshow(x_surface, cmap=utils.cmap)

Aha! This is why our LiDAR data looks rotated. For this data, the azimuth changes from top to bottom, not from left to right. We can see a similar phenomenon when we calculate the zenith. Using the same logic as above. we will use the sine of the zenith to make sure z is 0 when the zenith is 0.

In [None]:
z_surface = np.ones_like(data) * np.sin(-zenith[None, :])
plt.imshow(z_surface, cmap=utils.cmap)

We've calculated our Xs and our Zs, so finally, it's the Ys. When we fire a laser directly in front of us, the y position should be the same as the distance the laser traveled. However, the more the angle of the laser deviates, the more the y position is less than the distance travelled. To calculate this, we will use the cosine of both the azimuth and the zenith.

In [None]:
y_surface = np.ones_like(data) * np.cos(-azimuth[:, None]) * np.cos(-zenith[None, :])
plt.imshow(y_surface, cmap=utils.cmap)

There's one more thing to consider about LiDAR data. Many LiDAR sensors have a max range where they are effective. For this sensor system, if it doesn't detect that the light has returned, it will assume the maximum distance for that laser emitted. To make better visualizations, let's create a mask to identify when a laser has travelled its maximum distance.

In [None]:
a = [data != data.max()][0]
plt.imshow(a, cmap=utils.cmap)

Now that we've calculate how much the azimuth and the zenith alter the distance travelled, all we have to do is multiple our X, Y and Z surfaces with the lidar depth to extract the position information of each laser collision.

In [None]:
lidar_depth = data

x = lidar_depth * x_surface
z = lidar_depth * z_surface
y = lidar_depth * y_surface

plt.clf()
plt.subplot(1, 4, 1)
plt.imshow(x, cmap=utils.cmap)
plt.subplot(1, 4, 2)
plt.imshow(z, cmap=utils.cmap)
plt.subplot(1, 4, 3)
plt.imshow(y, cmap=utils.cmap)
plt.subplot(1, 4, 4)
plt.imshow(a, cmap=utils.cmap)
plt.show()

In many LiDAR visualizations, points are colored based on their vertical position. Let's do the same. We'll use `c` to represent how to color a collision.

In [None]:
c = np.copy(z)
c_min = np.min(c)
c_max = np.max(c)
c = (c - c_min) / (c_max - c_min)

Now that we have all our position information, let's plot it. How does it look? Is the LiDAR data finally aligned with the RGB image?

In [None]:
# Create a 3D figure
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Scatter plot the points
x_scatter = x[a == 1]
y_scatter = y[a == 1]
z_scatter = z[a == 1]
c_scatter = c[a == 1]
ax.scatter(x_scatter, y_scatter, z_scatter, c=c_scatter, cmap="rainbow", marker='o')

# Set labels
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlim(-6, 6)
ax.set_ylim(19, 31)
ax.set_zlim(-6, 6)

ax.view_init(elev=0., azim=270)
plt.show()

<img src="data/replicator_data_parallel/rgb/0163.png" width="300px" style="image-rendering:pixelated;"/>

Looking much better! Since this is position information, we can take our visualization further. Let's create an animation that rotates the points in a looping GIF.

This may take a few minutes. It's worth it, we promise.

In [None]:
def init():
    ax.scatter(x_scatter, y_scatter, z_scatter, c=c_scatter, cmap="rainbow", marker='o')
    return fig,

def animate(i):
    ax.view_init(elev=30., azim=i)
    return fig,

file_path = 'pointcloud.gif'
if not os.path.exists(file_path):
    anim = utils.animation.FuncAnimation(
        fig, animate, init_func=init, frames=360, interval=20, blit=True
    )
    anim.save(file_path, fps=30)
IPy_img(open(file_path,'rb').read())

## 1.2 Model Comparison

Now that we've explored our data, it's time to create our models. Before making a multimodal model, let's see how a single modal model would fair. We'll make a model that only uses RGB images and a model that only uses LiDAR data.

### 1.2.1 PyTorch Data

Before we can create our models, we should convert our data into a [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). First, let's create a function that will calculate the XYZ position data from the lidar depth data. We will use the letter `a` to represent whether the depth is at maximum or not.

**TODO:** Create a function to convert lidar depth data into position data by replacing the `FIXME`s below. *Hint*: The PyTorch implementation is similar to the NumPy implementation. Click the `...` for the solution.

In [None]:
def get_torch_xyza(lidar_depth, azimuth, zenith):
    x = lidar_depth * torch.sin(FIXME)
    z = lidar_depth * torch.sin(FIXME)
    y = lidar_depth * torch.cos(FIXME) * torch.cos(FIXME)
    a = torch.where(lidar_depth < 50.0, torch.ones_like(lidar_depth), torch.zeros_like(lidar_depth))
    xyza = torch.stack((x, y, z, a))
    return xyza

In [None]:
def get_torch_xyza(lidar_depth, azimuth, zenith):
    x = lidar_depth * torch.sin(-azimuth[:, None])
    z = lidar_depth * torch.sin(-zenith[None, :])
    y = lidar_depth * torch.cos(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    a = torch.where(lidar_depth < 50.0, torch.ones_like(lidar_depth), torch.zeros_like(lidar_depth))
    xyza = torch.stack((x, y, z, a))
    return xyza

Now that we have our XYZ data in tensor format, let's create a PyTorch Dataset. Normally, this is where we would apply data augmentation. However, we will be using this dataset for our multimodal model as well. Any data augmentation such as cropping and flipping would need to apply to both the corresponding image and LiDAR data. The math for that would get a little complicated, so let's keep things simple for now and use the data as is for demonstration purposes.

In [None]:
IMG_SIZE = 64
BATCH_SIZE = 32
VALID_BATCHES = 10
N = 9999

img_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),  # Scales data into [0,1]
])

In [None]:
class MyDataset(Dataset):
    def __init__(self, root_dir, start_idx, stop_idx):
        self.root_dir = root_dir
        self.imgs = []
        self.lidar_depths = []
        self.positions = np.genfromtxt(
            root_dir + "positions.csv", delimiter=",", skip_header=1
        )[start_idx:stop_idx]

        self.azimuth = torch.from_numpy(azimuth).to(device)
        self.zenith = torch.from_numpy(zenith).to(device)

        for idx in range(start_idx, stop_idx):
            file_number = "{:04d}".format(idx)
            rbg_img = Image.open(self.root_dir + "rgb/" + file_number + ".png")
            rbg_img = img_transforms(rbg_img).to(device)
            self.imgs.append(rbg_img)

            lidar_depth = np.load(self.root_dir + "lidar/" + file_number + ".npy")
            lidar_depth = torch.from_numpy(lidar_depth).to(torch.float32).to(device)
            self.lidar_depths.append(lidar_depth)

    def __len__(self):
        return len(self.positions)

    def __getitem__(self, idx):
        rbg_img = self.imgs[idx]
        lidar_depth = self.lidar_depths[idx]
        lidar_xyza = get_torch_xyza(lidar_depth, self.azimuth, self.zenith)

        position = self.positions[idx]
        position = torch.from_numpy(position).to(torch.float32).to(device)

        return rbg_img, lidar_xyza, position

Before we feed this into our neural networks for training, let's test it out. Is the data in the shape and size we would expect?

In [None]:
train_data = MyDataset("data/replicator_data_parallel/", 0, N-VALID_BATCHES*BATCH_SIZE)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_data = MyDataset("data/replicator_data_parallel/", N-VALID_BATCHES*BATCH_SIZE, N)
valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

The order is `rbg_img`, `lidar_xyza`, and then `position`. The `rbg_img` data and `lidar_xyza` data should be the same shapes as we are comparing them. We have 3 objects, and we're trying to find the x, y, and z positions of the centers of each of them. Therefore, the position information should have 9 numbers.

In [None]:
for i, sample in enumerate(train_data):
    print(i, sample[0].shape, sample[1].shape, sample[2].shape)

    if i == 5:
        break

### 1.2.2 Single Modal Model Architecture

Before we make a multimodal model, let's see how a single modal model will perform. There are many ways we could create this model, so it would be difficult to prove one data type is better than the other in all situations. That said, it's still worthwhile to hypothesize why some data might be better than others. In addition, we can use our single modal models as a baseline for our multimodal models.

Given the shape for both the image and LiDAR data is the same, we can use a convolutional neural network for both of them.

In [None]:
num_positions = 9

class Net(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        kernel_size = 3
        self.conv1 = nn.Conv2d(in_ch, 50, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(50, 100, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(100, 200, kernel_size, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(200 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

What kind of loss function would be the best? Given that we're trying to predict position information, why not use the [Pythagorean Theorem](https://en.wikipedia.org/wiki/Pythagorean_theorem) to calculate the distance between our predicted positions and the actual positions. Then we can try to minimize that distance through gradient descent.

Good news! The [Mean Squared Error](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html) loss function will do that for us. The mean squared error is the same as the Pythagorean Theorem, but without the square root. Since minimizing the MSE would also minimize the RMSE, we can use the MSE loss function.

In [None]:
loss_func = nn.MSELoss()

Below is our training loop. We've moved some of the functions into [utils.py](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html) for brevity. 

In [None]:
def train_model(model, optimizer, inputs_idx, epochs=20):
    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            outputs, target = utils.get_outputs(model, batch, inputs_idx)
            loss = loss_func(outputs, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / (step + 1)
        train_losses.append(train_loss)
        utils.print_loss(epoch, train_loss, outputs, target, is_train=True)
        
        model.eval()
        valid_loss = 0
        for step, batch in enumerate(valid_dataloader):
            outputs, target = utils.get_outputs(model, batch, inputs_idx)
            valid_loss += loss_func(outputs, target).item()
        valid_loss = valid_loss / (step + 1)
        valid_losses.append(valid_loss)
        utils.print_loss(epoch, valid_loss, outputs, target, is_train=False)
    return train_losses, valid_losses

Time to get this experiment started! This may take a little while, so please press the fast forward button ![image](images/fast_forward.png) above to run the rest of the notebook. While the experiment is running, let's take a moment to think. Which data type do you think would be more effective at predicting the location of our objects? Can you give a reason why?

Once you have an answer, please continue reading the rest of the notebook, even if the next cell is still running.

In [None]:
epochs = 20

rgb_net = Net(4).to(device)
rgb_opt = Adam(rgb_net.parameters(), lr=0.0001)

xyz_net = Net(4).to(device)
xyz_opt = Adam(xyz_net.parameters(), lr=0.0001)

print("Training rgb_net")
rgb_train_loss, rgb_valid_loss = train_model(rgb_net, rgb_opt, 0)
print("Training xyz_net")
xyz_train_loss, xyz_valid_loss = train_model(xyz_net, xyz_opt, 1)

In [None]:
plot_x = range(epochs)
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.plot(plot_x, rgb_train_loss, "green", label = "RGB Train Loss")
plt.plot(plot_x, rgb_valid_loss, "darkgreen", label = "RGB Valid Loss")
plt.plot(plot_x, xyz_train_loss, "orchid", label = "XYZ Train Loss")
plt.plot(plot_x, xyz_valid_loss, "darkorchid", label = "XYZ Valid Loss")
plt.legend()
plt.show()

### 1.2.3 Multimodal Model Architecture

If you've gotten to this point and the experiment is still running, we won't spoil the surprise. In the meanwhile, let's get another experiment up and running. Once we have pretrained models for each the image data and the LiDAR data, there's a relatively simple way we can turn them into a multimodal modal: by connecting the outputs of each model into a final output. In this way, it's almost like we're creating an [ensemble model](https://scikit-learn.org/1.5/modules/ensemble.html), where each model has a weighted vote in the final result.

In [None]:
networks = [rgb_net, xyz_net]

for network in networks:
    for param in network.parameters():
        param.requires_grad = False

class MyMultimodalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.rgb_net = rgb_net
        self.xyz_net = xyz_net
        self.fc1 = nn.Linear(num_positions * len(networks), num_positions * 10)
        self.fc2 = nn.Linear(num_positions * 10, num_positions)

    def forward(self, x_img, x_xyz):
        x_rgb = self.rgb_net(x_img)
        x_xyz = self.xyz_net(x_xyz)
        x = torch.cat((x_rgb, x_xyz), 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

This type of multimodal model is called `late fusion` because we analyze each data type separately and only combine the two data streams at the very end. In our [utils.py](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html) file, we've created a flexible function similar to the `train_model` above in order to train this model. It still needs a function to extract the network inputs from the dataloader, so let's define that.

In [None]:
def get_mm_late_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    return (inputs_rgb, inputs_xyz)

Ok, let's run another experiment. Will this be better or worse than the single modal models?

In [None]:
mm_late_net = MyMultimodalModel().to(device)
mm_late_opt = Adam(mm_late_net.parameters(), lr=0.0001)
mm_late_train_loss, mm_late_valid_loss = utils.train_model(
    mm_late_net,
    mm_late_opt,
    get_mm_late_inputs,
    epochs,
    train_dataloader,
    valid_dataloader
)

We've got one more experiment to run in this lab. We have late fusion, so why not try early fusion? If late fusion combines the data pathways at the end of the model, early fusion does it right at the beginning. In fact, we will use the same model architecture as our `rbg_net` and `xyz_net`. Here, we'll stack the two dataypes on top of each other and treat it like one data type with twice the channels.

In [None]:
def get_mm_early_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    inputs_mm_early = torch.cat((inputs_rgb, inputs_xyz), 1)
    return (inputs_mm_early,)

So what do you think? Between early fusion and late fusion, which one is more effective in this case?

In [None]:
mm_early_net = Net(8).to(device)
mm_early_opt = Adam(mm_early_net.parameters(), lr=0.0001)
mm_early_train_loss, mm_early_valid_loss = utils.train_model(
    mm_early_net,
    mm_early_opt,
    get_mm_early_inputs,
    epochs,
    train_dataloader,
    valid_dataloader
)

In [None]:
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.plot(plot_x, xyz_train_loss, "goldenrod", label = "LiDAR Train Loss")
plt.plot(plot_x, xyz_valid_loss, "darkgoldenrod", label = "LiDAR Valid Loss")
plt.plot(plot_x, mm_late_train_loss, "green", label = "Late Fusion Train Loss")
plt.plot(plot_x, mm_late_valid_loss, "darkgreen", label = "Late Fusion Valid Loss")
plt.plot(plot_x, mm_early_train_loss, "orchid", label = "Early Fusion Train Loss")
plt.plot(plot_x, mm_early_valid_loss, "darkorchid", label = "Early Fusion Valid Loss")
plt.legend()
plt.show()

## Next

If you've gotten this far and are still waiting for the experiment to complete, we've create a bonus notebook to explore more interesting data types while we're waiting for training to finish.

If the experiment is over, what were the results? Were they what you expected? Let's explore the theory in the next slide deck.

In [None]:
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<center><a href="https://www.nvidia.com/en-us/training/"><img src="https://dli-lms.s3.amazonaws.com/assets/general/DLI_Header_White.png" width="400" height="186" /></a></center>