In [1]:
!pip install albumentations
!pip install prettytable



In [2]:
import sys
import os
sys.path.append('/irad_users/smithk/beholder-interns/')

import torchvision
import torch
import numpy as np
from determined.pytorch import DataLoader
from prettytable import PrettyTable
from dataset import NuScenes, KITTI
from intern_dataset import InternData
from distresnet import DistResNet18
from mlpnet import DistanceRegressor
from distnet import DistResNeXt50

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from albumentations import (HorizontalFlip,
                            HueSaturationValue,
                            RandomBrightnessContrast,
                            Blur,
                            GaussNoise,
                            CLAHE,
                            CoarseDropout,
                            RGBShift,
                            BboxParams)

In [4]:
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                          torchvision.transforms.Normalize(
                                                              mean=[0.485, 0.456, 0.406],
                                                              std=[0.229, 0.224, 0.225]
                                                          )])

In [5]:
test_data = NuScenes(img_dir='/irad_mounts/lambda-quad-5-data/beholder/intern_data/nuscenes-full/',
                                  meta_path='/irad_mounts/lambda-quad-5-data/beholder/intern_data/nuscenes-full/nuscenes-v1.0.csv',
                                  split = 'test',
                                  augs = None,
                                  transforms = transforms,
                                  size = 1024,
                                  map_to_kitti=True
                                  )


testloader = DataLoader(test_data,
                        batch_size=1,
                        drop_last=True,
                        shuffle=False,
                        collate_fn=test_data.collate_fn)

In [6]:
test_data_mlp = NuScenes(img_dir='/irad_mounts/lambda-quad-5-data/beholder/intern_data/nuscenes-full/',
                                  meta_path='/irad_mounts/lambda-quad-5-data/beholder/intern_data/nuscenes-full/nuscenes-v1.0.csv',
                                  split = 'test',
                                  augs = None,
                                  transforms = transforms,
                                  size = 1024,
                                  map_to_kitti=True,
                                  mlp = True
                                  )
testloader_mlp = DataLoader(test_data_mlp,
                        batch_size=1,
                        drop_last=True,
                        shuffle=False,
                        collate_fn=test_data.collate_fn)

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [37]:
models = {}

In [39]:
resnext = DistResNeXt50(8, image_size=1024, pretrained=False, keypoints=False)
model_pth = "/irad_users/determined/checkpoints/3300d9d6-f061-419c-9b92-84ee6f6d7253/state_dict.pth"
resnext.load_state_dict(torch.load(model_pth)['models_state_dict'][0], strict=False)
resnext.to(device)
models['resnext'] = {'name': 'resnext', 'model': resnext}

In [40]:
resnet = DistResNet18(8, image_size=1024, pretrained=False,keypoints=False)
model_pth = "/irad_users/determined/checkpoints/90dab9a5-24cb-4ba2-8e49-db24e8f12d7c/state_dict.pth"
resnet.load_state_dict(torch.load(model_pth)['models_state_dict'][0], strict=False)
resnet.to(device)
models['resnet'] = {'name': 'resnet', 'model': resnet}

In [41]:
mlp = DistanceRegressor(n_features= 10)
model_pth = "/irad_users/determined/checkpoints/90dab9a5-24cb-4ba2-8e49-db24e8f12d7c/state_dict.pth"
mlp.load_state_dict(torch.load(model_pth)['models_state_dict'][0], strict=False)
mlp.to(device)
models['mlp'] = {'name': 'mlp', 'model': mlp}

In [42]:
# source: https://deci.ai/blog/measure-inference-time-deep-neural-networks/
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
for model_dict in models:

    model = models[model_dict]['model']
    num_res = 300
    timings=np.zeros((num_res,1))
    
    if model_dict == 'mlp':
        data = testloader_mlp
    else:
        data = testloader

    for i,batch in enumerate(data):
        # MEASURE PERFORMANCE
        inputs,boxes,distances,classes = batch[0],batch[1],batch[2],batch[3]
        inputs = inputs.to(device)
        boxes = [b.to(device) for b in boxes]
        classes = [c.to(device) for c in classes]

        # GPU warmup 
        if i < 5: 
            if model_dict != 'mlp':
                _ = model(inputs,boxes)
            else:
                inputs = torch.cat([torch.cat([bbox, class_encoding], dim=1) for bbox, class_encoding in zip(boxes, classes)])
                _ = model(inputs)
        with torch.no_grad():
            starter.record()
            if model_dict != 'mlp':
                _ = model(inputs,boxes)
            else:
                inputs = torch.cat([torch.cat([bbox, class_encoding], dim=1) for bbox, class_encoding in zip(boxes, classes)])
                _ = model(inputs)
            ender.record()
            # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[i] = curr_time
            if i >= num_res - 1:
                break
    mean_syn = np.sum(timings) / num_res
    std_syn = np.std(timings)
    models[model_dict]['time']=mean_syn
    models[model_dict]['std']=std_syn
    models[model_dict]['num_features'] = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())

In [47]:
tab = PrettyTable(['Model','Time (ms)','Std Deviation', 'Num Features'])
table = []
for model in models:
    table.append([models[model]['name'],models[model]['time'], models[model]['std'],models[model]['num_features']])

In [48]:
tab.add_rows(table)
print(tab)

+---------+---------------------+----------------------+--------------+
|  Model  |      Time (ms)      |    Std Deviation     | Num Features |
+---------+---------------------+----------------------+--------------+
| resnext |  27.351108601888022 | 0.30467881008816694  |   23644489   |
|  resnet |   8.14128205458323  | 0.41680646480497224  |   11312201   |
|   mlp   | 0.42664693256219227 | 0.027961753210412506 |   2792577    |
+---------+---------------------+----------------------+--------------+
