In [1]:
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import transforms, utils

from colour_demosaicing import demosaicing_CFA_Bayer_bilinear,demosaicing_CFA_Bayer_Malvar2004, demosaicing_CFA_Bayer_Menon2007

from utils import RAISE2_People_Dataset, Demosaicing, Resize, Normalize, get_dataset_stats
from architecture import Net
from model import Model

In [2]:
mean = torch.tensor([1078.7721,1902.7466,1580.9901])[:,None,None]
std = torch.tensor([324.2676,609.9332,616.3193])[:,None,None]

In [3]:
trainset = RAISE2_People_Dataset(csv_file='./data/huge_train.csv',
                             root_dir = './data/images/huge/train',
                                    transform = transforms.Compose([
                                        Demosaicing(demosaic=demosaicing_CFA_Bayer_bilinear,pattern='RGGB'),
                                        Resize((256,256)), Normalize(mean,std)]))

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4)


A = RAISE2_People_Dataset(csv_file='./data/huge_test.csv',
                             root_dir = './data/images/huge/test',
                                    transform = transforms.Compose([
                                       Demosaicing(demosaic=demosaicing_CFA_Bayer_bilinear,pattern='RGGB'),
                                    Resize((256,256)), Normalize(mean,std)]))

testloader = torch.utils.data.DataLoader(A, batch_size=4, shuffle=True,num_workers=4)
'''
Apply a different transformation by changing the arrangement of the colour filters on the pixel array 
from 'RGGB' to 'GRBG'
'''

B = RAISE2_People_Dataset(csv_file='./data/huge_test.csv',
                             root_dir = './data/images/huge/test',
                                    transform = transforms.Compose([
                                       Demosaicing(demosaic=demosaicing_CFA_Bayer_bilinear,pattern='GRBG'),
                                    Resize((256,256)), Normalize(mean,std)]))

B_testloader = torch.utils.data.DataLoader(B, batch_size=4, shuffle=True, num_workers=4)

classes = ('no people', 'people')

In [4]:
import torchvision.models as models
import plotly.graph_objects as go

resnet18 = models.resnet18(pretrained=False)
resnet18.fc = torch.nn.Linear(in_features=512, out_features=2, bias=True)
model18 = Model(resnet18,trainloader,testloader,B_testloader)
PATH = './trained/raise_classifier_ResNet18_pre-trained_epochs_25_lr_0.001_95223.pth'
model18.net.load_state_dict(torch.load(PATH))

resnet34 = models.resnet34(pretrained=False)
resnet34.fc = torch.nn.Linear(in_features=512, out_features=2, bias=True)
model34 = Model(resnet34,trainloader,testloader,B_testloader)
PATH = './trained/raise_classifier_ResNet34_pre-trained_epochs_25_lr_0.001_61748.pth'
model34.net.load_state_dict(torch.load(PATH))

resnet50 = models.resnet50(pretrained=False)
resnet50.fc = torch.nn.Linear(in_features=2048, out_features=2, bias=True)
model50 = Model(resnet50,trainloader,testloader,B_testloader)
PATH = './trained/raise_classifier_ResNet50_pre-trained_epochs_25_lr_0.001_88068.pth'
model50.net.load_state_dict(torch.load(PATH))

resnet101 = models.resnet101(pretrained=False)
resnet101.fc = torch.nn.Linear(in_features=2048, out_features=2, bias=True)
model101 = Model(resnet101,trainloader,testloader,B_testloader)
PATH = './trained/raise_classifier_ResNet101_pre-trained_epochs_25_lr_0.001_37873.pth'
model101.net.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [5]:
models = [model18, model34, model50, model101]
acc_total = []

for model in models:
    acc_model = []
    acc_model.append(model.predict(trainloader))
    acc_model.append(model.predict(testloader))
    acc_model.append(model.predict(B_testloader))
    acc_total.append(acc_model)

In [64]:
L1=np.array(acc_total)
L2=np.array([nets]).T
L = np.block([L2,L1]).T

In [65]:
nets = ['resnet18', 'resnet34', 'resnet50', 'resnet101']

rowEvenColor = 'lightblue'
rowOddColor = 'white'

#L = np.array(acc_total).T
fig = go.Figure(data = [go.Table(
        header = dict(
            values = ['<b> Accuracy</b>','<b>training set</b>','<b>A</b>','<b>B</b>'],
            line_color = 'darkslategray',
            fill_color = 'black',
            align = ['left','center'],
            font=dict(color='white', size=14)
      ),
        cells=dict(
        values=L,
        line_color='darkslategray',
        fill_color = [[rowOddColor,rowEvenColor]*2],
        align = ['left', 'center'],
        font = dict(color = 'darkslategray', size = 12)
    ))

                           ])
fig.show()


In [21]:
fig = go.Figure(data = [go.Table(
        header = dict(
            values = ['<b> Accuracy</b>','<b>training set</b>','<b>A</b>','<b>B</b>'],
            line_color = 'darkslategray',
            fill_color = 'black',
            align = ['left','center'],
            font=dict(color='white', size=14)
      ),
        cells=dict(
        values=[nets]+acc_total,
        line_color='darkslategray',
        fill_color = [[rowOddColor,rowEvenColor]*2],
        align = ['left', 'center'],
        font = dict(color = 'darkslategray', size = 12)
    ))

                           ])
fig.show()

[[98.40546697038724, 68.4090909090909, 65.9090909090909],
 [99.65831435079727, 69.31818181818181, 66.13636363636364],
 [99.37357630979498, 67.27272727272727, 65.9090909090909],
 [98.51936218678816, 69.77272727272727, 66.36363636363636]]