In [21]:
import torch
import torch.nn as nn

import os
from argparse import ArgumentParser
from glob import glob

import cv2
import numpy as np

from fiery.models.encoder import Encoder
from fiery.models.temporal_model import TemporalModelIdentity, TemporalModel
from fiery.models.distributions import DistributionModule
from fiery.models.future_prediction import FuturePrediction
from fiery.models.decoder import Decoder
from fiery.utils.network import pack_sequence_dim, unpack_sequence_dim, set_bn_momentum
from fiery.utils.geometry import cumulative_warp_features, calculate_birds_eye_view_parameters, VoxelsSumming
import visualise

from fiery.trainer import TrainingModule
from fiery.utils.network import NormalizeInverse
from fiery.utils.instance import predict_instance_segmentation_and_trajectories
from fiery.utils.visualisation import plot_instance_map, generate_instance_colours, make_contour, convert_figure_numpy

## Model I/O

**image**: torch.Tensor float (T, N, 3, H, W) - normalised cameras images with T the sequence length, and N the number of cameras.

**intrinsics**: torch.Tensor float (T, N, 3, 3) - intrinsics containing resizing and cropping parameters.

**extrinsics**: torch.Tensor float  (T, N, 4, 4) - 6 DoF pose from world coordinates to camera coordinates.

**future_egomotion**: torch.Tensor float (T, 6) - 6 DoF egomotion where  t -> t+1

In [22]:
trainer = TrainingModule.load_from_checkpoint('fiery.ckpt', strict=True)

device = torch.device('cpu')
trainer = trainer.to(device)
trainer.eval()

Loaded pretrained weights for efficientnet-b4


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
                not been set for this class (IntersectionOverUnion). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_full_state_property`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
                not been set for this class (PanopticMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend se

TrainingModule(
  (model): Fiery(
    (encoder): Encoder(
      (backbone): EfficientNet(
        (_conv_stem): Conv2dStaticSamePadding(
          3, 48, kernel_size=(3, 3), stride=(2, 2), bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn0): BatchNorm2d(48, eps=0.001, momentum=0.05, affine=True, track_running_stats=True)
        (_blocks): ModuleList(
          (0): MBConvBlock(
            (_depthwise_conv): Conv2dStaticSamePadding(
              48, 48, kernel_size=(3, 3), stride=[1, 1], groups=48, bias=False
              (static_padding): ZeroPad2d((1, 1, 1, 1))
            )
            (_bn1): BatchNorm2d(48, eps=0.001, momentum=0.05, affine=True, track_running_stats=True)
            (_se_reduce): Conv2dStaticSamePadding(
              48, 12, kernel_size=(1, 1), stride=(1, 1)
              (static_padding): Identity()
            )
            (_se_expand): Conv2dStaticSamePadding(
              12, 48, kernel_size=(1, 1), stride=(1, 1)
     

In [23]:
# Download and extract example input data
visualise.download_example_data()

EXAMPLE_DATA_PATH = 'example_data/example_1.npz'

data = np.load(EXAMPLE_DATA_PATH)
image = torch.from_numpy(data['image']).to(device)
intrinsics = torch.from_numpy(data['intrinsics']).to(device)
extrinsics = torch.from_numpy(data['extrinsics']).to(device)
future_egomotions = torch.from_numpy(data['future_egomotion']).to(device)

print(f"The Image shape is '{image.shape}")
print(f"The Intrinsics shape is '{intrinsics.shape}")
print(f"The Extrinsics shape is '{extrinsics.shape}")
print(f"The Future Egomotions shape is '{future_egomotions.shape}")

The Image shape is 'torch.Size([1, 3, 6, 3, 224, 480])
The Intrinsics shape is 'torch.Size([1, 3, 6, 3, 3])
The Extrinsics shape is 'torch.Size([1, 3, 6, 4, 4])
The Future Egomotions shape is 'torch.Size([1, 3, 6])


## Frustum Creation

First, a frustum is created, which is a three-dimensional shape that is a pyramid with a flat top and base, and four triangular sides. The frustum is defined in terms of a grid in the image plane, and has three dimensions: left-right, top-bottom, and depth.

The function first defines the height and width of the image plane, as well as the downsampled versions of these values. It then creates a depth grid by creating a 1D tensor of depth values between the bounds specified in the configuration (D_bound), and reshapes this tensor into a 3D tensor with dimensions (n_depth_slices, downsampled_h, downsampled_w).

Next, the function creates x and y grids that are also 3D tensors with dimensions (n_depth_slices, downsampled_h, downsampled_w). These grids contain the x and y coordinates of each point in the image plane, respectively.

Finally, the function stacks these three grids along the last dimension to create a frustum tensor with dimensions (n_depth_slices, downsampled_h, downsampled_w, 3). This tensor contains the x, y, and depth coordinates of each point in the frustum. The frustum tensor is then wrapped in a PyTorch nn.Parameter and returned. 

**_note_**: I noticed, however, that the shape of the final tensor has consistent x & y coordinates, making a rectangle and not a frustum shape. I beleive the shape is transformed to a frustum downstream in the 'get_geometry' method

In [24]:
image_dim = (224, 480)
encoder_downsample = 8
D_bound = [2.0, 50.0, 1.0]
def create_frustum():
    # Create grid in image plane
    h, w = image_dim
    downsampled_h, downsampled_w = h // encoder_downsample, w // encoder_downsample

    # Depth grid
    depth_grid = torch.arange(*D_bound, dtype=torch.float)
    depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
    n_depth_slices = depth_grid.shape[0]

    # x and y grids
    x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
    x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
    y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
    y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)

    # Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
    # containing data points in the image: left-right, top-bottom, depth
    frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
    return nn.Parameter(frustum, requires_grad=False)

In [25]:
# Depth grid creation
D_bound = [2.0, 50.0, 1.0]
h, w = image_dim
downsampled_h, downsampled_w = h // encoder_downsample, w // encoder_downsample

depth_grid = torch.arange(*D_bound, dtype=torch.float)
depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
n_depth_slices = depth_grid.shape[0]
print(f"The depth grid shape is'{depth_grid.shape}")
print(depth_grid)

The depth grid shape is'torch.Size([48, 28, 60])
tensor([[[ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         ...,
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.]],

        [[ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         ...,
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.]],

        [[ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         ...,
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.]],

        ...,

        [[47., 47., 47.,  ..., 47., 47., 47

In [26]:
# x and y grids
x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
print(f"The x grid grid shape is {x_grid.shape}")
y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)
print(f"The y grid shape is {y_grid.shape}")

The x grid grid shape is torch.Size([48, 28, 60])
The y grid shape is torch.Size([48, 28, 60])


In [27]:
# Create frustum of Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
# containing data points in the image: x, y, depth. The x and y values are the same here
# This is not representative of the shape of a frustum, I beleive the frustum is created
# Later on through scaling

frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
print(f"The y frustum shape is {frustum.shape}")
#print(frustum)
#print(frustum[1,27])

The y frustum shape is torch.Size([48, 28, 60, 3])


## Forward Step 1: Lifting and Projecting Images to BEV

The method that handles this operation is in the function 'calculate_birds_eye_view_features' as such:
1. Packs the sequence dimensions with the batch size to process the images in a time-agnostic manner.
2. Apply Intrinsic & Extrinsic Transformations to the Frustums to transform them to the ego frame
3. Pass the images through the encoder to extract 2D features coupled with depth probabilities
4. Project the images to BEV using the transformed frustums & images' features
   using the 'splat' method of lift-splat-shoot paper.
5. Unpack the sequence dimensions

In [28]:
# Commenting out to not crash computer
#trainer.model.calculate_birds_eye_view_features(image, intrinsics, extrinsics)

### Pack sequence dimensions

In [29]:
# The first step in this funtion is to pack the sequence dimension with the batches into one consolidate dimension

b, s, n, c, h, w = image.shape
# Reshape
x = pack_sequence_dim(image)
intrinsics = pack_sequence_dim(intrinsics)
extrinsics = pack_sequence_dim(extrinsics)

print(extrinsics.shape)

torch.Size([3, 6, 4, 4])


### Transform from Camera Frustums to Ego Frame

In [30]:
# The camera intriniscs and extrinsics are used to convert the images to the ego vehicle's reference frame 
# in the 'get_geometry' method as follows:

def get_geometry(self, intrinsics, extrinsics):
    """Calculate the (x, y, z) 3D position of the features.
    """
    rotation, translation = extrinsics[..., :3, :3], extrinsics[..., :3, 3]
    B, N, _ = translation.shape
    # Add batch, camera dimension, and a dummy dimension at the end
    points = trainer.model.frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)

    # Camera to ego reference frame
    points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3]), 5)
    combined_transformation = rotation.matmul(torch.inverse(intrinsics))
    points = combined_transformation.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
    points += translation.view(B, N, 1, 1, 1, 3)

    # The 3 dimensions in the ego reference frame are: (forward, sides, height)
    return points



In [31]:
# Lets dive into the function some more:

rotation, translation = extrinsics[..., :3, :3], extrinsics[..., :3, 3]
B, N, _ = translation.shape
# Add batch, camera dimension, and a dummy dimension at the end
points = trainer.model.frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)

print(f'The rotation shape is: {rotation.shape}')
print(f'The translation shape is: {translation.shape}')
print(f'The points shape is: {points.shape}')

The rotation shape is: torch.Size([3, 6, 3, 3])
The translation shape is: torch.Size([3, 6, 3])
The points shape is: torch.Size([1, 1, 48, 28, 60, 3, 1])


In [32]:
# Camera to ego reference frame

points = trainer.model.frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)

# x-y points extraction
x_y = points[:, :, :, :, :, :2]
#print(x_y)
# depth extraction
depth = points[:, :, :, :, :, 2:3]

# The x and y points are being multiplied by the depth and concatenated along the depth axis 
# This could be the actual creation of the frustum since the x & y points are increasing at a scale
# of the depth
points = torch.cat((x_y*depth, depth), 5)
print(f'The frustum shape is {points.shape}')

The frustum shape is torch.Size([1, 1, 48, 28, 60, 3, 1])


In [33]:
# Here we transform the frustum from the respective cameras to the ego frame of reference
# The final 3 dimensions in the ego reference frame are: (forward, sides, height)

combined_transformation = rotation.matmul(torch.inverse(intrinsics))
points = combined_transformation.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
points += translation.view(B, N, 1, 1, 1, 3)
print(f'The transformed frustums are returned transformed in the ego frame as: {points.shape}')

The transformed frustums are returned transformed in the ego frame as: torch.Size([3, 6, 48, 28, 60, 3])


### Encode Images

EfficientNet (E) is chosen as the encoder for this model with an final convolutional layer that encodes feature channels and depth of each image(I), with (C) the number of feature channels, (D) the number of discrete depth values and (H, W) the feature spatial size.
This feature is split into two: $e^{k}_{t}=(e^{k}_{t,C},e^{k}_{t,D})$ such that $e^{k}_{t,C}\in\mathbb{R}^{CxHxW}$ and $e^{k}_{t,D}\in\mathbb{R}^{DxHxW}$ and take their cross product to get the final output tensor of the encoder $u^{k}_{t}\in\mathbb{R}^{(CxDxHxW)}$
The depth probabilities act as a form of self-attention, modulating the features according to which depth plane they are predicted to belong to. 

Fiery takes the output of the encoder as defined above and manipulates the tensors such that the final output is a tensor of $u^{k}_{t}\in\mathbb{R}^{BxNxDxHxWxC}$

In [34]:

x = trainer.model.encoder_forward(x)
print(f'The encoder return tensor of shape : {x.shape}')

The encoder return tensor of shape : torch.Size([3, 6, 48, 28, 60, 64])


### Project to BEV

Following the method 'trainer.model.projection_to_birds_eye_view(x, points)', which is primarilty adapted from lift-splat-shoot at [this](https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/models.py#L200) point. 

The input data x is expected to be a tensor with dimensions B x N x D x H x W x C, where B is the batch size, N is the number of cameras or point clouds, D is the depth, H and W are the height and width, and C is the number of channels. The geometry represents the tensor containing the frustums, which is a tensor of 3D coordinates in the ego frame, with dimensions B x N x D x H x W x 3.

We first establish the total number of 3D points based on the predicted encoder features $N_{total}$ =B x N x D x H x W

We then loop through each batch and perform the following:

1. First, the x_b tensor is flattened to a [$N_{total}$, C] tensor  so that all the points from all the cameras and point clouds are concatenated together. The geometry_b tensor is also flattened to the shape of [$N_{total}$, 3].
2. Map geometry_b positions to the BEV grid
3. A mask is determined based on the converted geometry_b and is applied to geometry_b and x_b to remove outlier points outside the set BEV grid
4. Ranks are assigned to the geometry_b and x_b tensors such that the consecutive indices are within the same voxel, this structuring is an optimization step that simplifies voxel summing downstream
5. 



In [35]:
def projection_to_birds_eye_view(self, x, geometry):
    """ Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/models.py#L200"""
    # batch, n_cameras, depth, height, width, channels
    batch, n, d, h, w, c = x.shape
    output = torch.zeros(
        (batch, c, self.bev_dimension[0], self.bev_dimension[1]), dtype=torch.float, device=x.device
    )

    # Number of 3D points
    N = n * d * h * w
    for b in range(batch):
        # flatten x
        x_b = x[b].reshape(N, c)

        # Convert positions to integer indices
        geometry_b = ((geometry[b] - (self.bev_start_position - self.bev_resolution / 2.0)) / self.bev_resolution)
        geometry_b = geometry_b.view(N, 3).long()

        # Mask out points that are outside the considered spatial extent.
        mask = (
                (geometry_b[:, 0] >= 0)
                & (geometry_b[:, 0] < self.bev_dimension[0])
                & (geometry_b[:, 1] >= 0)
                & (geometry_b[:, 1] < self.bev_dimension[1])
                & (geometry_b[:, 2] >= 0)
                & (geometry_b[:, 2] < self.bev_dimension[2])
        )
        x_b = x_b[mask]
        geometry_b = geometry_b[mask]

        # Sort tensors so that those within the same voxel are consecutives.
        ranks = (
                geometry_b[:, 0] * (self.bev_dimension[1] * self.bev_dimension[2])
                + geometry_b[:, 1] * (self.bev_dimension[2])
                + geometry_b[:, 2]
        )
        ranks_indices = ranks.argsort()
        x_b, geometry_b, ranks = x_b[ranks_indices], geometry_b[ranks_indices], ranks[ranks_indices]

        # Project to bird's-eye view by summing voxels.
        x_b, geometry_b = VoxelsSumming.apply(x_b, geometry_b, ranks)

        bev_feature = torch.zeros((self.bev_dimension[2], self.bev_dimension[0], self.bev_dimension[1], c),
                                    device=x_b.device)
        bev_feature[geometry_b[:, 2], geometry_b[:, 0], geometry_b[:, 1]] = x_b

        # Put channel in second position and remove z dimension
        bev_feature = bev_feature.permute((0, 3, 1, 2))
        bev_feature = bev_feature.squeeze(0)

        output[b] = bev_feature

    return output

Lets start by having a loook at the BEV dimensional attributes where:

bev_resolution: Bird's-eye view bev_resolution

bev_start_position: Bird's-eye view first element

bev_dimension: Bird's-eye view tensor spatial dimension

In [36]:
print(f'BEV resolution : {trainer.model.bev_resolution}')
print(f'BEV start position : {trainer.model.bev_start_position}')
print(f'BEV dimensions x, y, z : {trainer.model.bev_dimension[0]}, {trainer.model.bev_dimension[1]}, {trainer.model.bev_dimension[2]}')

BEV resolution : Parameter containing:
tensor([ 0.5000,  0.5000, 20.0000])
BEV start position : Parameter containing:
tensor([-49.7500, -49.7500,   0.0000])
BEV dimensions x, y, z : 200, 200, 1


Lets skip to the end of step 3 to undestand the ranking system looking at the 0th batch as a sample

In [37]:
# Align variables to the function

b = 0
geometry = points


In [38]:
batch, n, d, h, w, c = x.shape
output = torch.zeros(
    (batch, c, trainer.model.bev_dimension[0], trainer.model.bev_dimension[1]), dtype=torch.float, device=x.device
)

# Number of 3D points
N = n * d * h * w

# flatten x
x_b = x[b].reshape(N, c)

# Convert positions to integer indices
geometry_b = ((geometry[b] - (trainer.model.bev_start_position - trainer.model.bev_resolution / 2.0)) / trainer.model.bev_resolution)
geometry_b = geometry_b.view(N, 3).long()

# Mask out points that are outside the considered spatial extent.
mask = (
        (geometry_b[:, 0] >= 0)
        & (geometry_b[:, 0] < trainer.model.bev_dimension[0])
        & (geometry_b[:, 1] >= 0)
        & (geometry_b[:, 1] < trainer.model.bev_dimension[1])
        & (geometry_b[:, 2] >= 0)
        & (geometry_b[:, 2] < trainer.model.bev_dimension[2])
)
x_b = x_b[mask]
geometry_b = geometry_b[mask]

print(f'x_b : {x_b.shape}')
print(f'geometry_b : {geometry_b.shape}')

x_b : torch.Size([452675, 64])
geometry_b : torch.Size([452675, 3])


It first assigns ranks to the each of the points in geometry_b by dot product of geometry_b[:,0], geometry_b[:,1],geometry_b[:,2] and (self.bev_dimension[1] * self.bev_dimension[2]), (self.bev_dimension[2]), 1 respectively which gives a unique rank for each tensor, based on its position within the BEV grid. 

The bev_dimension[1] * bev_dimension[2] term is used to account for the number of voxels in the y-z plane and is multiplied by the x-coordinate of the voxel to ensure that the x-coordinate has the greatest weight in the sorting. The bev_dimension[2] term is used to account for the number of voxels in the z-axis and is multiplied by the y-coordinate of the voxel. And finally, the z-coordinate of the voxel is added to the result. This way, the resulting 'ranks' variable will be a unique and consistent value for each voxel, and the indices of the sorted 'ranks' variable will correspond to the indices of the points sorted by their voxel location. 

It then uses the argsort() function to create an array called ranks_indices, which contains the indices that would sort the ranks array in ascending order.

Finally, it uses these indices to sort the original tensors represented by the geometry_b array, as well as the x_b array and ranks array in the same order. This ensures that tensors that are located within the same voxel in the BEV grid are now consecutive in all three arrays.

This way it is easy to process the groups of points which are located in same voxel, it also helps with faster processing as it reduces the iteration over the set.

In [39]:
# Sort tensors so that those within the same voxel are consecutives.
ranks = (
        geometry_b[:, 0] * (trainer.model.bev_dimension[1] * trainer.model.bev_dimension[2])
        + geometry_b[:, 1] * (trainer.model.bev_dimension[2])
        + geometry_b[:, 2]
)
ranks_indices = ranks.argsort()
x_b, geometry_b, ranks = x_b[ranks_indices], geometry_b[ranks_indices], ranks[ranks_indices]

We now perform the voxel summing operation which leverages what is called a 'cumulative sum trick'

Taken as an excerpt from lift-splat-shoot:

_The “cumulative sum trick” is the observation that sum pooling can be
performed by sorting all points according to bin id, performing a cumulative sum
over all features, then subtracting the cumulative sum values at the boundaries
of the bin sections. Instead of relying on autograd to backprop through all three
steps, the analytic gradient for the module as a whole can be derived, speeding up
training by 2x. We call the layer “Frustum Pooling” because it handles converting
the frustums produced by n images into a fixed dimensional H × W x C tensor (in fiery's case) 
independent of the number of cameras n._

__side note__: The implementation of this frustum pooling is shown below as the the 'VoxelSumming' subclass which inherits from the 'torch.autograd.Function'. This is the base class for all functions that compute gradients in PyTorch and is a fundamental building block for the PyTorch autograd system. When a function is executed in PyTorch, it creates a torch.autograd.Function object that computes the forward pass of the computation. The torch.autograd.Function object also holds the information necessary to compute the gradients of the computation with respect to the inputs using the backpropagation algorithm.

There are two main methods that need to be implemented for any torch.autograd.Function subclass:

1. forward(self, input): This method computes the forward pass of the function and stores the output tensors in the self.save_for_backward attribute.
2. backward(self, grad_output): This method computes the gradients with respect to the inputs given the gradients with respect to the outputs. It should store the gradients in the self.grad_fn attribute of the corresponding input tensors.

A common usage of torch.autograd.Function is when you want to implement a custom operation that is not natively supported by PyTorch or you want to add some custom behavior to the autograd system, for example for a custom loss function or custom layer in a neural network, in our case we are leveraging this function to capture our “cumulative sum trick”.

Back to the implementation:

A boolean mask called mask is created, with the same shape as x tensor. The mask tensor is filled with ones and is of type torch.bool
The second line of code assigns the result of the comparison ranks[1:] != ranks[:-1] to a slice of the mask tensor. Specifically, it assigns the comparison to all elements of mask except the last one. The comparison ranks[1:] != ranks[:-1] compares each element of ranks[1:] with the corresponding element of ranks[:-1] and returns a tensor of the same shape as ranks[1:] with elements of type torch.bool that are True if the two elements are different and False otherwise.
This mask is used to select only the elements in x and geometry that correspond to the change in voxel position, by using the mask to index the x and geometry tensor, this allows the forward method to only keep the first elements of each voxel and sum the feature for each voxel.

The masked geomety_b tensor is kept for the final assignment of x_b to geometry_b features which occurs after the voxel summing, shown below. 

In [40]:
class VoxelsSumming(torch.autograd.Function):
    """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193"""
    @staticmethod
    def forward(ctx, x, geometry, ranks):
        """The features `x` and `geometry` are ranked by voxel positions."""
        # Cumulative sum of all features.
        x = x.cumsum(0)

        # Indicates the change of voxel.
        mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
        mask[:-1] = ranks[1:] != ranks[:-1]

        x, geometry = x[mask], geometry[mask]
        # Calculate sum of features within a voxel.
        x = torch.cat((x[:1], x[1:] - x[:-1]))

        ctx.save_for_backward(mask)
        ctx.mark_non_differentiable(geometry)

        return x, geometry

    @staticmethod
    def backward(ctx, grad_x, grad_geometry):
        (mask,) = ctx.saved_tensors
        # Since the operation is summing, we simply need to send gradient
        # to all elements that were part of the summation process.
        indices = torch.cumsum(mask, 0)
        indices[mask] -= 1

        output_grad = grad_x[indices]

        return output_grad, None, None

Finally, we complete the bev projection by implementing the voxel summing static method and use the 'geometry_b' tensor to select the appropriate positions in the bev_feature tensor using advanced indexing. Specifically, it uses the values in the geometry_b[:, 2], geometry_b[:, 0], geometry_b[:, 1] as the indices to select the positions in the bev_feature tensor, and assigns the values of the x_b tensor to those positions. 

We perform some final permutations to 

In [1]:
x_b, geometry_b = VoxelsSumming.apply(x_b, geometry_b, ranks)

bev_feature = torch.zeros((trainer.model.bev_dimension[2], trainer.model.bev_dimension[0], trainer.model.bev_dimension[1], c),
                            device=x_b.device)
bev_feature[geometry_b[:, 2], geometry_b[:, 0], geometry_b[:, 1]] = x_b

# Put channel in second position and remove z dimension
bev_feature = bev_feature.permute((0, 3, 1, 2))
bev_feature = bev_feature.squeeze(0)

output[b] = bev_feature



NameError: name 'VoxelsSumming' is not defined