In [35]:
import os
import random

import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim 

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box
import resnet_no_fc

from collections import OrderedDict

# Threat score for road detection
from helper import compute_ats_bounding_boxes, compute_ts_road_map

In [36]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")

In [37]:
# Image folder
image_folder = '/scratch/brs426/data'
annotation_csv = '/scratch/brs426/data/annotation.csv'

In [38]:
train_labeled_scene_index = np.arange(106, 132)
val_labeled_scene_index = np.arange(132, 134)
test_labeled_scene_index = np.arange(132, 134)

In [42]:
def segmentation_collate_fn(batch):
    BLOCK_SIZE = 5
    road_maps = []
    road_bins = []
    images = []
    for x in batch:
        
        grid = []
        # Collect six images for this sample. 
        six_images = []
        for i in range(6):
            six_images.append(torch.as_tensor(x[0][i]))
        
        # Get road_image and cast it to float
        road_image = torch.as_tensor(x[2])
        road_maps.append(road_image)
        road_image = road_image.float()
        
        for x in range(0, 800, BLOCK_SIZE):
            for y in range(0, 800, BLOCK_SIZE):
                block = road_image[x:x+BLOCK_SIZE, y:y+BLOCK_SIZE]
                score = torch.sum(block).item()
                # If more than have the pixels are 1, classify as road
                if score > (BLOCK_SIZE**2) / 2:
                    grid.append(1.0)
                else:
                    grid.append(0.0)
            
        images.append(torch.stack(six_images))
                
        road_bins.append(torch.as_tensor(grid))
                
    boom = torch.stack(images), torch.stack(road_bins), torch.stack(road_maps)
    return boom

In [43]:
transform = torchvision.transforms.ToTensor()
aug_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomApply([
        torchvision.transforms.ColorJitter(brightness = 0.5, contrast = 0.5, saturation = 0.4, hue = (-0.5, 0.5)),
        torchvision.transforms.Grayscale(3),
#         transforms.RandomAffine(3),
    ]),
    torchvision.transforms.ToTensor(),
])

In [44]:
transform = torchvision.transforms.ToTensor()

labeled_trainset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_labeled_scene_index,
                                  transform=aug_transform,
                                  extra_info=True
                                 )
trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=16, shuffle=True, num_workers=10, collate_fn=segmentation_collate_fn)

labeled_valset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=val_labeled_scene_index,
                                  transform=transform,
                                  extra_info=True
                                 )
valloader = torch.utils.data.DataLoader(labeled_valset, batch_size=1, shuffle=True, num_workers=2, collate_fn=segmentation_collate_fn)

labeled_testset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=test_labeled_scene_index,
                                  transform=transform,
                                  extra_info=True
                                 )
testloader = torch.utils.data.DataLoader(labeled_testset, batch_size=2, shuffle=True, num_workers=2, collate_fn=segmentation_collate_fn)

### Starting Segmentation Architecture

In [50]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        self.encoder = torchvision.models.resnet50()
        self.encoder.fc = nn.Identity()
        self.concat_dim = 200 * 6
        
        self.compress = nn.Sequential(OrderedDict([
            ('linear0', nn.Linear(2048, 200)),
            ('drop', nn.Dropout(p = 0.5)),
            ('relu', nn.ReLU()),
        ]))
        
        self.segmentation = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(self.concat_dim, 25600)),
            ('sigmoid', nn.Sigmoid())
        ]))
        
    def forward(self, x):
        batch_size = x.shape[0]
        num_images = x.shape[1]
        channels = x.shape[2]
        height = x.shape[3]
        width = x.shape[4]
        
        # Reshape to feed in images
        x = x.reshape(-1, channels, height, width)
        
        x = self.encoder(x)
        x = self.compress(x)
        x = x.view(-1, self.concat_dim)
        return self.segmentation(x)

### Training Logic

In [9]:
# Train logic, return average loss over training set after each epoch
def train(model, device, train_loader, optimizer, epoch, log_file_path, log_interval = 250):
    # Set model to training mode
    model.train()
    
    # Number correct for accuracy
    num_correct = 0
    
    # Train loss
    train_loss = 0
    
    f = open(log_file_path, "a+")
    # Loop through examples
    for batch_idx, (images, bins, road_map) in enumerate(train_loader):
        
        
        # Send data and target to device
        data, target = images.to(device), bins.to(device)
        
        # Zero out optimizer
        optimizer.zero_grad()
        
        # Pass data through model - right now only segmentation
        output = model(data)
        # Should be batch_size X 800 X 800
        output = output.squeeze()
        
        # Compute the loss
        loss = F.binary_cross_entropy(output, target)
        train_loss += loss.item()
        
        # Backpropagate loss
        loss.backward()
        
        # Make a step with the optimizer
        optimizer.step()
        
        # Print loss (uncomment lines below once implemented)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            f.write('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\n'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    
    # Average train loss
    average_train_loss = train_loss / len(train_loader)
    # Print loss (uncomment lines below once implemented)
    print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        average_train_loss, num_correct, len(train_loader.dataset),
        100. * num_correct / len(train_loader.dataset)))
    f.write('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        average_train_loss, num_correct, len(train_loader.dataset),
        100. * num_correct / len(train_loader.dataset)))
    f.close()

### Validation Logic

In [10]:
def reconstruct_from_bins(bins, block_size, threshold):
    road_map = torch.zeros((800, 800))
    idx = 0
    for x in range(0, 800, block_size):
        for y in range(0, 800, block_size):
            road_map[x:x+block_size, y:y+block_size] = bins[idx]
            idx += 1
    return road_map > threshold


# Define test method
def test(model, device, test_loader, log_file_path):
    # Set model to evaluation mode
    model.eval()
    # Variable for the total loss 
    test_loss = 0
    # Counter for the correct predictions
    num_correct = 0
    
#     thresholds = [0.42]
    
#     threat_scores = torch.zeros(1)
    
    f = open(log_file_path, "a+")
    # don't need autograd for eval
    with torch.no_grad():
        for batch_idx, (images, bins, road_map) in enumerate(test_loader):

            # Send data and target to device
            data, target= images.to(device), bins.to(device)

            # Pass data through model - right now only segmentation
            output = model(data)
            # Should be batch_size X 6400

            # Compute the loss
            loss = F.binary_cross_entropy(output, target)
            test_loss += loss.item()
            
            # Now squeeze for reconstruction
#             output = output.squeeze()
            
            # Compute threat score at 4 different thresholds
#              for idx in range(len(thresholds)):
#                 reconstructed_road_map = reconstruct_from_bins(output, 5, thresholds[idx]).cpu()
#                 ts_road_map = compute_ts_road_map(reconstructed_road_map, road_map)
#                 threat_scores[idx] += ts_road_map
         
    # Compute the average test_loss
    # avg_test_loss = TODO
    avg_test_loss = test_loss / len(test_loader)
    
    # Compute average threat scores
#     avg_threat_scores = threat_scores / len(test_loader)
    
#     print('\Threat scores: \t {}:{}\n'.format(thresholds[0], avg_threat_scores[0]))
    # Print loss (uncomment lines below once implemented)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        avg_test_loss, num_correct, len(test_loader.dataset),
        100. * num_correct / len(test_loader.dataset)))
    
#     f.write('\Threat scores: \t {}:{}\n'.format(thresholds[0], avg_threat_scores[0]))
    f.write('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        avg_test_loss, num_correct, len(test_loader.dataset),
        100. * num_correct / len(test_loader.dataset)))
    f.close()
    
    return avg_test_loss

In [11]:
# Segmentation model
model = SimpleModel().to(device)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=2e-4)
# Scheduler
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.25, patience=1)

In [None]:
f = "/scratch/brs426/ben_models/simple_model_block_size_25_test_threat_alone.log"
best_val_loss = 100
save_path = "/scratch/brs426/ben_models/simple_model_block_size_25_test_threat_alone.p"
epochs = 40
for epoch in range(1, epochs + 1):
    # Train model
    loss = train(model, device, trainloader, optimizer, epoch, f)
    val_loss = test(model, device, valloader, f)
#     scheduler.step(val_threat)
    #Save model
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), save_path)
        best_val_loss = val_loss


Train set: Average loss: 0.4101, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.5740, Accuracy: 0/252 (0%)


Train set: Average loss: 0.3447, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.3608, Accuracy: 0/252 (0%)


Train set: Average loss: 0.3026, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.3533, Accuracy: 0/252 (0%)


Train set: Average loss: 0.2754, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.3694, Accuracy: 0/252 (0%)


Train set: Average loss: 0.2570, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.3162, Accuracy: 0/252 (0%)


Train set: Average loss: 0.2427, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.2849, Accuracy: 0/252 (0%)


Train set: Average loss: 0.2301, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.2823, Accuracy: 0/252 (0%)


Train set: Average loss: 0.2179, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.3097, Accuracy: 0/252 (0%)


Train set: Average loss: 0.2060, Accuracy: 0/3276 (0%)


Test set: Average loss: 0.2617, Accuracy: 0/25

### Threat Score Evaluation

In [49]:
def reconstruct_from_bins(bins, block_size, threshold):
    road_map = torch.zeros((800, 800))
    idx = 0
    for x in range(0, 800, block_size):
        for y in range(0, 800, block_size):
            road_map[x:x+block_size, y:y+block_size] = bins[idx]
            idx += 1
    return road_map > threshold

# Predicting everything as road
model = SimpleModel().to(device)
model.load_state_dict(torch.load("/scratch/brs426/ben_models/simple_model_block_size_25_test_threat_alone.p"))
model.eval()

thresholds = [0.4]

threat_scores = torch.zeros(1)

with torch.no_grad():
    for batch_idx, (images, bins, road_map) in enumerate(valloader):

        # Send data and target to device
        data, target= images.to(device), bins.to(device)

        # Pass data through model - right now only segmentation
        output = model(data)
        # Should be batch_size X 6400

#         # Compute the loss
#         loss = F.binary_cross_entropy(output, target)
#         test_loss += loss.item()

        # Now squeeze for reconstruction
        output = output.squeeze()

        # Compute threat score at 4 different thresholds
        for idx in range(len(thresholds)):
            reconstructed_road_map = reconstruct_from_bins(output, 25, thresholds[idx]).cpu()
            ts_road_map = compute_ts_road_map(reconstructed_road_map, road_map)
            threat_scores[idx] += ts_road_map

avg_threat_scores = threat_scores / len(valloader)
print("Threshold 0.4: {}".format(avg_threat_scores))

Threshold 0.4: tensor([0.8311])
