<a href="https://colab.research.google.com/github/vladd11/ImVisible/blob/master/LYTNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [2]:
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import math

import pandas as pd
import os
import numpy as np
from PIL import ImageFile, Image, ImageDraw
from torchvision import transforms
from torchvision.transforms import functional as F
import random, time
import torch.nn as nn

# Helpers

In [3]:
def my_loss(classifier, regression, points, mode):
    #classifier is the predicted class
    #regression is an array of predicted coordinates
    #points is an array of ground truth coordinates
    #mode is the ground truth class
    alpha = 0.5
    MSE = nn.MSELoss()
    MSEl = MSE(regression, points)
    cross_entropy = nn.CrossEntropyLoss()
    ce = cross_entropy(classifier, mode)
    
    loss = (ce*alpha + MSEl*(1-alpha)).float()
    return loss, MSEl, ce

In [4]:
def angle_difference(pred, label):
    #pred is the array of coordinate predictions
    #label is the array of ground truth coordinates
    
    #get all 4 coordinates from pred array
    pred_x1 = pred[0][0]
    pred_x2 = pred[0][2]
    pred_y1 = pred[0][1]
    pred_y2 = pred[0][3]
    
    pred_x_distance = pred_x1 - pred_x2 #distance between 2 x-coordinates
    pred_y_distance = pred_y2 - pred_y1 #distance between 2 y-coordinates

    pred_angle = math.atan2(pred_y_distance, pred_x_distance) #predicted angle between direction vector and x-axis
    
    #get all 4 coordinates from label array
    act_x1 = label[0][0]
    act_x2 = label[0][2]
    act_y1 = label[0][1]
    act_y2 = label[0][3]
    
    act_x_distance = act_x1 - act_x2
    act_y_distance = act_y2 - act_y1
    
    actual_angle = math.atan2(act_y_distance, act_x_distance)
    
    return (pred_angle - actual_angle)*180/math.pi #returns difference between predicted and ground truth angles

In [5]:
def startpoint_difference(pred, label):
    #the startpoint will always be the second set of coordinates in the pred and label array
    
    x_distance = pred[0][2] - label[0][2]
    y_distance = pred[0][3] - label[0][3]
    
    #distance between predicted and ground truth startpoints
    distance = math.sqrt(x_distance*x_distance + y_distance*y_distance)
    
    return distance

In [6]:
def endpoint_difference(pred, label):
    #the endpoint will always be the first set of coordinates in the pred and label array
    
    x_distance = pred[0][0] - label[0][0]
    y_distance = pred[0][1] - label[0][1]
    
    #distance between predicted and ground truth endpoints
    distance = math.sqrt(x_distance*x_distance + y_distance*y_distance)
    return distance

In [7]:
def direction_performance(pred, label):
    pred = pred.cpu().detach().numpy()
    label = label.cpu().detach().numpy()
    pred = pred.tolist()
    label = label.tolist()
    
    #gets the absolute error
    angle = math.fabs(angle_difference(pred,label))
    start = math.fabs(startpoint_difference(pred,label))
    end = math.fabs(endpoint_difference(pred,label))
    
    return angle, start, end

In [8]:
def display_image(image, title, points_pred, points_gt, factor):
    #factor is used to convert the coordinates from between [0,1] to desired image coordinates
    
    plt.imshow(image)
    plt.title(title)
    
    #plots predicted coordinates
    plt.scatter([points_pred[0]*factor*4,points_pred[2]*factor*4],[points_pred[1]*factor*3,points_pred[3]*factor*3], c = 'r')
    #plots ground truth coordinates
    plt.scatter([points_gt[0]*factor*4,points_gt[2]*factor*4],[points_gt[1]*factor*3,points_gt[3]*factor*3], c = 'b')
    plt.show()

# LYTNetv1 classes

In [9]:
def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True),
        nn.MaxPool2d(2,2),
    )

In [10]:
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

In [11]:
class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # depthwise
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pointwise
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pointwise
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # depthwise
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pointwise
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

In [12]:
class LYTNet(nn.Module):
    def __init__(self, n_class=5, input_size_x = 768, input_size_y = 576, width_mult=1.):
        super(LYTNet, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 1, 2],
            [6, 64, 3, 2],
            [6, 96, 1, 1],
            [6, 160, 2, 2],
            [6, 320, 1, 1],
        ]

        #first layer
        assert input_size_x % 64 == 0
        assert input_size_y % 64 == 0
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        
        #bottleneck blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        
        #last convolutional layer
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))

        self.features = nn.Sequential(*self.features,
                                      nn.Dropout(0.1))

        #classfier section of the network
        self.classifier_light = nn.Sequential(
            nn.Linear(self.last_channel, 160),
            nn.BatchNorm1d(160),
            nn.ReLU6(inplace = True),
            nn.Linear(160, 5),
            nn.Softmax()
        )
        
        # regression for direction
        self.regression_direction = nn.Sequential(
            nn.Linear(self.last_channel, 80),
            nn.BatchNorm1d(80),
            nn.ReLU6(inplace = True),
            nn.Linear(80, 4)
        )
        
        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x1 = self.classifier_light(x)
        x2 = self.regression_direction(x)
        return x1, x2

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    m.bias.data.zero_()
                nn.init.xavier_normal_(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

# LYTNet dataset class

In [13]:
class TrafficLightDataset(Dataset):
    
    def __init__(self, csv_file, img_dir, transformation = True):
        self.labels = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transformation = transformation

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        
        ImageFile.LOAD_TRUNCATED_IMAGES = True
        img_name = os.path.join(self.img_dir, self.labels.iloc[index, 0]) #gets image name in csv file
        image = Image.open(img_name)

        light_mode = self.labels.iloc[index, 1] #mode of the traffic light
        block = self.labels.iloc[index,6] #label of blocked or unblocked
        points = self.labels.iloc[index, 2:6] #midline coordinates
        points = [points[0]/4032, points[1]/3024, points[2]/4032, points[3]/3024] #normalize coordinate values to be between [0,1]

        if self.transformation:            
            #random horizontal flip with 50% probability
            num = random.random()
            if num >= 0.5:
                image = F.hflip(image)
                #flip x coordinates when entire image is flipped
                points[0] = 1 - points[0] 
                points[2] = 1 - points[2]
            
            #random crop
            cp = [points[0]*876, (1-points[1])*657, 876*points[2], (1-points[3])*657] #convert points to cartesian coordinates
            #shifts to determine what region to crop
            shiftx = random.randint(0, 108) 
            shifty = random.randint(0, 81)

            with np.errstate(all="raise"):
                try: m = (cp[1]-cp[3])/(cp[0]-cp[2]) #slope
                except: m = 10000000000000000 #prevent divide by zero error

            b = cp[1] - m*cp[0] #y-intercept
            
            #changing the coordinates based on the new cropped area
            if(shiftx > cp[0]): 
                cp[0] = shiftx
                cp[1] = (cp[0]*m + b)
            elif((768+shiftx) < cp[0]):
                cp[0] = (768+shiftx)
                cp[1] = (cp[0]*m + b)
            if(shiftx > cp[2]): 
                cp[2] = shiftx
                cp[3] = (cp[2]*m + b)
            elif((768+shiftx) < cp[2]):
                cp[2] = (768+shiftx)
                cp[3] = (cp[2]*m + b)
            if(657-shifty < cp[1]): 
                cp[1] = 657-shifty
                cp[0] = (cp[1]-b)/m if (cp[1]-b)/m>0 else 0
#            elif((657-576-shifty) > cp[1]):
#                cp[0] = (657-576-shifty-b)/m
#                cp[1] = 0
#                cp[2] = (657-576-shifty-b)/m
#                cp[3] = 0
            if(657-576-shifty > cp[3]): 
                cp[3] = 657-576-shifty
                cp[2] = (cp[3]-b)/m
#            elif((657-shifty) < cp[3]):
#                cp[3] = 657-shifty
#                cp[2] = (657-shifty-b)/m
#                cp[1] = 657-shifty
#                cp[0] = (657-shifty-b)/m

            #converting the coordinates from a 876x657 image to a 768x576 image
            cp[0] -= shiftx
            cp[1] -= (657-576-shifty)
            cp[2] -= shiftx
            cp[3] -= (657-576-shifty)

            #converting the cartesian coordinates back to image coordinates
            points = [cp[0]/768, 1-cp[1]/576, cp[2]/768, 1-cp[3]/576]
            
            image = F.crop(image, shifty, shiftx, 576, 768)
            transform = transforms.Compose([transforms.ColorJitter(0.05,0.05,0.05,0.01)])
            image = transform(image)
        
        #normalize image
        #image = transforms.functional.to_tensor(image)
        #image = transforms.functional.normalize(image, mean = [120.56737612047593, 119.16664454573734, 113.84554638827127], std=[66.32028460114392, 65.09469952002551, 65.67726614496246])
        
        image = np.transpose(image, (2, 0, 1))
        points = torch.FloatTensor(points)
        
        #combine all the info into a dictionary
        final_label = {'image': image, 'mode':light_mode, 'points': points, 'block': block}
        return final_label

# Training

In [14]:
cuda_available = torch.cuda.is_available()

BATCH_SIZE = 32
MAX_EPOCHS = 800
INIT_LR = 0.001
WEIGHT_DECAY = 0.00005
LR_DROP_MILESTONES = [400,600]

train_file_dir = '/gdrive/MyDrive/Colab Notebooks/lights/training_file.csv'
valid_file_dir = '/gdrive/MyDrive/Colab Notebooks/lights/validation_file.csv'
train_img_dir = '/gdrive/MyDrive/Colab Notebooks/lights/dataset'
valid_img_dir = '/gdrive/MyDrive/Colab Notebooks/lights/validation_dataset'
save_path = '/gdrive/MyDrive/Colab Notebooks/lights'

train_dataset = TrafficLightDataset(csv_file = train_file_dir, img_dir = train_img_dir)
valid_dataset = TrafficLightDataset(csv_file = valid_file_dir, img_dir = valid_img_dir)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=2)

net = LYTNet()

if cuda_available:
    net = net.cuda()

loss_fn = my_loss

optimizer = torch.optim.Adam(net.parameters(), lr = INIT_LR, weight_decay = 0.000005 )
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, LR_DROP_MILESTONES)

#storing all data during training
train_losses = [] #stores the overall training loss at each epoch
train_MSE = [] #stores the MSE loss during training at each epoch
train_CE = [] #stores the cross entropy loss during training at each epoch
valid_losses = [] #stores the overall validation loss at each epoch
valid_MSE = [] #stores the MSE loss during validation at each epoch
valid_CE = [] #stores the cross entropy loss during validation at each epoch
train_accuracies = [] #stores the training accuracy of the network at each epoch
valid_accuracies = [] #stores the validation accuracy of the network at each epoch
val_angles = [] #stores the average angle error of the network during validation at each epoch
val_start = [] #stores the average startpoint error of the network during validation at each epoch
val_end = [] #stores the average endpoint error of the network during validation at each epoch

  cpuset_checked))


In [15]:
for epoch in range(MAX_EPOCHS):
    
    ##########
    #TRAINING#
    ########## 
    
    net.train()
    
    running_loss = 0.0 #stores the total loss for the epoch
    running_loss_MSE = 0.0 #stores the total MSE loss for the epoch
    running_loss_cross_entropy = 0.0 #store the total cross entropy loss for the epoch
    angle_error = 0.0 #stores average angle error for the epoch
    startpoint_error = 0.0 #stores average startpoint error for the epoch
    endpoint_error = 0.0 #stores average endpoint error for the epoch
    train_correct = 0 #stores total number of correctly predicted images during training for the epcoh
    train_total = 0 #stores total number of batches processed at each epoch
    
    for j, data in enumerate(train_dataloader, 0): 
        
        optimizer.zero_grad()
        train_total += 1
        
        images = data['image'].type(torch.FloatTensor)
        mode = data['mode'] #index of traffic light mode
        points = data['points'] #array of midline coordinates
        
        if cuda_available:
            images = images.cuda()
            mode = mode.cuda()
            points = points.cuda()
        
        pred_classes, pred_direc = net(images)
        _, predicted = torch.max(pred_classes, 1) #finds index of largest probability
        train_correct += (predicted == mode).sum().item() #increments train_correct if predicted index is correct
        loss, MSE, cross_entropy = loss_fn(pred_classes, pred_direc, points, mode)
        angle, start, end = direction_performance(pred_direc, points)
        angle_error += angle
        endpoint_error += end
        startpoint_error += start
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_loss_MSE += MSE.item()
        running_loss_cross_entropy += cross_entropy.item()

    print('Epoch: ' + str(epoch+1))
    print('Average training loss: ' + str(running_loss/(j+1)))
    print('Average training MSE loss: ' + str(running_loss_MSE/(j+1)))
    print('Average training cross entropy loss: ' + str(running_loss_cross_entropy/(j+1)))
    print('Training accuracy: ' + str(train_correct/train_total/BATCH_SIZE))

    train_MSE.append(running_loss_MSE/train_total)
    train_CE.append(running_loss_cross_entropy/train_total)
    train_losses.append(running_loss/train_total) 
    train_accuracies.append(train_correct/train_total/32*100) 
            
    lr_scheduler.step(epoch + 1) #decrease learning rate if at desired epoch   
    
    ############
    #VALIDATION#
    ############ 
    
    net.eval()
    
    tp = {'0':0, '1':0, '2':0, '3':0, '4':0} #stores number of true positives for each class
    fp = {'0':0, '1':0, '2':0, '3':0, '4':0} #stores number of false positives for each class
    fn = {'0':0, '1':0, '2':0, '3':0, '4':0} #stores number of false negatives for each class
    
    precisions = [] #stores the precision for each class
    recalls = [] #stores the recall for each class
    
    #stores losses and errors for network during validation
    val_running_loss = 0
    val_mse_loss = 0
    val_ce_loss = 0
    val_angle_error = 0
    val_start_error = 0
    val_end_error = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        
        for k, data in enumerate(valid_dataloader, 0):
            
            images = data['image'].type(torch.FloatTensor)
            mode = data['mode']
            points = data['points']
            
            if cuda_available:
                images = images.cuda()
                mode = mode.cuda()
                points = points.cuda()
            
            pred_classes, pred_direc = net(images)
            _, predicted = torch.max(pred_classes, 1)
            val_correct += (predicted == mode).sum().item()
            val_total += 1
            
            #incorrect prediction
            if (predicted == mode).sum().item() == 0:
                fp[str(predicted.cpu().numpy()[0])] += 1 #increments predicted class's false positive count by one
                fn[str(mode.cpu().numpy()[0])] += 1 #increments correct class's false negative count by one
            
            #correct prediction
            if (predicted == mode).sum().item() == 1: 
                tp[str(predicted.cpu().numpy()[0])] += 1 #increments correct class's true positive count by one
            
            loss, MSE, cross_entropy = loss_fn(pred_classes, pred_direc, points, mode)
            val_running_loss += loss.item()
            val_mse_loss += MSE.item()
            val_ce_loss += cross_entropy.item()
            
            angle, start, end = direction_performance(pred_direc, points)
            val_angle_error += angle
            val_start_error += start
            val_end_error += end
            
        #calculates precision and recalls for each class given fp, tp, fn
        #try excepts are used to prevent division by zero errors
        try:red_precision = tp['0']/(tp['0'] + fp['0'])
        except: red_precision = 0
        precisions.append(red_precision)
        try: red_recall = tp['0']/(tp['0'] + fn['0'])
        except: red_recall = 0
        recalls.append(red_recall)
        
        try: green_precision = tp['1']/(tp['1'] + fp['1'])
        except: green_precision = 0
        precisions.append(green_precision)
        try: green_recall = tp['1']/(tp['1'] + fn['1'])
        except: green_recall = 0
        recalls.append(green_recall)
        
        try: countdown_green_precision = tp['2']/(tp['2'] + fp['2'])
        except: countdown_green_precision = 0
        precisions.append(countdown_green_precision)
        try: countdown_green_recall = tp['2']/(tp['2'] + fn['2'])
        except: countdown_green_recall = 0
        recalls.append(countdown_green_recall)
        
        try: countdown_blank_precision = tp['3']/(tp['3'] + fp['3'])
        except: countdown_blank_precision = 0
        precisions.append(countdown_blank_precision)
        try: countdown_blank_recall = tp['3']/(tp['3'] + fn['3'])
        except: countdown_blank_recall = 0
        recalls.append(countdown_blank_recall)
        
        try: none_precision = tp['4']/(tp['4'] + fp['4']) 
        except: none_precision = 0
        precisions.append(none_precision)
        try: none_recall = tp['4']/(tp['4'] + fn['4'])
        except: none_recall = 0
        recalls.append(none_recall)
        
        print("Average validation loss: " + str(val_running_loss/val_total))
        print("Average validation MSE loss: " + str(val_mse_loss/val_total))
        print("Average validation cross entropy loss: " + str(val_ce_loss/val_total))
        print("Validation accuracy: " + str(100*val_correct/val_total))
        
        valid_accuracies.append(100*val_correct/val_total)
        valid_losses.append(val_running_loss/val_total)
        valid_MSE.append(val_mse_loss/val_total)
        valid_CE.append(val_ce_loss/val_total)
        
        print("Precisions: " + str(precisions))
        print("Recalls: " + str(recalls))
        print("Angle Error: " + str(val_angle_error/val_total))
        print("Startpoint Error: " + str(val_start_error/val_total))
        print("Endpoint Error: " + str(val_end_error/val_total))
        
        val_angles.append(val_angle_error/val_total)
        val_start.append(val_start_error/val_total)
        val_end.append(val_end_error/val_total)
        
        #graphs average losses every epoch_num of epochs
        epoch_num = 100
        if epoch % epoch_num == (epoch_num - 1):
            plt.title('Train and Validation losses')
            plt.plot(train_losses)
            plt.plot(valid_losses)
            plt.show()
            
        #stores the network and optimizer weights every 200th epoch
        if epoch%50 == 49:
            states = {
                    'epoch': epoch+1,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict()
                    }
            torch.save(states, save_path + '_epoch_' + str(epoch+1))

KeyboardInterrupt: ignored

In [None]:
#plots training and validation loss
plt.title('train and validation loss')
plt.plot(valid_losses)
plt.plot(train_losses)
plt.savefig(save_path + '_losses')
plt.show()

In [None]:
#plots training and validation cross entropy loss
plt.title('train and valid cross entropy')
plt.plot(train_CE)
plt.plot(valid_CE)
plt.savefig(save_path + 'train_valid_ce')
plt.show()

In [None]:
#plots training and validation MSE loss
plt.title('train and valid MSE')
plt.plot(train_MSE)
plt.plot(valid_MSE)
plt.savefig(save_path + 'train_valid_MSE')
plt.show()

In [None]:
#plots training and validation accuracies
plt.title('train and validation accuracies')
plt.plot(valid_accuracies)
plt.plot(train_accuracies)
plt.savefig(save_path + '_accuracies')
plt.show()

In [None]:
#save final network weights
torch.save(net.state_dict(), save_path + '_final_weights')

# LYTNet testing

In [None]:
cuda_available = torch.cuda.is_available()

test_file_loc = '/gdrive/MyDrive/Colab Notebooks/lights/testing_file.csv'
test_image_directory = '/gdrive/MyDrive/Colab Notebooks/lights/images'
MODEL_PATH = '/gdrive/MyDrive/Colab Notebooks/lights/LytNetV2_weights'

dataset = TrafficLightDataset(csv_file = test_file_loc, img_dir = test_image_directory)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

net = LYTNet()
if cuda_available:
    checkpoint = torch.load(MODEL_PATH)
else:
    checkpoint = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
net.load_state_dict(checkpoint)
net.eval()

if cuda_available:
    net = net.cuda()

loss_fn = my_loss

#storing data
running_loss = 0
running_test_angle = 0
running_test_start = 0
running_test_end = 0

#errors when zebra crossing is blocked
running_angle_block = 0
running_start_block = 0
running_end_block = 0
block_count = 0

#errors when zebra crossing is unblocked
running_angle_unblock = 0
running_start_unblock = 0
running_end_unblock = 0
unblock_count = 0

total = 0
correct = 0

tp = {'0':0, '1':0, '2':0, '3':0, '4':0}
fp = {'0':0, '1':0, '2':0, '3':0, '4':0}
fn = {'0':0, '1':0, '2':0, '3':0, '4':0}
classes = {'0':'red', '1':'green', '2':'none', '3':'countdown_blank', '4':'countdown_green'}
precisions = []
recalls = []

start_time = time.time()

with torch.no_grad():
    
    for i, data in enumerate(dataloader):
        
        images = data['image'].type(torch.FloatTensor)
        mode = data['mode']
        points = data['points']
        blocked = data['block'] #tag for blocked zebra crossing
        
        if cuda_available:
            images = images.cuda()
            mode = mode.cuda()
            points = points.cuda()
 
        pred_classes, pred_direc = net(images)
        _, predicted = torch.max(pred_classes, 1)
        
        #correct prediction
        if (predicted == mode).sum().item() == 1:
            correct += 1
            tp[str(predicted.cpu().numpy()[0])] += 1
        
        #incorrect prediction
        if (predicted == mode).sum().item() == 0:
            predicted_idx = str(predicted.cpu().numpy()[0])
            gt_idx = str(mode.cpu().numpy()[0])
            fp[predicted_idx] += 1
            fn[gt_idx] += 1
            
            #display image when incorrect
            image = images.cpu().numpy()[0]
            image = np.transpose(image, (1,2,0))
            image = image.astype(int)
            
            title = 'predicted: ' + classes[predicted_idx] + ' ground_truth: ' + classes[gt_idx] + ' ' + str(i+1)            
            ax = plt.subplot()
            ax.axis('on')
            pred_points = pred_direc.cpu().detach().numpy()[0].tolist()
            gt_points = points.cpu().detach().numpy()[0]
            
            display_image(image,title,pred_points, gt_points, 192) #factor is 192 because 4*192 = 768

        loss, MSE, cross_entropy =  loss_fn(pred_classes, pred_direc, points, mode)
        running_loss += loss.item()
        angle, start, end = direction_performance(pred_direc, points)
        
        if(blocked[0] == "blocked"):
            running_angle_block += angle
            running_start_block += start
            running_end_block += end
            block_count += 1
            
        else:
            running_angle_unblock += angle
            running_start_unblock += start
            running_end_unblock += end
            unblock_count += 1
               
        running_test_angle += angle
        running_test_start += start
        running_test_end += end
        total += 1
        

try:red_precision = tp['0']/(tp['0'] + fp['0'])
except: red_precision = 0
precisions.append(red_precision)
try: red_recall = tp['0']/(tp['0'] + fn['0'])
except: red_recall = 0
recalls.append(red_recall)
            
try: green_precision = tp['1']/(tp['1'] + fp['1'])
except: green_precision = 0
precisions.append(green_precision)
try: green_recall = tp['1']/(tp['1'] + fn['1'])
except: green_recall = 0
recalls.append(green_recall)
            
try: countdown_green_precision = tp['2']/(tp['2'] + fp['2'])
except: countdown_green_precision = 0
precisions.append(countdown_green_precision)
try: countdown_green_recall = tp['2']/(tp['2'] + fn['2'])
except: countdown_green_recall = 0
recalls.append(countdown_green_recall)
            
try: countdown_blank_precision = tp['3']/(tp['3'] + fp['3'])
except: countdown_blank_precision = 0
precisions.append(countdown_blank_precision)
try: countdown_blank_recall = tp['3']/(tp['3'] + fn['3'])
except: countdown_blank_recall = 0
recalls.append(countdown_blank_recall)
            
try: blank_precision = tp['4']/(tp['4'] + fp['4']) 
except: blank_precision = 0
precisions.append(blank_precision)
try: blank_recall = tp['4']/(tp['4'] + fn['4'])
except: blank_recall = 0
recalls.append(blank_recall)
            
print("Average loss: " + str(running_loss/total))
print("Average angle error: " + str(running_test_angle/total))
print("Average startpoint error: " + str(running_test_start/total))
print("Average endpoint error: " + str(running_test_end/total))
print("Blocked angle error: " + str(running_angle_block/block_count))
print("Blocked startpoint error: " + str(running_start_block/block_count))
print("Blocked endpoint error: " + str(running_end_block/block_count))
print("Unblocked angle error: " + str(running_angle_unblock/unblock_count))
print("Unblocked startpoint error: " + str(running_start_unblock/unblock_count))
print("Unblocked endpoint error: " + str(running_end_unblock/unblock_count))
print("Accuracy: " + str(correct/total*100))

print("Precisions: " + str(precisions))
print("Recalls: " + str(recalls))
print("Time Elapsed: " + str(time.time() - start_time))