In [1]:


import torch
import torch.nn as nn

from fiery.models.encoder import Encoder
from fiery.models.temporal_model import TemporalModelIdentity, TemporalModel
from fiery.models.distributions import DistributionModule
from fiery.models.future_prediction import FuturePrediction
from fiery.models.decoder import Decoder
from fiery.utils.network import pack_sequence_dim, unpack_sequence_dim, set_bn_momentum
from fiery.utils.geometry import cumulative_warp_features, calculate_birds_eye_view_parameters, VoxelsSumming

Frustum Creation

This function appears to be creating a frustum, which is a three-dimensional shape that is a pyramid with a flat top and base, and four triangular sides. The frustum is defined in terms of a grid in the image plane, and has three dimensions: left-right, top-bottom, and depth.

The function first defines the height and width of the image plane, as well as the downsampled versions of these values. It then creates a depth grid by creating a 1D tensor of depth values between the bounds specified in the configuration (D_bound), and reshapes this tensor into a 3D tensor with dimensions (n_depth_slices, downsampled_h, downsampled_w).

Next, the function creates x and y grids that are also 3D tensors with dimensions (n_depth_slices, downsampled_h, downsampled_w). These grids contain the x and y coordinates of each point in the image plane, respectively.

Finally, the function stacks these three grids along the last dimension to create a frustum tensor with dimensions (n_depth_slices, downsampled_h, downsampled_w, 3). This tensor contains the x, y, and depth coordinates of each point in the frustum. The frustum tensor is then wrapped in a PyTorch nn.Parameter and returned.

In [10]:
image_dim = (224, 480)
encoder_downsample = 8
D_bound = [2.0, 50.0, 1.0]
def create_frustum():
    # Create grid in image plane
    h, w = image_dim
    downsampled_h, downsampled_w = h // encoder_downsample, w // encoder_downsample

    # Depth grid
    depth_grid = torch.arange(*D_bound, dtype=torch.float)
    depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
    n_depth_slices = depth_grid.shape[0]

    # x and y grids
    x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
    x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
    y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
    y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)

    # Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
    # containing data points in the image: left-right, top-bottom, depth
    frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
    return nn.Parameter(frustum, requires_grad=False)

In [28]:
# Depth grid creation
D_bound = [2.0, 50.0, 1.0]
h, w = image_dim
downsampled_h, downsampled_w = h // encoder_downsample, w // encoder_downsample

depth_grid = torch.arange(*D_bound, dtype=torch.float)
depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
n_depth_slices = depth_grid.shape[0]
print(f"The depth grid shape is'{depth_grid.shape}")
print(depth_grid)

The depth grid shape is'torch.Size([48, 28, 60])
tensor([[[ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         ...,
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.]],

        [[ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         ...,
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.]],

        [[ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         ...,
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.]],

        ...,

        [[47., 47., 47.,  ..., 47., 47., 47

In [27]:
# x and y grids
x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
print(f"The x grid grid shape is {x_grid.shape}")
y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)
print(f"The y grid shape is {y_grid.shape}")

The x grid grid shape is torch.Size([48, 28, 60])
The y grid grid shape is torch.Size([48, 28, 60])


In [47]:

# Create frustum of Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
# containing data points in the image: x, y, depth
frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
print(f"The y frustum shape is {frustum.shape}")
#print(frustum)
#print(frustum[1,27])

The y frustum shape is torch.Size([48, 28, 60, 3])
