In [15]:
import os
import random
import math

import numpy as np
import pandas as pd
import pdb
from collections import OrderedDict
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import collections  as mc
matplotlib.rcParams['figure.figsize'] = [6, 6]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms

from data_helper import UnlabeledDataset, LabeledDataset
from helper import draw_box

# random.seed(0)
# np.random.seed(0)
# torch.manual_seed(0);

# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = '/scratch/brs426/data'
annotation_csv = '/scratch/brs426/data/annotation.csv'

# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
train_labeled_scene_index = np.arange(106, 128)
val_labeled_scene_index = np.arange(128, 132)
test_labeled_scene_index = np.arange(132, 134)

In [16]:
def round_up(x):
    return int(math.ceil(x / 50.0)) * 50

def round_down(x):
    return round_up(x) - 50

In [17]:
class_label = 0
class_dict = dict()
reverse_class_dict = []
for i in range(0, 800, 50):
    for j in range(0, 800, 50):
        class_dict[(i, j)] = class_label
        class_label += 1
        reverse_class_dict.append((i, j))

In [4]:
class_dict

{(0, 0): 0,
 (0, 50): 1,
 (0, 100): 2,
 (0, 150): 3,
 (0, 200): 4,
 (0, 250): 5,
 (0, 300): 6,
 (0, 350): 7,
 (0, 400): 8,
 (0, 450): 9,
 (0, 500): 10,
 (0, 550): 11,
 (0, 600): 12,
 (0, 650): 13,
 (0, 700): 14,
 (0, 750): 15,
 (50, 0): 16,
 (50, 50): 17,
 (50, 100): 18,
 (50, 150): 19,
 (50, 200): 20,
 (50, 250): 21,
 (50, 300): 22,
 (50, 350): 23,
 (50, 400): 24,
 (50, 450): 25,
 (50, 500): 26,
 (50, 550): 27,
 (50, 600): 28,
 (50, 650): 29,
 (50, 700): 30,
 (50, 750): 31,
 (100, 0): 32,
 (100, 50): 33,
 (100, 100): 34,
 (100, 150): 35,
 (100, 200): 36,
 (100, 250): 37,
 (100, 300): 38,
 (100, 350): 39,
 (100, 400): 40,
 (100, 450): 41,
 (100, 500): 42,
 (100, 550): 43,
 (100, 600): 44,
 (100, 650): 45,
 (100, 700): 46,
 (100, 750): 47,
 (150, 0): 48,
 (150, 50): 49,
 (150, 100): 50,
 (150, 150): 51,
 (150, 200): 52,
 (150, 250): 53,
 (150, 300): 54,
 (150, 350): 55,
 (150, 400): 56,
 (150, 450): 57,
 (150, 500): 58,
 (150, 550): 59,
 (150, 600): 60,
 (150, 650): 61,
 (150, 700): 62,

In [5]:
reverse_class_dict[7]

(0, 350)

In [29]:
def collate_fn(batch):
    BLOCK_SIZE = 5
    images = []
    target = []
    road_maps = []
    road_bins = []
    bbs = []
    target_counts = []
    for x in batch:
        
        grid = []
        # Get road_image and cast it to float
        road_image = torch.as_tensor(x[2])
        road_maps.append(road_image)
        road_image = road_image.float()
        
        # Split up into blocks and assign pixel value for block
        for x_ in range(0, 800, BLOCK_SIZE):
            for y in range(0, 800, BLOCK_SIZE):
                block = road_image[x_:x_+BLOCK_SIZE, y:y+BLOCK_SIZE]
                score = torch.sum(block).item()
                # If more than have the pixels are 1, classify as road
                if score > (BLOCK_SIZE**2) / 2:
                    grid.append(1.0)
                else:
                    grid.append(0.0)
                
        road_bins.append(torch.Tensor(grid))
        
        # Collect six images for this sample. 
        six_images = []
        for i in range(6):
            six_images.append(torch.Tensor(x[0][i]))
        
        road_imgs.append(torch.as_tensor(x[2]))
        
        # target
        bb_tens = x[1]['bounding_box']
        current_bbs = []
        bins = np.zeros(256)
        counts = np.zeros(90)
        count = 0
        
        for i, corners in enumerate(bb_tens):
#             if x[1]['category'][i] not in [1, 3, 6, 8]:
            # Get its four bird's-eye view coordinates. 
            point_squence = torch.stack([corners[:, 0], corners[:, 1], corners[:, 3], corners[:, 2]])
            xs = point_squence.T[0] * 10 + 400
            ys = -point_squence.T[1] * 10 + 400

            # Grab the current bounding box. 
            current_bbs.append((xs, ys))

            # Find the bin/grid cell it falls in, get its class mapping. 
            center_x, center_y = torch.mean(xs).item(), torch.mean(ys).item()
            key = (round_down(center_x), round_down(center_y))
            if key not in class_dict:
                print(key)
            bin_id = class_dict[key]
            bins[bin_id] = 1
            count += 1
            
        
        counts[count] = 1

        # Label Smoothing #
        if count > 10 and count < 88:
            counts[count+1] = 0.2
            counts[count-1] = 0.2
        target_counts.append(torch.Tensor(counts))
        
        images.append(torch.stack(six_images))
                
        target.append(torch.Tensor(bins))
        
        bbs.append(current_bbs)
                
    boom = torch.stack(images), torch.stack(target), torch.stack(road_maps), bbs, torch.stack(target_counts), torch.stack(road_bins)
    return boom

In [30]:
# The labeled dataset can only be retrieved by sample.
# And all the returned data are tuple of tensors, since bounding boxes may have different size
# You can choose whether the loader returns the extra_info. It is optional. You don't have to use it.
val_transform = transforms.ToTensor()

train_transform = transforms.Compose([
    transforms.RandomApply([
        transforms.ColorJitter(brightness = 0.5, contrast = 0.5, saturation = 0.4, hue = (-0.5, 0.5)),
        transforms.Grayscale(3),
#         transforms.RandomAffine(3),
    ]),
    transforms.ToTensor(),
])


labeled_trainset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_labeled_scene_index,
                                  transform=train_transform,
                                  extra_info=True
                                 )
labeled_valset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=val_labeled_scene_index,
                                  transform=val_transform,
                                  extra_info=True
                                 )

train_loader = torch.utils.data.DataLoader(labeled_trainset, batch_size=16, num_workers=2, shuffle=True, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(labeled_valset, batch_size=16, num_workers=2, shuffle=True, collate_fn=collate_fn)



#### Testing

In [None]:
sample, target, road_img, bbs, counts = iter(train_loader).next()

In [None]:
idx = -1

In [None]:
idx += 1

In [None]:
print(torch.argmax(counts[idx]))

In [None]:
plt.imshow(torchvision.utils.make_grid(sample[idx], nrow=3).numpy().transpose(1, 2, 0))
plt.axis('off');

In [None]:
fig, ax = plt.subplots()
ax.imshow(road_img[idx], cmap ='binary');
ax.plot(400, 400, 'x', color="red")

# `target` is 32 by 81. Find the indices where there's a 1. 
bin_ids = (target[idx] == 1).nonzero()
for bin_id in bin_ids:
    class_box = reverse_class_dict[bin_id]
    
    draw_box(ax, class_box, 'green')
    
def append_first_to_last(tens):
    ret = torch.cat((tens, torch.as_tensor([tens[0]])))
    return ret

    
for bb in bbs[idx]:
    ax.plot(append_first_to_last(bb[0]), append_first_to_last(bb[1]), color='orange')
    


In [None]:
def draw_box(ax, class_box, color):
    box_xs = [class_box[0], class_box[0], class_box[0]+50, class_box[0]+50, class_box[0]]
    box_ys = [class_box[1], class_box[1]+50, class_box[1]+50, class_box[1], class_box[1]]
    ax.plot(box_xs, box_ys, color=color)

### model

In [20]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        self.encoder = torchvision.models.resnet50()
        self.encoder.fc = nn.Identity()
        self.concat_dim = 200 * 6
        
        self.compress = nn.Sequential(OrderedDict([
            ('linear0', nn.Linear(2048, 200)),
            ('drop', nn.Dropout(p = 0.5)),
            ('relu', nn.ReLU()),
        ]))
        
        self.classification = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(self.concat_dim, 200)),
        ]))
        
        self.counts = nn.Sequential(OrderedDict([
            ('count1', nn.Linear(self.concat_dim, 90))
        ]))
        
        self.segmentation = nn.Sequential(OrderedDict([
            ('linear1_segmentation', nn.Linear(self.concat_dim, 25600)),
            ('sigmoid', nn.Sigmoid())
        ]))
        
    def forward(self, x):
        batch_size = x.shape[0]
        num_images = x.shape[1]
        channels = x.shape[2]
        height = x.shape[3]
        width = x.shape[4]
        # Reshape here
        x = x.view(-1, channels, height, width)
        x = self.encoder(x)
        x = self.compress(x)
        x = x.view(-1, self.concat_dim)
        return self.classification(x), self.counts(x), self.segmentation(x)

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel()

# Weighting certain classes more. 
# positive_weight = torch.ones(256).to(device)
# for (x, y), bin_id in class_dict.items():
#     if abs(x - 400) <= 200 and abs(y - 400) <= 200:
#         positive_weight[bin_id] = 2

# for name, param in model.encoder.named_parameters():
#     if("bn" not in name):
#         param.requires_grad = False
        
# unfreeze_layers = [model.encoder.layer3, model.encoder.layer4]
# for layer in unfreeze_layers:
#     for param in layer.parameters():
#         param.requires_grad = True
        
model = model.to(device)
bin_criterion = nn.BCEWithLogitsLoss()
count_criterion = nn.BCEWithLogitsLoss()
segmentation_criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
best_val_loss = 100

In [22]:
def mixup_data(x, y, alpha=0.2, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).to(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a.float()) + (1 - lam) * criterion(pred, y_b.float())

In [None]:
def train():
    model.train()
    labeled_trainset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_labeled_scene_index,
                                  transform=train_transform,
                                  extra_info=True
                                 )
    train_loader = torch.utils.data.DataLoader(labeled_trainset, batch_size=10, num_workers=3, shuffle=True, collate_fn=collate_fn)
    
    train_losses = []
    bin_losses = []
    count_losses = []
    segmentation_losses = []
    for i, (sample, target, road_img, bbs, target_count, road_bins) in enumerate(train_loader):

        optimizer.zero_grad()

        sample = sample.to(device)
        target = target.to(device)
        target_count = target_count.to(device)
        road_bins = road_bins.to(device)
        
#         sample, target_a, target_b, lam = mixup_data(sample, target)
#         sample, target_a, target_b = map(torch.autograd.Variable, (sample, target_a, target_b))
        
    # Why were you doing this?
#         batch_yhat = []
#         batch_ycount = []
#         for j, x in enumerate(sample):
#             y_hat, y_count, segmentation = model(x)
#             batch_yhat.append(y_hat)
#             batch_ycount.append(y_count)
        
#         y_hat = torch.stack(batch_yhat).squeeze()
#         y_count = torch.stack(batch_ycount).squeeze()
        
    
        y_hat, y_count, segmentation = model(sample)
        
        # Mixup criterion here
#         bin_loss = mixup_criterion(bin_criterion, y_hat, target_a, target_b, lam)
        
        bin_loss = bin_criterion(y_hat, target.float())
        count_loss = count_criterion(y_count, target_count.float())
        segmentation_loss = segmentation_criterion(segmentation, road_bins.float())
        loss = bin_loss + 2 * count_loss + segmentation_loss
        
        train_losses.append(loss.item())
        bin_losses.append(bin_loss.item())
        count_losses.append(count_loss.item())
        segmentation_losses.append(segmentation_loss.item())

        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(sample), len(train_loader.dataset),
                50. * i / len(train_loader), loss.item()))
            
    print("\nAverage Train Epoch Loss: ", np.mean(train_losses))
    print("Average Train Bin Epoch Loss: ", np.mean(bin_losses))
    print("Average Train Count Epoch Loss: ", np.mean(count_losses))
    print("Average Train Segmentation Epoch Loss: ", np.mean(segmentation_losses))
            
def val():
    model.eval()
    val_losses = []
    bin_losses = []
    count_losses = []
    segmentation_losses = []
    count_correct = 0
    count_off_by_1 = 0
    total_count = 0
    bin_correct = 0
    total_bins = 0
    for i, (sample, target, road_img, bbs, target_count) in enumerate(val_loader):

        model.eval()

        sample = sample.to(device)
        target = target.to(device)
        target_count = target_count.to(device)

        with torch.no_grad():
#             batch_yhat = []
#             batch_ycount = []
#             for j, x in enumerate(sample):
#                 y_hat, y_count = model(x)
#                 batch_yhat.append(y_hat)
#                 batch_ycount.append(y_count)

#             y_hat = torch.stack(batch_yhat).squeeze()
#             y_count = torch.stack(batch_ycount).squeeze()
            
#             for j, x in enumerate(y_count):
#                 pred_count = torch.argmax(y_count[j])
#                 t_count = torch.argmax(target_count[j])
#                 if pred_count == t_count:
#                     count_correct += 1
#                 elif abs(pred_count - t_count) == 1:
#                     count_off_by_1 += 1
                      
#                 pred_bins = torch.topk(y_hat[j], k = pred_count).indices
#                 t_bins = (target[j] == 1).nonzero()
#                 for b in pred_bins:
#                     if b in t_bins:
#                         bin_correct += 1
                    
#                 total_bins += len(t_bins)
            
            y_hat, y_count, segmentation = model(sample)
#             total_count += y_count.size(0)
            
            bin_loss = bin_criterion(y_hat, target.float())
            count_loss = count_criterion(y_count, target_count.float())
            segmentation_loss = segmentation_criterion(segmentation, road_bins.float())
            loss = bin_loss + count_loss + segmentation_loss

            val_losses.append(loss.item())
            bin_losses.append(bin_loss.item())
            count_losses.append(count_loss.item())
            segmentation_losses.append(segmentation_loss.item())
            
    print("Average Validation Epoch Loss: ", np.mean(val_losses))
    print("Average Validation Bin Epoch Loss: ", np.mean(bin_losses))
    print("Average Validation Count Epoch Loss: ", np.mean(count_losses))
    print("Average Train Segmentation Epoch Loss: ", np.mean(segmentation_losses))
#     print("\tAverage Validation Count Accuracy: ", 100*count_correct/total_count)
#     print("\tAverage Validation Count-off-by-1 Accuracy: ", 100*count_off_by_1/total_count)
#     if total_bins != 0:
#         print("\tAverage Validation Bin Accuracy: ", 100*bin_correct/total_bins)
#     print("\n")
    global best_val_loss
    if np.mean(val_losses) < best_val_loss:
        best_val_loss = np.mean(val_losses)
        torch.save(model.state_dict(), '/scratch/brs426/all_six_images_classify_count.pt')

In [None]:
epochs = 40
for epoch in range(epochs):
    train()
    val()

In [5]:
# 60 x 200
# 10 x 1200
x = torch.randn((12, 5))

In [6]:
x

tensor([[-1.2217e-01, -7.3901e-01, -1.8286e+00, -7.0720e-01,  5.5493e-01],
        [ 8.8764e-01, -1.8905e-01, -7.1408e-02,  9.6601e-01, -2.5093e+00],
        [ 4.9115e-01, -1.9423e+00, -1.1985e+00,  1.1864e+00,  1.8880e+00],
        [-2.1469e-01,  2.9941e-01,  1.9331e+00,  4.4181e-01, -1.7342e+00],
        [-3.7422e-01,  1.4801e+00, -8.4294e-01, -2.2830e+00, -1.5767e-01],
        [-1.5516e-01,  7.6859e-01,  7.9272e-01, -9.6438e-01, -3.8129e-01],
        [-1.2397e+00, -8.6256e-01, -1.8095e+00,  9.0200e-01, -7.2423e-01],
        [-6.2345e-01, -1.9860e+00,  5.2594e-01, -6.8100e-01, -1.9903e-01],
        [ 1.8554e+00, -3.2209e-02,  1.5327e+00, -1.2397e+00,  1.1343e+00],
        [ 1.4491e+00, -8.2009e-01, -6.2748e-01,  1.9462e+00,  1.8221e-03],
        [-8.3885e-01,  1.1912e+00,  8.6688e-01, -3.4989e-01,  5.7890e-01],
        [-5.4305e-01, -7.8100e-01, -1.9309e+00, -8.4028e-01, -2.5602e-01]])

In [9]:
y = x.view(2, -1)
y

tensor([[-1.2217e-01, -7.3901e-01, -1.8286e+00, -7.0720e-01,  5.5493e-01,
          8.8764e-01, -1.8905e-01, -7.1408e-02,  9.6601e-01, -2.5093e+00,
          4.9115e-01, -1.9423e+00, -1.1985e+00,  1.1864e+00,  1.8880e+00,
         -2.1469e-01,  2.9941e-01,  1.9331e+00,  4.4181e-01, -1.7342e+00,
         -3.7422e-01,  1.4801e+00, -8.4294e-01, -2.2830e+00, -1.5767e-01,
         -1.5516e-01,  7.6859e-01,  7.9272e-01, -9.6438e-01, -3.8129e-01],
        [-1.2397e+00, -8.6256e-01, -1.8095e+00,  9.0200e-01, -7.2423e-01,
         -6.2345e-01, -1.9860e+00,  5.2594e-01, -6.8100e-01, -1.9903e-01,
          1.8554e+00, -3.2209e-02,  1.5327e+00, -1.2397e+00,  1.1343e+00,
          1.4491e+00, -8.2009e-01, -6.2748e-01,  1.9462e+00,  1.8221e-03,
         -8.3885e-01,  1.1912e+00,  8.6688e-01, -3.4989e-01,  5.7890e-01,
         -5.4305e-01, -7.8100e-01, -1.9309e+00, -8.4028e-01, -2.5602e-01]])

In [None]:
# 0.253 val bin loss

# Random Affine 3 degrees
# 0.263 val bin loss


# Need to do 5 * bin_loss + count_loss or something like that. Also more extreme Random Affine maybe?

# Random Affine 5 degrees
# 0.266 val bin loss

# Took out Random Affine. 
# 0.268 val bin loss

# Increased compress dim from 128 to 200. 
# 0.259 val bin loss

# 5 * bin_loss + count_loss
# 0.249 + 0.055

# 8 *
# 0.251 + 0.054

# 8*, RandomAffine(3)
# 0.255

# 8*, RandomAffine(3), weight_decay 0.1

# 10 *, RandomAffine(3)
# 0.259

# 8 *, Normalize (mean, std)
# 0.26

# 8 *, Dropout
# 0.241, 0.253

# 5 *, Dropout
# 0.254

# 11 *, Dropout
# 0.249

# Want to try positive-weights for classes within 200 to 600. 
# Want to get the model to get those classes correct. 

# Mixup 0.2, 1 *, Dropout
# (0.244, 0.053), 


## Testing Model Output

In [None]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        self.encoder = torchvision.models.resnet50()
        self.encoder.fc = nn.Identity()
        self.concat_dim = 200 * 6
        
        self.compress = nn.Sequential(OrderedDict([
            ('linear0', nn.Linear(2048, 200)),
            ('drop', nn.Dropout(p = 0.5)),
            ('relu', nn.ReLU()),
        ]))
        
        self.classification = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(self.concat_dim, 256)),
        ]))
        
        self.counts = nn.Sequential(OrderedDict([
            ('count1', nn.Linear(self.concat_dim, 90))
        ]))
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.compress(x)
        x = x.view(-1, self.concat_dim)
        return self.classification(x), self.counts(x)

In [None]:
model = SimpleModel()
model.load_state_dict(torch.load('all_six_images_classify_count.pt'))
model.cuda()
model.eval()

In [None]:
class_label = 0
class_dict = dict()
reverse_class_dict = []
for i in range(0, 800, 50):
    for j in range(0, 800, 50):
        class_dict[(i, j)] = class_label
        class_label += 1
        reverse_class_dict.append((i, j))

In [None]:
def get_bounding_boxes(samples):
    
    # samples is (batch_size, 6, 3, 256, 306)
    
    # You need to return a tuple with size batch_size and each element is a cuda tensor [N, 2, 4]
    # where N is the number of bounding boxes. 
    
    # Okay so I have my model. 
    # 
    
    bb_samples = []
    
    for x in samples:
        preds_class, preds_count = model(x)
        
        # preds class is a 256-dimensional tensor, filled with probabilities
        # I need to find the `preds_count` top indices with the top values.
        
        
        result = torch.topk(preds_class, k = torch.argmax(preds_count).item())
        pred_ids = result.indices
        
        bounding_boxes = []
        for idx in pred_ids[0]:
            buck_x, buck_y = reverse_class_dict[idx.item()]
            
            xs = torch.as_tensor([buck_x, buck_x, buck_x + 50, buck_x + 50])
            ys = torch.as_tensor([buck_y+16, buck_y+36, buck_y+16, buck_y+36])
            
            xs = xs - 400
            ys = 800 - ys # right-side up
            ys = ys - 400
            
            xs /= 10
            ys /= 10
               
            coords = torch.stack((xs, ys))
            bounding_boxes.append(coords)
            
        bounding_boxes = torch.stack(bounding_boxes).cuda()
        bb_samples.append(bounding_boxes)
    
    return tuple(bb_samples)
    

In [None]:
sample, target, road_img, bbs, counts = iter(val_loader).next()

In [None]:
sample = sample.cuda()

In [None]:
boom = get_bounding_boxes(sample)

In [None]:
sigmoid_preds = torch.sigmoid(model(sample[idx])[0]).squeeze()

In [None]:
idx = -1

In [None]:
idx += 1

In [None]:
plt.imshow(torchvision.utils.make_grid(sample[idx].cpu().detach(), nrow=3).numpy().transpose(1, 2, 0))
plt.axis('off');

In [None]:
fig, ax = plt.subplots()
ax.imshow(road_img[idx], cmap ='binary');
ax.plot(400, 400, 'x', color="red")

# `target` is 32 by 81. Find the indices where there's a 1. 

bin_ids = (sigmoid_preds > 0.25).nonzero()
for bin_id in bin_ids:
    class_box = reverse_class_dict[bin_id]
    draw_vish_box(ax, class_box, 'red')
    
bin_ids = (target[idx] == 1).nonzero()
for bin_id in bin_ids:
    class_box = reverse_class_dict[bin_id]
    draw_vish_box(ax, class_box, 'green')

    
for bb in boom[idx]:
    box = bb.cpu().detach()
    draw_box(ax, box, 'orange')

In [None]:
torch.stack([box[:, 0], box[:, 1], box[:, 3], box[:, 2], box[:, 0]])

def draw_box(ax, corners, color):
    point_squence = torch.stack([corners[:, 0], corners[:, 1], corners[:, 3], corners[:, 2], corners[:, 0]])
    
    # the corners are in meter and time 10 will convert them in pixels
    # Add 400, since the center of the image is at pixel (400, 400)
    # The negative sign is because the y axis is reversed for matplotlib
    ax.plot(point_squence.T[0] * 10 + 400, -point_squence.T[1] * 10 + 400, color=color)
    return point_squence.T[0] * 10 + 400, -point_squence.T[1] * 10 + 400

In [None]:
def draw_vish_box(ax, class_box, color):
    box_xs = [class_box[0], class_box[0], class_box[0]+50, class_box[0]+50, class_box[0]]
    box_ys = [class_box[1], class_box[1]+50, class_box[1]+50, class_box[1], class_box[1]]
    ax.plot(box_xs, box_ys, color=color)