In [1]:
import numpy as np
import pandas as pd
from random import shuffle
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torchvision
from torch.utils.data import DataLoader, Dataset, random_split, ConcatDataset
from torchvision import transforms, datasets, models
from train import train_model, train_loss, train_accuracy, val_loss, val_accuracy
from torchsummary import summary
import time
import copy
import seaborn as sns
sns.set(font_scale=1.4)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = 32


In [2]:
def norm(dataset):
    output = []
    for sample in tqdm(dataset):
        newimg = []
        img = sample[0]
        label = sample[1]
        for s in img:
            s = s[0]
            s /= torch.max(s)
            s = torch.stack([s, s, s], 0)
            newimg.append(s)
        output.append((newimg, label))
    return output

dataset = torch.load('../../datasets/64skulldataset.pt')
dataset = norm(dataset)

extraAD1 = torch.load('../../datasets/GAN_SS/64ADgan1SS.pt')
extraAD2 = torch.load('../../datasets/GAN_SS/64ADgan2SS.pt')
extraAD3 = torch.load('../../datasets/GAN_SS/64ADgan3SS.pt')

extraNC1 = torch.load('../../datasets/GAN_SS/64NCgan1SS.pt')
extraNC2 = torch.load('../../datasets/GAN_SS/64NCgan2SS.pt')
extraNC3 = torch.load('../../datasets/GAN_SS/64NCgan3SS.pt')

extraAD = [([extraAD1[j][0], extraAD2[j][0], extraAD3[j][0]], extraAD1[j][1]) 
              for j in tqdm(range(len(extraAD1))) ]
extraNC = [([extraNC1[j][0], extraNC2[j][0], extraNC3[j][0]], extraNC1[j][1]) 
              for j in tqdm(range(len(extraNC1))) ]


lengths = [

    int(len(dataset) * 0.8),
    int(len(dataset) * 0.1),
    int(len(dataset) * 0.1) + 1
]


trainset, valset, testset = random_split(dataset, lengths)

trainset = torch.utils.data.ConcatDataset((trainset, extraNC, extraAD))

image_datasets = {'train': trainset, 'val': valset, 'test': testset}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
              for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}  

HBox(children=(FloatProgress(value=0.0, max=1181.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=705.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=476.0), HTML(value='')))




In [3]:
class MultiCNN(nn.Module):
    def __init__(self):
        super(MultiCNN, self).__init__()
        self.fc1 = nn.Linear(6144, 256)
        self.fc2 = nn.Linear(256, 2)
        self.fc3 = nn.Linear(10, 2)
        resnet = models.resnet50(pretrained=True).to(device) 
        self.new_resnet1 = nn.Sequential(*list(resnet.children())[:-1])
        self.new_resnet2 = nn.Sequential(*list(resnet.children())[:-1])
        self.new_resnet3 = nn.Sequential(*list(resnet.children())[:-1])
        self.drop = nn.Dropout(p=0.5)
        
    def forward(self, x_slices):

        x1 = x_slices[0]
        x1 = self.new_resnet1(x1)
        x1 = x1.view(-1, 2048)

        x2 = x_slices[1]
        x2 = self.new_resnet2(x2)
        x2 = x2.view(-1, 2048)

        x3 = x_slices[2]
        x3 = self.new_resnet3(x3)
        x3 = x3.view(-1, 2048)

        out = torch.cat((x1, x2, x3), dim=-1)
        out = F.relu(self.drop(self.fc1(out)))
        out = F.relu(self.fc2(out))

        return out

In [4]:
model = MultiCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
model = train_model(model, criterion, optimizer, exp_lr_scheduler, dataloaders, dataset_sizes,
                       num_epochs=50, batch_size = batch_size)

Epoch 0/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.3464 Acc: 0.8193


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5956 Acc: 0.7288
Epoch 1/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.1979 Acc: 0.9205


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6031 Acc: 0.7712
Epoch 2/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.1339 Acc: 0.9487


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.4956 Acc: 0.8051
Epoch 3/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0743 Acc: 0.9755


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.8893 Acc: 0.7627
Epoch 4/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0539 Acc: 0.9840


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.9314 Acc: 0.7288
Epoch 5/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0571 Acc: 0.9793


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7905 Acc: 0.7373
Epoch 6/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0289 Acc: 0.9906


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1909 Acc: 0.7288
Epoch 7/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0136 Acc: 0.9962


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.9882 Acc: 0.7797
Epoch 8/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0080 Acc: 0.9986


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0265 Acc: 0.7881
Epoch 9/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0059 Acc: 0.9981


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1144 Acc: 0.7797
Epoch 10/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0033 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1764 Acc: 0.7797
Epoch 11/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0019 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0601 Acc: 0.7881
Epoch 12/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0026 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1311 Acc: 0.7966
Epoch 13/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0031 Acc: 0.9991


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1470 Acc: 0.8051
Epoch 14/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0035 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0736 Acc: 0.7881
Epoch 15/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0022 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1873 Acc: 0.8051
Epoch 16/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0024 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1440 Acc: 0.7966
Epoch 17/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0015 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0940 Acc: 0.7881
Epoch 18/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0014 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1318 Acc: 0.8051
Epoch 19/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0051 Acc: 0.9986


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1076 Acc: 0.8136
Epoch 20/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0015 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0296 Acc: 0.7881
Epoch 21/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0018 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.2038 Acc: 0.7966
Epoch 22/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0038 Acc: 0.9991


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0662 Acc: 0.7966
Epoch 23/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0015 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.2035 Acc: 0.7881
Epoch 24/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0009 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1298 Acc: 0.8051
Epoch 25/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0010 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1776 Acc: 0.7881
Epoch 26/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0011 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0839 Acc: 0.8051
Epoch 27/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0024 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1768 Acc: 0.7881
Epoch 28/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0017 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1567 Acc: 0.7881
Epoch 29/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0017 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1456 Acc: 0.7966
Epoch 30/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0020 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0342 Acc: 0.8051
Epoch 31/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0019 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1036 Acc: 0.7966
Epoch 32/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0011 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1268 Acc: 0.8051
Epoch 33/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0015 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1487 Acc: 0.8136
Epoch 34/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0017 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1163 Acc: 0.7881
Epoch 35/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0014 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0379 Acc: 0.7881
Epoch 36/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0024 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.0559 Acc: 0.7966
Epoch 37/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0012 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1293 Acc: 0.7966
Epoch 38/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0013 Acc: 1.0000


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1181 Acc: 0.7966
Epoch 39/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))


train Loss: 0.0022 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 1.1709 Acc: 0.8051
Epoch 40/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=67.0), HTML(value='')))




KeyboardInterrupt: 

In [5]:
running_loss = 0.0
running_corrects = 0
for inputs, labels in tqdm(dataloaders['test'], total = dataset_sizes['test']//32+1):
    labels = labels.to(device)
    inputs = [i.to(device, dtype = torch.float) for i in inputs]

    with torch.no_grad():
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)


    running_loss += loss.item() * 32
    running_corrects += torch.sum(preds == labels.data)
    
print(f"Test Loss: {running_loss / dataset_sizes['test']}\nTest Accuracy: {running_corrects.double() / dataset_sizes['test']}")


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


Test Loss: 1.0762737899267374
Test Accuracy: 0.7815126050420168


In [None]:
from confusionmatrix import make_confusion_matrix

nb_classes = 2
cf = torch.zeros(nb_classes, nb_classes)

with torch.no_grad():
    for i, (inputs, classes) in enumerate(dataloaders['test']):
        classes = classes.to(device)
        inputs = [i.to(device, dtype=torch.float) for i in inputs]
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        for t, p in zip(classes.view(-1), preds.view(-1)):
            cf[t.long(), p.long()] += 1

cf = np.array(cf)

group_names = ['bruh','False Pos','False Neg','True Pos']
group_counts = ['{0:0.0f}'.format(value) for value in
                cf.flatten()]
group_percentages = ['{0:.2%}'.format(value) for value in
                     cf.reshape(-1)/cf.sum()]

labels = [f'{v1}\n{v2}\n{v3}' for v1, v2, v3 in zip(group_names,group_counts,group_percentages)]

make_confusion_matrix(cf, group_names=None,
                          categories='auto',
                          count=True,
                          percent=True,
                          cbar=True,
                          xyticks=True,
                          xyplotlabels=True,
                          sum_stats=True,
                          figsize= (10,7),
                          cmap='Blues',
                          title='Resnet 50 with GAN\n')