In [3]:
import os
import random
import math

import numpy as np
import pandas as pd
import pdb
from collections import OrderedDict
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import collections  as mc
matplotlib.rcParams['figure.figsize'] = [6, 6]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms

from data_helper import UnlabeledDataset, LabeledDataset
from helper import draw_box

# random.seed(0)
# np.random.seed(0)
# torch.manual_seed(0);

# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = '/scratch/brs426/data'
annotation_csv = '/scratch/brs426/data/annotation.csv'

# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
train_labeled_scene_index = np.arange(106, 132)
val_labeled_scene_index = np.arange(132, 134)
test_labeled_scene_index = np.arange(132, 134)

from helper import compute_ats_bounding_boxes, compute_ts_road_map

In [4]:
def round_up(x):
    return int(math.ceil(x / 50.0)) * 50

def round_down(x):
    return round_up(x) - 50

In [7]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        self.encoder = torchvision.models.resnet50()
        self.encoder.fc = nn.Identity()
        self.concat_dim = 180 * 6
        
        self.compress = nn.Sequential(OrderedDict([
            ('linear0', nn.Linear(2048, 180)),
            ('drop', nn.Dropout(p = 0.5)),
            ('relu', nn.ReLU()),
        ]))
        
        self.classification = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(self.concat_dim, 256)),
        ]))
        
        self.x_offset = nn.Sequential(OrderedDict([
            ('xoff1', nn.Linear(self.concat_dim, 256)),
            ('tanh1', nn.Tanh())
        ]))
        
        self.y_offset = nn.Sequential(OrderedDict([
            ('yoff1', nn.Linear(self.concat_dim, 256)),
            ('tanh2', nn.Tanh())
        ]))
        
        self.counts = nn.Sequential(OrderedDict([
            ('count1', nn.Linear(self.concat_dim, 90))
        ]))
        
        self.segmentation = nn.Sequential(OrderedDict([
            ('linear1_segmentation', nn.Linear(self.concat_dim, 25600)),
            ('sigmoid', nn.Sigmoid())
        ]))
        
    def forward(self, x):
        batch_size = x.shape[0]
        num_images = x.shape[1]
        channels = x.shape[2]
        height = x.shape[3]
        width = x.shape[4]
        # Reshape here
        x = x.view(-1, channels, height, width)
        x = self.encoder(x)
        x = self.compress(x)
        x = x.view(-1, self.concat_dim)
        return self.classification(x), self.counts(x), self.segmentation(x), self.x_offset(x), self.y_offset(x)

In [36]:
device = torch.device("cpu")
model = SimpleModel().to(device)
model.load_state_dict(torch.load('/scratch/vr1059/all_six_images_classify_count_offset.pt', map_location=device))

<All keys matched successfully>

In [37]:
def collate_fn(batch):
    BLOCK_SIZE = 5
    images = []
    target = []
    road_maps = []
    road_bins = []
    bbs = []
    target_counts = []
    for x in batch:
        
        grid = []
        # Get road_image and cast it to float
        road_image = torch.as_tensor(x[2])
        road_maps.append(road_image)
        road_image = road_image.float()
        
        # Split up into blocks and assign pixel value for block
        for x_ in range(0, 800, BLOCK_SIZE):
            for y in range(0, 800, BLOCK_SIZE):
                block = road_image[x_:x_+BLOCK_SIZE, y:y+BLOCK_SIZE]
                score = torch.sum(block).item()
                # If more than have the pixels are 1, classify as road
                if score > (BLOCK_SIZE**2) / 2:
                    grid.append(1.0)
                else:
                    grid.append(0.0)
                
        road_bins.append(torch.Tensor(grid))
        
        # Collect six images for this sample. 
        six_images = []
        for i in range(6):
            six_images.append(torch.Tensor(x[0][i]))
        
        
        # target
        bb_tens = x[1]['bounding_box']
        current_bbs = []
        bins = np.zeros(256)
        counts = np.zeros(90)
        count = 0
        
        for i, corners in enumerate(bb_tens):
#             if x[1]['category'][i] not in [1, 3, 6, 8]:
            # Grab the current bounding box. 
            current_bbs.append(corners)

            # Get its four bird's-eye view coordinates.
            point_squence = torch.stack([corners[:, 0], corners[:, 1], corners[:, 3], corners[:, 2]])
            xs = point_squence.T[0] * 10 + 400
            ys = -point_squence.T[1] * 10 + 400

            # Find the bin/grid cell it falls in, get its class mapping. 
            center_x, center_y = torch.mean(xs).item(), torch.mean(ys).item()
            key = (round_down(center_x), round_down(center_y))
            if key not in class_dict:
                print(key)
            bin_id = class_dict[key]
            bins[bin_id] = 1
            count += 1
            
        
        counts[count] = 1

        # Label Smoothing #
        if count > 10 and count < 88:
            counts[count+1] = 0.2
            counts[count-1] = 0.2
        target_counts.append(torch.Tensor(counts))
        
        images.append(torch.stack(six_images))
                
        target.append(torch.Tensor(bins))
        
        bbs.append(current_bbs)
                
    boom = torch.stack(images), torch.stack(target), torch.stack(road_maps), bbs, torch.stack(target_counts), torch.stack(road_bins)
    return boom

In [38]:
test_transform = transforms.ToTensor()
labeled_testset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=test_labeled_scene_index,
                                  transform=test_transform,
                                  extra_info=True
                                 )

test_loader = torch.utils.data.DataLoader(labeled_testset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [39]:

class_label = 0
class_dict = dict()
reverse_class_dict = []
for i in range(0, 800, 50):
    for j in range(0, 800, 50):
        class_dict[(i, j)] = class_label
        class_label += 1
        reverse_class_dict.append((i, j))


model.eval()

threat_scores = 0

with torch.no_grad():
    for i, (sample, target, road_img, bbs, target_count, road_bins) in enumerate(test_loader):
        bb_samples = []
        sample = sample.to(device)
        target = target.to(device)
        road_bins = road_bins.to(device)
        target_count = target_count.to(device)
        
        y_hat, y_count, segmentation, x_offset, y_offset = model(sample)
        
        if torch.argmax(y_count).item() > 15:
            result = torch.topk(y_hat, k = 6 + torch.argmax(y_count).item())
            pred_ids = result.indices
        else:
            result = torch.topk(y_hat, k = torch.argmax(y_count).item())
            pred_ids = result.indices

        bounding_boxes = []
        for idx in pred_ids[0]:
            bin_x, bin_y = reverse_class_dict[idx.item()]
            
            bin_x_center, bin_y_center = bin_x + 25, bin_y + 25
                        
            x_off = x_offset[0][idx.item()] * 25
            y_off = y_offset[0][idx.item()] * 25
            
            new_center_x = bin_x_center + x_off
            new_center_y = bin_y_center + y_off
            
            xs = torch.Tensor([new_center_x - 25, new_center_x - 25, new_center_x + 25, new_center_x + 25])
            ys = torch.Tensor([new_center_y - 10, new_center_y + 10, new_center_y - 10, new_center_y + 10])

            xs = xs - 400
            ys = 800 - ys # right-side up
            ys = ys - 400

            xs /= 10.
            ys /= 10.

            coords = torch.stack((xs, ys))
            bounding_boxes.append(coords)

        bounding_boxes = torch.stack(bounding_boxes).double()
        bb_samples.append(bounding_boxes)
        bb_samples = tuple(bb_samples)
                
        bb_samples = bb_samples[0].cpu()
        bbs = torch.stack(bbs[0]).cpu()
        ts_bounding_box = compute_ats_bounding_boxes(bb_samples, bbs)
        
        threat_scores += ts_bounding_box
    
    print("Average threat score", threat_scores / len(test_loader))        

Average threat score tensor(0.0121)


In [None]:
# Best is 0.012 so far