In [38]:
import torch
import torch.nn as nn

import os
from argparse import ArgumentParser
from glob import glob

import cv2
import numpy as np

from fiery.models.encoder import Encoder
from fiery.models.temporal_model import TemporalModelIdentity, TemporalModel
from fiery.models.distributions import DistributionModule
from fiery.models.future_prediction import FuturePrediction
from fiery.models.decoder import Decoder
from fiery.utils.network import pack_sequence_dim, unpack_sequence_dim, set_bn_momentum
from fiery.utils.geometry import cumulative_warp_features, calculate_birds_eye_view_parameters, VoxelsSumming
import visualise

from fiery.trainer import TrainingModule
from fiery.utils.network import NormalizeInverse
from fiery.utils.instance import predict_instance_segmentation_and_trajectories
from fiery.utils.visualisation import plot_instance_map, generate_instance_colours, make_contour, convert_figure_numpy

## Model I/O

**image**: torch.Tensor float (T, N, 3, H, W) - normalised cameras images with T the sequence length, and N the number of cameras.

**intrinsics**: torch.Tensor float (T, N, 3, 3) - intrinsics containing resizing and cropping parameters.

**extrinsics**: torch.Tensor float  (T, N, 4, 4) - 6 DoF pose from world coordinates to camera coordinates.

**future_egomotion**: torch.Tensor float (T, 6) - 6 DoF egomotion where  t -> t+1

In [42]:
trainer = TrainingModule.load_from_checkpoint('fiery.ckpt', strict=True)

device = torch.device('cuda:0')
trainer = trainer.to(device)
trainer.eval()

Loaded pretrained weights for efficientnet-b4


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
                not been set for this class (IntersectionOverUnion). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_full_state_property`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
                not been set for this class (PanopticMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend se

TrainingModule(
  (model): Fiery(
    (encoder): Encoder(
      (backbone): EfficientNet(
        (_conv_stem): Conv2dStaticSamePadding(
          3, 48, kernel_size=(3, 3), stride=(2, 2), bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn0): BatchNorm2d(48, eps=0.001, momentum=0.05, affine=True, track_running_stats=True)
        (_blocks): ModuleList(
          (0): MBConvBlock(
            (_depthwise_conv): Conv2dStaticSamePadding(
              48, 48, kernel_size=(3, 3), stride=[1, 1], groups=48, bias=False
              (static_padding): ZeroPad2d((1, 1, 1, 1))
            )
            (_bn1): BatchNorm2d(48, eps=0.001, momentum=0.05, affine=True, track_running_stats=True)
            (_se_reduce): Conv2dStaticSamePadding(
              48, 12, kernel_size=(1, 1), stride=(1, 1)
              (static_padding): Identity()
            )
            (_se_expand): Conv2dStaticSamePadding(
              12, 48, kernel_size=(1, 1), stride=(1, 1)
     

In [44]:
# Download and extract example input data
visualise.download_example_data()

EXAMPLE_DATA_PATH = 'example_data/example_1.npz'

data = np.load(EXAMPLE_DATA_PATH)
image = torch.from_numpy(data['image']).to(device)
intrinsics = torch.from_numpy(data['intrinsics']).to(device)
extrinsics = torch.from_numpy(data['extrinsics']).to(device)
future_egomotions = torch.from_numpy(data['future_egomotion']).to(device)

print(f"The Image shape is '{image.shape}")
print(f"The Intrinsics shape is '{intrinsics.shape}")
print(f"The Extrinsics shape is '{extrinsics.shape}")
print(f"The Future Egomotions shape is '{future_egomotions.shape}")

The Image shape is 'torch.Size([1, 3, 6, 3, 224, 480])
The Intrinsics shape is 'torch.Size([1, 3, 6, 3, 3])
The Extrinsics shape is 'torch.Size([1, 3, 6, 4, 4])
The Future Egomotions shape is 'torch.Size([1, 3, 6])


## Forward Step 1: Lifting and Projecting Images to BEV

The method that handles this operation is in the function 'calculate_birds_eye_view_features'. This method packs the temporal input with the batch size to process the images in a time-agnostic manner 

In [45]:

trainer.model.calculate_birds_eye_view_features(image, intrinsics, extrinsics)

tensor([[[[[-4.0861e-01,  0.0000e+00, -2.8302e-01,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [ 0.0000e+00, -2.2157e-02, -7.2931e-03,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [ 5.1881e-03, -1.8833e-02,  0.0000e+00,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           ...,
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00]],

          [[-5.4252e-01,  0.0000e+00, -4.5819e-01,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [ 0.0000e+00, -1.2307e-02,  1.6408e-03,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
           [-2.6655e-04, -1.8383e-02,  0.0000e+00,  ...,  0.0000e+00,
             0.0000e+00,  0.0000e+00],
 

## Frustum Creation

This function creates a frustum, which is a three-dimensional shape that is a pyramid with a flat top and base, and four triangular sides. The frustum is defined in terms of a grid in the image plane, and has three dimensions: left-right, top-bottom, and depth.

The function first defines the height and width of the image plane, as well as the downsampled versions of these values. It then creates a depth grid by creating a 1D tensor of depth values between the bounds specified in the configuration (D_bound), and reshapes this tensor into a 3D tensor with dimensions (n_depth_slices, downsampled_h, downsampled_w).

Next, the function creates x and y grids that are also 3D tensors with dimensions (n_depth_slices, downsampled_h, downsampled_w). These grids contain the x and y coordinates of each point in the image plane, respectively.

Finally, the function stacks these three grids along the last dimension to create a frustum tensor with dimensions (n_depth_slices, downsampled_h, downsampled_w, 3). This tensor contains the x, y, and depth coordinates of each point in the frustum. The frustum tensor is then wrapped in a PyTorch nn.Parameter and returned.

In [10]:
image_dim = (224, 480)
encoder_downsample = 8
D_bound = [2.0, 50.0, 1.0]
def create_frustum():
    # Create grid in image plane
    h, w = image_dim
    downsampled_h, downsampled_w = h // encoder_downsample, w // encoder_downsample

    # Depth grid
    depth_grid = torch.arange(*D_bound, dtype=torch.float)
    depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
    n_depth_slices = depth_grid.shape[0]

    # x and y grids
    x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
    x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
    y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
    y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)

    # Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
    # containing data points in the image: left-right, top-bottom, depth
    frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
    return nn.Parameter(frustum, requires_grad=False)

In [28]:
# Depth grid creation
D_bound = [2.0, 50.0, 1.0]
h, w = image_dim
downsampled_h, downsampled_w = h // encoder_downsample, w // encoder_downsample

depth_grid = torch.arange(*D_bound, dtype=torch.float)
depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
n_depth_slices = depth_grid.shape[0]
print(f"The depth grid shape is'{depth_grid.shape}")
print(depth_grid)

The depth grid shape is'torch.Size([48, 28, 60])
tensor([[[ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         ...,
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.]],

        [[ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         ...,
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  ...,  3.,  3.,  3.]],

        [[ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         ...,
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  ...,  4.,  4.,  4.]],

        ...,

        [[47., 47., 47.,  ..., 47., 47., 47

In [27]:
# x and y grids
x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
print(f"The x grid grid shape is {x_grid.shape}")
y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)
print(f"The y grid shape is {y_grid.shape}")

The x grid grid shape is torch.Size([48, 28, 60])
The y grid grid shape is torch.Size([48, 28, 60])


In [47]:

# Create frustum of Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
# containing data points in the image: x, y, depth
frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
print(f"The y frustum shape is {frustum.shape}")
#print(frustum)
#print(frustum[1,27])

The y frustum shape is torch.Size([48, 28, 60, 3])


In [46]:
frustum = torch.rand([48,28,60,3])
print(dfrustum.shape)

torch.Size([48, 28, 60, 3])
