# Homework 3 - Frustum PointNet for 3D Object Detection
![image](http://stanford.edu/~rqi/frustum-pointnets/images/teaser.jpg)

В рамках этого домашнего задания вам надо будет реализовать архитектуру Frustum PointNet для 3D детекции объектов. Для выполнения очень рекомендуется машина с какой-то GPU.

Референсная статья - [Frustum PointNets for 3D Object Detection from RGB-D Data](https://arxiv.org/abs/1711.08488)

Этот метод использует как вход и лидарное облако, и картинку, но внутри декомпозирует задачу детекции объекта в 3D на две подзадачи, каждая из которых работает только с одним типом данных:
- задетектировать объект в 2D на картинке, используя какой-либо готовый детектор объектов;
- предсказать 3д параметры (положение, ориентацию, размеры) по куску лидарного облака, которое при репроекции в картинку попало бы в 2D бокс объекта.

Для решения первой подзадачи возьмем готовый 2D детектор объектов из зоопарка моделей torchvision. Для решения второй задачи потребуется обучить две сетки:
- pointnet для сегментации лидарных облаков. Он будет принимать на вход лидарные точки, проецирующиеся в 2D бокс (т.н frustum - усеченная пирамида), и сегментировать их на два класса: принадлежащие объекту или являющиеся фоном;
- pointnet для регрессии параметров бокса. Он принимает на вход лидарные точки, которые были отсегментированны предыдущей сеткой как принадлежащие объекту, и возвращает параметры 3D бокса - координаты центра, размеры и ориентацию.

Сети будем обучать последовательно. Для ускорения сеть будет обучаться на подмножестве датасета [kitti](https://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d).

Для получения полного балла за домашку надо будет дописать недостающий код, обучить сетки и побить бейзлайн.

Перед решением домашки рекомендуется посмотреть третий семинар, во второй части которого рассказывается про математику проецирования лидарных точек в картинку.

In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
# check imports; if something fails, install it via pip3
import shapely
import scipy.optimize
import plotly
import matplotlib
import sklearn.metrics  # pip3 install scikit-learn
import numpy as np
import torchvision
import tqdm

## 1. Архитектуры PointNet'ов для сегментации и регрессии (2 балла)

![image](https://stanford.edu/~rqi/pointnet/images/pointnet.jpg)

In [3]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

Класс `SpatialTransformerNetworkKDim` - реализация spatial transformer network (T-net со схемы выше) для предсказания как надо трансформировать облако (входное с координатами xyz или уже переведенное в некоторое новое пространство 1D свертками) в некоторое стандартизованное представление, из которого будет проще решать целевую задачу (см. рассказ про pointnet в лекции или [статью про pointnet](https://arxiv.org/abs/1612.00593))

In [4]:
class SpatialTransformerNetworkKDim(nn.Module):
    def __init__(self, k=64):
        super().__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        """
        x - tensor of shape [bs, self.k, n_points]
        returns tensor of shape [bs, self.k, self.k] with transforms for each sample
        """
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]  # maxpool
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        identity_transform = Variable(
            torch.from_numpy(np.eye(self.k, dtype=np.float32).flatten())).view(1, self.k*self.k).repeat(batchsize, 1)
        if x.is_cuda:
            identity_transform = identity_transform.cuda()
        x = x + identity_transform
        x = x.view(-1, self.k, self.k)
        return x

`PointNetFeatureExtractor` - класс, реализующий основной backbone для вытаскивания фичей из облака. Выдает либо глобальный вектор фичей для облака (для задач регрессии и классификации; при `global_feat= True`), либо фичи для каждой точки (для задачи сегментации; при `global_feat = False`)

In [5]:
class PointNetFeatureExtractor(nn.Module):
    def __init__(self, global_feat=True, feature_transform=False):
        super().__init__()
        self.stn = SpatialTransformerNetworkKDim(k=4)
        self.conv1 = torch.nn.Conv1d(4, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = SpatialTransformerNetworkKDim(k=64)

    def forward(self, x):
        """
        x - tensor of shape [bs, 4, n_points]
        Returns tuple (features, input_stn_transform, interm_stn_transform), where
        - features - either global features for pointcloud or features for each point
        - input_stn_transform, interm_stn_transform - predicted transforms from SpatialTransformerNetworkKDim
        """
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat

        x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
        return torch.cat([x, pointfeat], 1), trans, trans_feat

In [6]:
sim_data = Variable(torch.rand(32,4,2500))
pointfeat = PointNetFeatureExtractor(global_feat=True)
out, _, _ = pointfeat(sim_data)
print('global feat', out.size())

pointfeat = PointNetFeatureExtractor(global_feat=False)
out, _, _ = pointfeat(sim_data)
print('point feat', out.size())

global feat torch.Size([32, 1024])
point feat torch.Size([32, 1088, 2500])


`SegmentationPointNet` - сегментационный PointNet. Принимает на вход облако и возвращает логарифмированные вероятности за классы для каждой точки.

**(1 балл)** допишите недостающие куски

In [None]:
class SegmentationPointNet(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super().__init__()
        self.k = k
        self.feature_transform = feature_transform
        self.feat = PointNetFeatureExtractor(global_feat=False, feature_transform=feature_transform)
        # ============YOUR CODE ===================
        # define subnetwork that will transform point features from PointNetFeatureExtractor
        # into k dims. it should be like (Conv1dBnRelu(1088, 512) + Conv1dBnRelu(512, 256) + Conv1dBnRelu(256, 128) + Conv1d(128, k))
        raise NotImplementedError()
        # ==========================================
        self.logits = torch.nn.Conv1d(128, self.k, 1)

    def forward(self, x):
        """
        x - pointcloud, tensor of shape [batchsize, n_points, 4]
        """
        x = torch.moveaxis(x, -1, -2)  # swap 'dim' and 'points' axis
        x, trans, _ = self.feat(x)
        # ======= YOUR CODE HERE ====================
        # apply your subnetwork to x
        x = ...
        raise NotImplementedError()
        # ===========================================
        x = self.logits(x)
        return x, trans

In [None]:
sim_data = Variable(torch.rand(32,2500, 4))
seg = SegmentationPointNet(k = 3)
out, _ = seg(sim_data)
print('seg', out.size())

`PointNetDetector` - PointNet для предсказания параметров 3D бокса. Принимает на вход облако с точками, которые сегментационная сеть определила как часть объекта. Параметры центра коробки будем предсказывать напрямую. Параметры размеров и ориентации коробки будем предсказывать как классификацию бина с характерными параметрами и регрессию поправок к параметрам в бине.

**(1 балл)** допишите недостающие куски

In [None]:
class PointNetDetector(nn.Module):
    def __init__(self, feature_transform=False, num_heading_bins=12, num_size_clusters=8):
        super(PointNetDetector, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetFeatureExtractor(global_feat=True, feature_transform=feature_transform)
        # ============YOUR CODE ===================
        # define subnetwork that will transform global features for cloud from PointNetFeatureExtractor
        # into final 3d box parameters.
        # it should be like (Linear(1024, 512) + BN + Relu + Linear(512, 256) + Dropout(0.3) + BN + Relu + heads)
        raise NotImplementedError()

        self.dropout = nn.Dropout(p=0.3)
        # ==========================================

        self.center_reg_head = nn.Linear(256, 3)

        self.size_class_head = nn.Linear(256, num_size_clusters)
        self.size_reg_head = nn.Linear(256, num_size_clusters * 3)

        self.heading_class_head = nn.Linear(256, num_heading_bins)
        self.heading_reg_head = nn.Linear(256, num_heading_bins)

    def forward(self, x):
        """
        x - pointcloud, tensor of shape [batchsize, n_points, 4]
        """
        x = torch.moveaxis(x, -1, -2)  # swap 'dim' and 'points' axis

        x, trans, _ = self.feat(x)
        # ======= YOUR CODE HERE ====================
        # apply your subnetwork to x
        x = ...
        # ===========================================
        center_reg = self.center_reg_head(x)
        size_class = self.size_class_head(x)
        size_reg = self.size_reg_head(x)
        heading_class = self.heading_class_head(x)
        heading_reg = self.heading_reg_head(x)

        return center_reg, size_class, size_reg, heading_class, heading_reg, \
                trans  # feature transform regularization goes here


def feature_transform_regularizer(trans):
    """
    trans - tensor of shape [bs, d, d] with transforms of d-dim points into new space
    regularizer forces 'trans' to be close to orthonormal i.e trans.dot(tran.T) == I
    """
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[np.newaxis]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss

In [None]:
print('stn: ')
sim_data_stn_4d = Variable(torch.rand(32,4, 2500))
trans = SpatialTransformerNetworkKDim(k=4)
out = trans(sim_data_stn_4d)
print('stn out', out.size())
print('loss', feature_transform_regularizer(out))

print('\n stn 64d: ')
sim_data_stn_64d = Variable(torch.rand(32, 64, 2500))
trans = SpatialTransformerNetworkKDim(k=64)
out = trans(sim_data_stn_64d)
print('stn64d out', out.size())
print('loss', feature_transform_regularizer(out))

print('\n detector: ')
sim_data = Variable(torch.rand(32,2500, 4))
detector = PointNetDetector()
out = detector(sim_data)
print('center_reg', out[0].size())
print('size class', out[1].size())
print('size reg', out[2].size())
print('heading class', out[3].size())
print('heading reg', out[4].size())
print('input transform', out[5].size())

## 2. Подготовка данных для обучения (2 балла)

Загрузим сабсэт kitti

In [None]:
from yfile import download_from_yadisk
import os
TARGET_DIR = '/home/vyurchenko/data/shad/kitti_subset'
FILENAME = "kitti_subset.zip"
if not os.path.exists(os.path.join(TARGET_DIR, FILENAME)):
    # we are going to download 2.7 gb file, downloading will take some time
    download_from_yadisk(
        short_url='https://disk.yandex.ru/d/7fdMyxhreW_SiA',
        filename=FILENAME,
        target_dir=TARGET_DIR
    )
    # alternative way:
    #from gfile import download_list
    #download_list(url=https://drive.google.com/file/d/1VdLawplRcdT3UfvwNsihgzjJ0ks-Cmn0,
    #               filename=FILENAME, target_dir=TARGET_DIR)
filesize = os.path.getsize(os.path.join(TARGET_DIR, FILENAME))
GB = 2**30
assert filesize > 1 * GB, f"{filesize} is too small, something wrong with downloading"

In [None]:
! unzip -q kitti_subset.zip

Класс для представления датасета KITTI реализован за вас. По сути он хранит внутри сэмплы из датасета в сыром виде. Описание датасета доступно в файле `KITTI dataset description.txt`, который лежит в архиве с датасетом.


In [None]:
from kitti_dataset import split_train_val, KittiDataset

KITTI_ROOT = os.path.join(TARGET_DIR, "kitti")

train_items, val_items = split_train_val(KITTI_ROOT, fraction=0.25)

train_kitti = KittiDataset(KITTI_ROOT, items=train_items, split='training', only_easy=False)
val_kitti = KittiDataset(KITTI_ROOT, items=val_items, split='training', only_easy=True)

Каждый сэмпл датасета хранит картинку, лидарное облако, структуру с калибровками, структуру с gt данными детектора

In [None]:
train_kitti[1]

Для того, чтобы сформировать датасет для обучения детектора потребуется дописать два класса: `Projector` и `Detector2DWrapper`.

Класс `Projector` реализует преобразования между различными фреймами. Вам нужно научиться трансформировать данные между фреймом камеры и лидара, а также проецировать точки на изображение и точки с изображения -- в лучи во фрейме камеры. Подробнее про формат данных можно прочитать в https://www.cvlibs.net/datasets/kitti/setup.php, https://towardsdatascience.com/kitti-coordinate-transformations-125094cd42fb

Для трансформаций облаков потребуется использовать класс `Calibration`, который имеет интерфейс dict'a. Из него вам потребуются следующие поля:
- `Tr_velo_to_cam` - матрица преобразований из системы координат лидара в систему координат камеры
- `R0_rect` - матрица ректификации изображения - преобразование точек из системы координат камеры в некоторую новую систему, учитывающую искажения линзы (дисторсию).
- `P2` - матрица проекции точек в системе координат камеры (с учтенной дисторсией) на плоскость картинки
Таким образом чтобы спроецировать точки, заданные в системе координат камеры, на плоскость картинки, надо их сначала домножить на матрицу `R0_rect` , а потом на матрицу `P2`

- **(1 балл)** Допишите преобразования между различными фреймами в класс `Projector`.

In [None]:
import numpy as np
from numpy.typing import NDArray

from kitti_dataset import Calibration


class Projector:
    @staticmethod
    def to_homogenous_coords(coords_3d: NDArray) -> NDArray:
        """
        @param coords_3d: (..., K)
        @return: NDArray of same dtype with shape (..., K + 1) -- same points in homogenous coordinates

        """
        assert coords_3d.shape[-1] < 4, "In this task, this function should never be called with last dimention >= 4"
        #===========YOUR CODE==========================
        raise NotImplementedError()

    @staticmethod
    def from_homogenous_coords(coords_3d: NDArray) -> NDArray:
        """
        @param coords_3d: (N, K), where last dimention corresponds to homogenous dimention
        @return: NDArray of same dtype with shape (N, K - 1) -- same points in homogenous coordinates

        """
        assert coords_3d.shape[-1] > 2
        #===========YOUR CODE==========================
        raise NotImplementedError()

    @staticmethod
    def _check_and_maybe_cast_to_homogenous(coords):
        assert coords.shape[-1] in (3, 4)
        if coords.shape[-1] == 3:
            coords = Projector.to_homogenous_coords(coords)
        return coords

    @staticmethod
    def world_to_camera(world_coords, calib: Calibration):
        """
        This function projects points from world coorinate frame to camera coordinate frame
        Note: check that you do not pass Velodyne points with intensity here, as intensity occupies 4th dimention as well

        @param world_coords: (..., 3) or (..., 4); if not homogenous, will be converted to homogenous
        @return: camera frame coordinates, homogenous
        """
        world_coords = Projector._check_and_maybe_cast_to_homogenous(world_coords)
        #===========YOUR CODE==========================
        raise NotImplementedError()

    @staticmethod
    def camera_to_world(camera_coords, calib: Calibration):
        """
        @param camera_coords: (..., 3) or (..., 4); if not homogenous, will be converted to homogenous
        @return: world frame coordinates, homogenous
        """
        camera_coords = Projector._check_and_maybe_cast_to_homogenous(camera_coords)
        #===========YOUR CODE==========================
        raise NotImplementedError()

    @staticmethod
    def camera_to_projection(camera_coords, calib: Calibration):
        """
        @param camera_coords: (..., 3) or (..., 4); if not homogenous, will be converted to homogenous
        @return: xyw coordinates in image projection frame, shape (..., 3)
        """
        camera_coords = Projector._check_and_maybe_cast_to_homogenous(camera_coords)
        #===========YOUR CODE==========================
        raise NotImplementedError()

    @staticmethod
    def _pad_matrix_to_44(mtx):
        padded = np.eye(4, dtype=mtx.dtype)
        padded[:mtx.shape[0], :mtx.shape[1]] = mtx
        return padded

    @staticmethod
    def projection_to_camera(projected_coords, calib: Calibration):
        """
        @param projection_coords: (..., 3) -- xyw
        @return: homogenous coordinates in camera frame, shape (..., 4)
        """
        projected_coords = Projector._check_and_maybe_cast_to_homogenous(projected_coords)
        return projected_coords @ np.linalg.inv(Projector._pad_matrix_to_44(calib['P2']) @ calib['R0_rect']).T

Класс `Detector2DWrapper` - обертка над предобученным 2D детектором объектов на картинке из зоопарка моделей torchvision. Хотя в датасете доступны координаты объектов на картинке, будет некорректным учить frustum pointnet на "идеальных" кропах т.к. на этапе инференса они будут недоступны. Поэтому для обучения потребуется предварительно сопоставить результат 2D детектора с gt 2D боксами объектов и использовать для обучения наиболее близкий бокс из детектора. Эта логика должна быть реализована в методе `match_with_gt_objects`

- **(1 балл)** реализуйте жадный матчинг gt объектов к предсказанным 2д боксам из детектора по iou в функции `match_with_gt_objects`

In [None]:
import torch.utils.data as data
import torchvision
import torchvision.models.detection as detection_zoo
import numpy as np

from typing import TypedDict, Union, Literal, List, Tuple
from numpy.typing import NDArray

from kitti_dataset import Label


class Detector2DWrapper:
    def __init__(self):
        self.camera_weights = detection_zoo.FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
        self.camera_model = detection_zoo.fasterrcnn_resnet50_fpn_v2(weights=self.camera_weights, box_score_thresh=0.925)
        self.camera_model.eval()
        self.cuda = False
        self.image_preprocess = self.camera_weights.transforms()

    def evaluate(self, image) -> torch.FloatTensor:
        """
        @param image: Tensor of shape (3, H, W), value range [0, 255]
        @return: camera bboxes as tensor of shape (N, 4). Detections should be filtered to only contain
        classes that are present in KITTI

        """
        SELECTED_LABELS = ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck']
        image = (image / 255.0).float()
        if self.cuda:
            image = image.cuda()
        with torch.no_grad():
            image_batch = [self.image_preprocess(image)]
            predictions = self.camera_model(image_batch)
        prediction = {x: y.cpu() for x, y in predictions[0].items()}

        # filter labels
        mask = [
            self.camera_weights.meta["categories"][cat] in SELECTED_LABELS
            for cat in prediction["labels"]]
        mask = torch.tensor(mask)
        if not mask.size(0) or not mask.any():
            return prediction['boxes']
        prediction = {x: y[mask] for x, y in prediction.items()}
        return prediction['boxes']

    @staticmethod
    def match_with_gt_objects(gt_objects: List[Label], cam_bboxes: torch.FloatTensor, iou_threshold=0.5):
        """
        @param gt_objects: objects in KITTI scene
        @param cam_bboxes: detections bboxes returned from Detector2DWrapper.evaluate for this KITTI scene
        @param iou_threshold: minimum IOU for detection and GT entry to be considered matched

        @return: entries of gt_objects, for which a matching camera detection was found. gt `bbox` field should be
        replaced with found detection bbox. For each gt entry, pick the best (in terms of bbox IoU) detection
        if their IoU is at least iou_threshold. Do not drop 'DontCare' labels here, we need them for metrics later.
        If some gt object has no predicted object with iou higher than threshold, drop it.
        """

        gt_bboxes = [x['bbox'] for x in gt_objects]
        ans = []
        # ============YOUR CODE========================
        # you may use torchvision.ops.box_iou() to compute pairwise iou of gt_bboxes and cam_bboxes
        raise NotImplementedError()
        # ==============================================
        return ans

    def to_cuda(self):
        self.camera_model = self.camera_model.cuda()
        self.image_preprocess = self.image_preprocess.cuda()
        self.cuda = True

    def to_cpu(self):
        self.camera_model = self.camera_model.cpu()
        self.image_preprocess = self.image_preprocess.cpu()
        self.cuda = False

In [None]:
detector_2d_wrapper = Detector2DWrapper()

Класс `FrustumDataset` реализован за вас, он принимает на вход датасет kitti, объекты классов Detector2DWrapper и Projector и возвращает необходимые для обучения данные.

In [None]:
from frustum_dataset import FrustumDataset

frustum_train = FrustumDataset(train_kitti, detector_2d_wrapper, Projector(), cuda=True)
frustum_val = FrustumDataset(val_kitti, detector_2d_wrapper, Projector(), cuda=True)


Каждый сэмпл датасета - это
- frustum с облаком (поле 'cloud') выровненным т.ч. ось z проходила по центру кропа объекта, которому этот frustum соответствует
- картинка с кропом (поле 'image'; нужна только для визуализаций, в обучении не используется),
- gt сегментация облака (поле 'cloud_segmentation'),
- закодированны размер бокса - поле `size_idx` с индексом класса, поле `size_residual` с поправками к размеру относительно типичных размеров в классах
- закодированная ориентация бокса - поле `heading_idx` с индексом класса и поле `heading_residual` с поправками к центрам бинов с ориентацией
- поле `world_location` c координатами центра бокса
- индекс сцены в датасете kitti `kitti_scene_idx`
- `frustrum_rotation_angle` - угол насколько frustum был повернут

In [None]:
frustum_train[1]

Повизуализируем gt данные для сегментации с помощью plotly

In [None]:
from plotly.offline import init_notebook_mode
from plotly.offline import iplot
from plotly.offline import plot
from plotly import graph_objs as go
init_notebook_mode()

def plotly_add_cloud(fig, cloud, color, colorscale=None, name=None, showscale=None, cmin=None, cmax=None):
    fig.add_scatter3d(x=cloud[..., 0], y=cloud[..., 1], z=cloud[..., 2], name=name, mode='markers',
        marker=dict(size=2, color=color, colorscale=colorscale, showscale=showscale, cmin=cmin, cmax=cmax))


fig = go.Figure(layout=dict(scene=dict(
        aspectmode='data',
        xaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
        yaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
        zaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
    ), plot_bgcolor='#000', paper_bgcolor='black'))

frustum_scene = frustum_train[2]
kitti_scene = train_kitti[3]

plotly_add_cloud(fig, frustum_scene['cloud'], frustum_scene['cloud_segmentation'].astype(np.int32), colorscale='Reds', cmin=0, cmax=1, name='Velodyne')
# plotly_add_cloud(fig, frustum_scene['object_cloud'], frustum_scene['object_cloud'][:, -1], colorscale='Reds', cmin=0, cmax=1, name='Velodyne')
# plotly_add_cloud(fig, kitti_scene['cloud'], np.ones_like(kitti_scene['cloud'][:, 0]), colorscale='Blues', cmin=0, cmax=0, name='Velodyne')
iplot(fig)

Для референса в датасете также доступны картинки, но для обучения они нам не понадобятся:

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 9))
plt.imshow(torch.moveaxis(frustum_train.kitti_dataset[0]['image'], 0, -1))

for x1, y1, x2, y2 in detector_2d_wrapper.evaluate(frustum_train.kitti_dataset[0]['image']):
    plt.plot([x1, x1, x2, x2, x1], [y1, y2, y2, y1, y1], color='red', linewidth=2)

for item in frustum_train.kitti_dataset[0]['labels']:
    x1, y1, x2, y2 = item['bbox']
    plt.plot([x1, x1, x2, x2, x1], [y1, y2, y2, y1, y1], color='green', linewidth=1.5)
plt.show()

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(torch.moveaxis(train_kitti[1]['image'], 0, -1))
plt.show()


# Image is only available for visualization purposes; do not use it in network
plt.figure()
plt.imshow(torch.moveaxis(frustum_train[7]['image'], 0, -1))
plt.show()

## 3. Обучение pointnet'a для сегментации облака (2 балла)

Учим часть про сегментацию. Вам необходимо дописать часть про подсчет лосса и функцию получения предсказаний для метрик по результатам прогона сетки.

В качестве лосса для сегментации используется `nn.functional.cross_entropy` и `feature_transform_regularizer` как регуляризатор на результат SpatialTransformerNetwork, использующийся внутри поинтнета для сегментации.

В функции `get_predictions_for_metrics` на выходе ожидаются вероятности за класс 1 и gt метки для точек.

**(1 балл)**: дописан недостающий код. Лосс уменьшается в ходе обучения, метрики вычисляются

**(1 балл)**: сеть достигла хотя бы 0.96 ROC AUC на валидации

In [None]:
import sklearn.metrics

def calculate_segmentation_loss(gt, segmentation_output, transform_regularizer_weight):
    """
    @return: tuple of (classification loss, regularization from pointnet). Both are scalars (tensors of empty shape)
    """
    # ==========YOUR CODE GOES HERE==================
    # don't forget to multiply regularizer_loss on transform_regularizer_weight
    segm, trans = segmentation_output
    segmentation_loss = ...
    regularizer_loss = ...
    return segmentation_loss, regularizer_loss


def get_predictions_for_metrics(gt, segmentation_output):
    """
    @return: tuple of (batch_pred, batch_true) -- for every point in every batch element, predicted probability and gt label.
    Both tensors are of shape (batch_sz, NUM_POINTS)
    """
    # ==========YOUR CODE GOES HERE==================
    segm, trans = segmentation_output
    positive_class_probs = ...
    gt_labels = ...
    return positive_class_probs, gt_labels


class RunningMean:
    def __init__(self):
        self.cnt = 0
        self.sum = 0

    def add(self, value):
        self.sum += value
        self.cnt += 1

    def get(self):
        return 0 if self.cnt == 0 else self.sum / self.cnt


def train_segmentation_pointnet_one_epoch(segmentation_pointnet, train_data_generator, optimizer, transform_regularizer_weight):
    mean_segm_loss = RunningMean()
    mean_reg_loss = RunningMean()

    segmentation_pointnet.train()

    for batch_idx, batch in enumerate(train_data_generator):
        optimizer.zero_grad()

        batch = {x: y.cuda() for x, y in batch.items()}
        segmentation_output = segmentation_pointnet(batch['cloud'].float())

        segm_loss, reg_loss = calculate_segmentation_loss(batch, segmentation_output, transform_regularizer_weight)
        loss = segm_loss + reg_loss
        loss.backward()
        optimizer.step()
        mean_segm_loss.add(segm_loss)
        mean_reg_loss.add(reg_loss)

        print(f'Train batch: {batch_idx:4d}/{len(train_data_generator)} segmentation: {mean_segm_loss.get():.3f}, '
              f'segmentation trans reg: {mean_reg_loss.get():.3f}', end='\r')
    print()
    return mean_segm_loss.get()


def eval_segmentation_model(segmentation_pointnet, val_data_generator, transform_regularizer_weight):
    mean_segm_loss = RunningMean()
    mean_reg_loss = RunningMean()
    with torch.no_grad():
        segmentation_pointnet.eval()
        Y_pred = []
        Y_true = []
        for batch_idx, batch in enumerate(val_data_generator):
            batch = {x: y.cuda() for x, y in batch.items()}
            segmentation_output = segmentation_pointnet(batch['cloud'].float())
            segm_loss, reg_loss = calculate_segmentation_loss(batch, segmentation_output, transform_regularizer_weight)
            batch_pred, batch_true = get_predictions_for_metrics(batch, segmentation_output)
            Y_pred.append(batch_pred.cpu())
            Y_true.append(batch_true.cpu())
            mean_segm_loss.add(segm_loss)
            mean_reg_loss.add(reg_loss)
            print(f'Valid batch: {batch_idx:4d}/{len(val_data_generator)} segmentation: {mean_segm_loss.get():.3f}, '
                  f'segmentation trans reg: {mean_reg_loss.get():.3f}', end='\r')
    print()
    Y_pred = torch.cat(Y_pred)
    Y_true = torch.cat(Y_true)
    print(f'Valid AP: {sklearn.metrics.average_precision_score(Y_true, Y_pred):.3f}, '
          f'ROC AUC: {sklearn.metrics.roc_auc_score(Y_true, Y_pred):.3f}' )
    return mean_segm_loss.get()


def train_segmentation_pointnet(segmentation_pointnet, frustum_train, frustum_val, batch_size=16, n_epochs=90,
                                transform_regularizer_weight=1.0):
    train_data_generator = data.DataLoader(frustum_train, batch_size, shuffle=True, drop_last=True, pin_memory=True)
    val_data_generator = data.DataLoader(frustum_val, batch_size, shuffle=True, drop_last=True, pin_memory=True)
    optim = torch.optim.Adam([
            {'params': segmentation_pointnet.parameters()}
        ], lr=1e-3, weight_decay=1e-6)
    lr_sched = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.2)

    best_model_loss = None
    for epoch in range(n_epochs):
        print('Epoch:', epoch)
        train_segmentation_pointnet_one_epoch(segmentation_pointnet, train_data_generator, optim, transform_regularizer_weight)

        model_loss = eval_segmentation_model(segmentation_pointnet, val_data_generator, transform_regularizer_weight)

        if best_model_loss is None or model_loss < best_model_loss:
            best_model_loss = model_loss
            torch.save(segmentation_pointnet.state_dict(), 'segmentation.pth')
            print('new best model is saved')
        print()
        lr_sched.step()

In [None]:
segmentation_pointnet = SegmentationPointNet().cuda()
train_segmentation_pointnet(segmentation_pointnet, frustum_train, frustum_val,
                            batch_size=16, n_epochs=90,
                            transform_regularizer_weight=1.0)


In [None]:
segmentation_pointnet = SegmentationPointNet().cuda()
segmentation_pointnet.load_state_dict(torch.load('segmentation.pth'))
segmentation_pointnet.eval()

## 4. Обучение pointnet'a для детекции (4 балла)

В этом блоке предстоит дописать код в функциях `calculate_detection_loss`, `cloud_to_object_cloud` и `detection_output_to_center_dims_rot`

- **(1 балл)** Дописан весь недостающий код, обучение идет, лосс уменьшается
- **(3 балла)**: сеть достигла
  - хотя бы 0.80/0.80 precision/recall для автомобилей с min_iou=0.70
  - хотя бы 0.80/0.80 precision/recall для пешеходов с min_iou=0.25

In [None]:
from frustum_dataset import get_detection3d_corner_points, NUM_SIZE_CLUSTER, NUM_HEADING_BIN, g_mean_size_arr

OBJECT_CLOUD_SZ = 512

def calculate_detection_loss(gt, detection_output, transform_regularizer_weight):
    center, size_class, size_reg, heading_class, heading_reg, stn_translation = detection_output

    # ======== YOUR CODE =====================
    # center loss - huber loss between gt['world_location'] and `center` prediction
    # size_class_loss - cross-entropy between gt['size_idx'] and size_class
    # heading_class_loss - cross-entropy between gt['heading_idx'] and heading_class
    # regularizer_loss - calc via feature_transform_regularizer (don't forget to multiply it on transform_regularizer_weight)
    # size_reg_loss - huber loss between gt['size_residual'] and size_reg WHICH CORRESPOND TO GT CLASS gt['size_idx']
    # heading_reg_loss - humber loss between gt['heading_residual'] and heading_reg WHICH CORRESPOND TO GT CLASS gt['heading_idx']

    center_loss = ...
    size_class_loss = ...
    size_reg_loss = ...
    heading_class_loss = ...
    heading_reg_loss = ...
    regularizer_loss = ...

    # scale loss on custom value if needed
    heading_reg_loss = heading_reg_loss * 20
    size_reg_loss = size_reg_loss * 3 * 20
    return center_loss, size_class_loss, size_reg_loss, heading_class_loss, heading_reg_loss, regularizer_loss


def resample_cloud(cloud, positive_points_mask, num_samples=OBJECT_CLOUD_SZ):
    positive_points_mask = positive_points_mask.float() * 150
    object_idx = torch.multinomial(
        nn.Softmax(dim=1)(positive_points_mask),
        num_samples=num_samples,
        replacement=True)[..., None].expand(-1, -1, cloud.shape[-1])
    cloud = torch.gather(input=cloud, dim=1, index=object_idx)
    return cloud

def cloud_to_object_cloud(frustum_batch, segmentation_probas_batch):
    """
    @return: tuple (new_frustrum_batch, object_cloud_centers)
    where 'new_frustrum_batch' is copy of source batch with modified 'object_cloud'
    (instead of 'cloud') and 'world_location' (if present) fields and
    'object_cloud_centers' is tensor with mean points per sample

    'object_cloud' needs to be resampled for every batch element to only include points that were segmented as
    object points. New 'object_cloud' center is then subtracted from 'object_cloud' and 'world_location'
    """
    # ===============YOUR CODE======================
    # form new_frustrum_batch
    # 1) copy data for fields ['size_idx', 'size_residual', 'heading_idx', 'heading_residual', 'kitti_scene_idx']
    # from frustrum_batch
    # 2) resample 'object_cloud' from 'cloud' field using function 'resample_cloud' and
    # predicted segmentation mask as positive_points_mask. (you need to sample OBJECT_CLOUD_SZ points)
    # 3) compute mean point (in 3D) for each batch sample, substract it from 'object_cloud' and 'world_location'
    # keep the 4th coordinate (lidar intensity) unchanged

    positive_points_mask = ...
    object_points = ...
    object_mean_points = ...
    object_points_centered = ...

    new_frustrum_batch = {
        'object_cloud': object_points_centered,
    }
    if 'world_location' in frustum_batch:
        new_frustrum_batch['world_location'] = frustum_batch['world_location'] - object_mean_points

    for field in ['size_idx', 'size_residual', 'heading_idx', 'heading_residual', 'kitti_scene_idx']:
        if field in frustum_batch:
            new_frustrum_batch[field] = frustum_batch[field]
    return new_frustrum_batch, object_mean_points


def detection_output_to_center_dims_rot(detection_output):
    center, size_class, size_reg, heading_class, heading_reg, trans = detection_output
    #================YOUR CODE==================
    # decode bounding box and return predicted center, sizes and heading
    size_idx = ...
    size_reg_for_pred_class = ...
    sizes = g_mean_size_arr[size_idx] * (size_reg_for_pred_class + 1.0)

    heading_idx = ...
    heading_reg_for_pred_class = ...
    angle_step = (2 * np.pi / NUM_HEADING_BIN)
    heading = angle_step * (heading_idx + 0.5) + heading_reg_for_pred_class
    return center.detach().cpu(), sizes, heading


def visualize_detection_output(object_batch, detection_output, title):
    plt.figure(figsize=(9, 9))
    object_cloud = object_batch['object_cloud'][0].detach().cpu().numpy()
    plt.scatter(object_cloud[..., 0], object_cloud[..., 1])
    x_axis = 0
    y_axis = 1

    # GT
    center, dims, rotation_y = detection_output_to_center_dims_rot((
        object_batch['world_location'],
        torch.nn.functional.one_hot(object_batch['size_idx'], 8),
        object_batch['size_residual'],
        torch.nn.functional.one_hot(object_batch['heading_idx'], 12),
        object_batch['heading_residual'],
        None))
    gt_cuboid_pts = get_detection3d_corner_points(center[0], dims[0], rotation_y[0])
    plt.scatter(gt_cuboid_pts[..., 0], gt_cuboid_pts[..., 1], color='red')
    plt.axline(
        (center[0][x_axis], center[0][y_axis]),
        (center[0][x_axis] + np.cos(rotation_y[0]), center[0][y_axis] + np.sin(rotation_y[0])), color='red')
    plt.title(title)

    # Pred
    center, dims, rotation_y = detection_output_to_center_dims_rot(detection_output)
    cuboid_pts = get_detection3d_corner_points(center[0], dims[0], rotation_y[0])
    plt.scatter(cuboid_pts[..., x_axis], cuboid_pts[..., y_axis], color='orange')
    plt.axline(
        (center[0][x_axis], center[0][y_axis]),
        (center[0][x_axis] + np.cos(rotation_y[0]), center[0][y_axis] + np.sin(rotation_y[0])), color='orange')

    plt.gca().set_aspect('equal', adjustable='box')
    plt.show()


def train_detection_pointnet_one_epoch(detection_pointnet, segmentation_pointnet, train_data_generator,
                                       optim, transform_regularizer_weight):
    means = [RunningMean() for _ in range(6)]

    detection_pointnet.train()
    for batch_idx, batch in enumerate(train_data_generator):
        optim.zero_grad()
        batch = {x: y.cuda() for x, y in batch.items()}
        with torch.no_grad():
            segmentation_output, _ = segmentation_pointnet(batch['cloud'].float())

        object_batch, centers_offset = cloud_to_object_cloud(batch, segmentation_output)
        detection_output = detection_pointnet(object_batch['object_cloud'].float())

        detection_loss = calculate_detection_loss(object_batch, detection_output, transform_regularizer_weight)
        combined_loss = torch.sum(torch.stack(detection_loss))
        combined_loss.backward()
        optim.step()

        for i in range(6):
            means[i].add(detection_loss[i].item())

        print(f'Train batch: {batch_idx:4d}/{len(train_data_generator):-4d}    center: {means[0].get():.3f}, '
              f'size_class: {means[1].get():.3f}, size_reg: {means[2].get():.3f}, heading_class: {means[3].get():.3f}, '
              f'heading_reg: {means[4].get():.3f}, trans reg: {means[5].get():.3f}', end='\r')
    print()
    return


def eval_detection_model(detection_pointnet, segmentation_pointnet, val_data_generator, transform_regularizer_weight,
                        visualize_sample=False):
    means = [RunningMean() for _ in range(6)]
    with torch.no_grad():
        detection_pointnet.eval()
        for batch_idx, batch in enumerate(val_data_generator):
            batch = {x: y.cuda() for x, y in batch.items()}
            segmentation_output, _ = segmentation_pointnet(batch['cloud'].float())
            object_batch, centers_offset = cloud_to_object_cloud(batch, segmentation_output)
            detection_output = detection_pointnet(object_batch['object_cloud'].float())

            detection_loss = calculate_detection_loss(object_batch, detection_output, transform_regularizer_weight)
            combined_loss = torch.sum(torch.stack(detection_loss))
            for i in range(6):
                means[i].add(detection_loss[i].item())

            print(f'Valid batch: {batch_idx:4d}/{len(val_data_generator):-4d}    center: {means[0].get():.3f}, '
                  f'size_class: {means[1].get():.3f}, size_reg: {means[2].get():.3f}, heading_class: {means[3].get():.3f}, '
                  f'heading_reg: {means[4].get():.3f}, trans reg: {means[5].get():.3f}', end='\r')
    print()

    if visualize_sample:
        visualize_detection_output(object_batch, detection_output, f'Val sample result')

    mean_loss = sum(elem.get() for elem in means)

    return mean_loss


def train_detection_pointnet(detection_pointnet, segmentation_pointnet, frustum_train, frustum_val,
                             batch_size=16, n_epochs=90,
                             transform_regularizer_weight=1e-3):
    train_data_generator = data.DataLoader(frustum_train, batch_size, shuffle=True, drop_last=True, pin_memory=True)
    val_data_generator = data.DataLoader(frustum_val, batch_size, shuffle=True, drop_last=True, pin_memory=True)
    optim = torch.optim.Adam([
            {'params': detection_pointnet.parameters()}
        ], lr=1e-3, weight_decay=1e-6)
    lr_sched = torch.optim.lr_scheduler.StepLR(optim, step_size=30, gamma=0.2)

    best_model_loss = None

    segmentation_pointnet.eval()
    for epoch in range(n_epochs):
        print('Epoch:', epoch)
        train_detection_pointnet_one_epoch(detection_pointnet, segmentation_pointnet, train_data_generator,
                                           optim, transform_regularizer_weight)

        visualize_sample = epoch >= 30 and epoch % 10 == 0

        model_loss = eval_detection_model(detection_pointnet, segmentation_pointnet, val_data_generator,
                                          transform_regularizer_weight,
                                          visualize_sample)

        if best_model_loss is None or model_loss < best_model_loss:
            best_model_loss = model_loss
            torch.save(detection_pointnet.state_dict(), 'detection.pth')
            print('new best model is saved')
        print()
        lr_sched.step()


In [None]:
detection_pointnet = PointNetDetector().cuda()
train_detection_pointnet(detection_pointnet, segmentation_pointnet, frustum_train, frustum_val, batch_size=16, n_epochs=90,
                        transform_regularizer_weight=1e-3)

In [None]:
detection_pointnet = PointNetDetector().cuda()
detection_pointnet.load_state_dict(torch.load('detection.pth'))
detection_pointnet.eval()

Посчитаем метрики на тестовом датасете

In [None]:
from metrics import compute_confusion_mtx

In [None]:
test_kitti = KittiDataset(KITTI_ROOT, split='testing', only_easy=True)
frustum_test = FrustumDataset(test_kitti, detector_2d_wrapper, Projector(), cuda=True)
test_loader = data.DataLoader(frustum_test, 1, shuffle=False, drop_last=False, pin_memory=True)

In [None]:
evaluated_data = {}

with torch.no_grad():
    detection_pointnet.eval()
    segmentation_pointnet.eval()
    for batch_idx, batch in enumerate(tqdm.tqdm(test_loader)):
        batch = {x: y.cuda() for x, y in batch.items()}
        segmentation_output, _ = segmentation_pointnet(batch['cloud'].float())
        object_batch, centers_offset = cloud_to_object_cloud(batch, segmentation_output)
        detection_output = detection_pointnet(object_batch['object_cloud'].float())

        # we learn size bin to be equal to GT class, hence using it here
        size_idx = torch.argmax(detection_output[1].cpu(), dim=1).detach().numpy()
        center, dims, rotation_y = detection_output_to_center_dims_rot(detection_output)
        center += centers_offset.cpu()
        for frustum_idx in range(size_idx.shape[0]):
            kitti_scene_idx = batch['kitti_scene_idx'][frustum_idx]
            cuboid_pts = get_detection3d_corner_points(center[frustum_idx], dims[frustum_idx], rotation_y[frustum_idx])
            frustum_rotation = FrustumDataset._rot_44_by_angle_world(-batch['frustum_rotation_angle'][frustum_idx].cpu())

            rotated_cuboid_pts = Projector.from_homogenous_coords(Projector.to_homogenous_coords(cuboid_pts) @ frustum_rotation.T)
            evaluated_data.setdefault(int(kitti_scene_idx), []).append((rotated_cuboid_pts, size_idx[frustum_idx]))


In [None]:
from kitti_dataset import g_class2type

conf = compute_confusion_mtx(frustum_test.kitti_dataset, evaluated_data, Projector())

for cls_idx, cls_name in g_class2type.items():
    cm = conf[cls_idx]
    print(f'{cls_name:15s}: tp={cm["tp"]:4d}, fp={cm["fp"]:4d}, fn={cm["fn"]:4d}, '
          f'precision={cm["tp"] / (cm["tp"] + cm["fp"]):.3f}, recall={cm["tp"] / (cm["tp"] + cm["fn"]):.3f}')

## 5. Визуализация того, что получилось
### Предсказания для одного фруструма в 3D

In [None]:
from plotly.offline import init_notebook_mode
from plotly.offline import iplot
from plotly.offline import plot
from plotly import graph_objs as go


def plotly_add_cloud(fig, cloud, color, colorscale=None, name=None, showscale=None, cmin=None, cmax=None):
    fig.add_scatter3d(x=cloud[..., 0], y=cloud[..., 1], z=cloud[..., 2], name=name, mode='markers',
        marker=dict(size=2, color=color, colorscale=colorscale, showscale=showscale, cmin=cmin, cmax=cmax))

def plotly_add_cuboid(fig, points, color, opacity=0.4):
    fig.add_mesh3d(
        # 8 vertices of a cube
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],

        i = [6, 0, 0, 0, 4, 4, 7, 7, 4, 0, 2, 3],
        j = [2, 4, 1, 3, 5, 7, 5, 3, 0, 1, 7, 2],
        k = [0, 6, 3, 2, 7, 6, 1, 1, 5, 5, 6, 7],
        opacity=opacity,
        color=color,
        flatshading = True
    )



In [None]:
# NOTE: val frustum here!
test_loader = data.DataLoader(frustum_val, 1, shuffle=True, drop_last=True, pin_memory=True)

with torch.no_grad():
    test_batch = next(iter(test_loader))
    test_batch = {x: y.cuda() for x, y in test_batch.items()}


In [None]:
with torch.no_grad():
    segmentation_pointnet.eval()
    segmentation_output = segmentation_pointnet(test_batch['cloud'].float())
    object_batch, *_ = cloud_to_object_cloud(test_batch, segmentation_output[0])
    detection_pointnet.eval()
    detection_output = detection_pointnet(object_batch['object_cloud'].float())

# GT
center, dims, rotation_y = detection_output_to_center_dims_rot((
    object_batch['world_location'],
    torch.nn.functional.one_hot(object_batch['size_idx'], 8),
    object_batch['size_residual'],
    torch.nn.functional.one_hot(object_batch['heading_idx'], 12),
    object_batch['heading_residual'],
    None))
gt_cuboid_pts = get_detection3d_corner_points(center[0], dims[0], rotation_y[0])

# Pred
center, dims, rotation_y = detection_output_to_center_dims_rot(detection_output)

fig = go.Figure(layout=dict(scene=dict(
        aspectmode='data',
        xaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
        yaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
        zaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
    ), plot_bgcolor='#000', paper_bgcolor='black'))
pred_cuboid_pts = get_detection3d_corner_points(center[0], dims[0], rotation_y[0])


plotly_add_cuboid(fig, pred_cuboid_pts, 'red', 0.3)
plotly_add_cuboid(fig, gt_cuboid_pts, 'green', 0.3)

object_cloud = object_batch['object_cloud'].detach().cpu().numpy()

plotly_add_cloud(fig, object_cloud[0], object_cloud[0, ..., -1], colorscale='Reds', cmin=0, cmax=1, name='Velodyne')
iplot(fig)

### Предсказания для сцены

In [None]:
fig = go.Figure(layout=dict(scene=dict(
        aspectmode='data',
        xaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
        yaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
        zaxis=dict(showbackground=False, gridcolor="rgba(0, 0, 0, 0)"),
    ), plot_bgcolor='#000', paper_bgcolor='black', height=1000))

kitti_idx = 15
cloud = frustum_test.kitti_dataset[kitti_idx]['cloud']

plotly_add_cloud(fig, cloud, cloud[..., -1], colorscale='Blues', cmin=0, cmax=1, name='Velodyne')
for obj_cuboid, obj_cls in evaluated_data.get(kitti_idx, []):
    plotly_add_cuboid(fig, obj_cuboid, 'red', 0.3)

for label in frustum_test.kitti_dataset[kitti_idx]['labels']:
    # if label['type'] == 'DontCare':
    #     continue
    if label['dimensions'][0] < 0:
        continue
    world_location = Projector.from_homogenous_coords(
        Projector.camera_to_world(np.array(label['location']), frustum_test.kitti_dataset[kitti_idx]['calibration']))
    cuboid_pts = get_detection3d_corner_points(
        torch.tensor(world_location),
        torch.tensor(label['dimensions']),
        np.pi / 2 - label['rotation_y'])
    # cuboid_pts = Projector.camera_to_world(cuboid_pts, frustum_val.kitti_dataset[kitti_idx]['calibration'])
    plotly_add_cuboid(fig, cuboid_pts, 'yellow' if label['type'] == 'DontCare' else 'green', 0.3)

plt.figure(figsize=(12, 9))
plt.imshow(torch.moveaxis(frustum_test.kitti_dataset[kitti_idx]['image'], 0, -1))
plt.show()

iplot(fig)

In [None]:
frustum_test.kitti_dataset[kitti_idx]['labels']