In [None]:
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import skorch
from skorch.classifier import NeuralNetClassifier
import pickle

In [None]:
# Load data
train_imgs = pickle.load(open('train_images_512.pk','rb'), encoding='bytes')
train_labels = pickle.load(open('train_labels_512.pk','rb'), encoding='bytes')
test_imgs = pickle.load(open('test_images_512.pk','rb'), encoding='bytes')

In [None]:
# Load pre-trained DenseNet121 model
class DenseNet121(nn.Module):

    def __init__(self, classCount, isTrained):

        super(DenseNet121, self).__init__()

        self.densenet121 = torchvision.models.densenet121(pretrained=isTrained)

        kernelCount = self.densenet121.classifier.in_features

        # This is changed later from the CheXNet classifier to a covid classifier
        self.densenet121.classifier = nn.Sequential(nn.Linear(kernelCount, classCount), nn.Sigmoid())

    def forward(self, x):
        x = self.densenet121(x)
        return x

In [None]:
# Allows cuda to find the best algorithm for the hardware
cudnn.benchmark = True

# Load the model for CheXNet
model = torch.nn.DataParallel(DenseNet121(14, True).cuda()).cuda()

# Load the weights for CheXNet
modelCheckpoint = torch.load('m-25012018-123527.pth.tar')

state_dict = modelCheckpoint['state_dict']

# Fix incompatibility between keys
import re
p = re.compile("(\.)(?=\d)")

new_state_dict = {}
for key, value in state_dict.items():
  if key == "module.densenet121.classifier.0.weight":
    new_key = key
  elif key == "module.densenet121.classifier.0.bias":
    new_key = key
  else:
    new_key = p.sub("",key)
  new_state_dict[new_key] = value

model.load_state_dict(new_state_dict)


In [None]:
 # Freeze all paramaters 
 # Uncomment to train only last layer of the net

 #for param in model.parameters():
 #   param.requires_grad = False
 
 # Replacing last (classification) layer to have correct number of classes
 # This resets param.requires_grad = False for the last layer,
 # allowing it to be trained alone if the above code is uncommented

model.module.densenet121.classifier = nn.Sequential(nn.Linear(in_features=1024, out_features=2), nn.Sequential())

In [None]:
# Perform the same transformations and normalization on the data as was 
# used in CheXNet

transResize = 256
transCrop = 224     

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
transformList = []
transformList.append(transforms.RandomResizedCrop(transCrop))
transformList.append(transforms.RandomHorizontalFlip())
transformList.append(transforms.ToTensor())
transformList.append(normalize)      
transformSequence=transforms.Compose(transformList)

# Define custom Dataset for the training data to apply transformations
class CovidDatasetTrainTransforms(Dataset):
    """Covid chest scans training dataset."""

    def __init__(self, imgs, labels, transform=transformSequence):
        self.imgs = imgs
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        imgs = self.imgs
        if self.transform != None: imageData = self.transform(transforms.ToPILImage()(imgs[idx]))
        return imgs[idx], self.labels[idx]


In [None]:
# Parameters to play with: https://skorch.readthedocs.io/en/stable/user/neuralnet.html
# Used NeuralNetClassifier to perform CV with skorch
net = NeuralNetClassifier(
    model.module,
    criterion=nn.CrossEntropyLoss,
    optimizer=optim.SGD,
    optimizer__momentum=0,
    lr=0.001,         # Relatively low learning rate to maintain weights found by CheXNet
    max_epochs=750,   # Region found to have the best validation error
    batch_size=20,    # Set as high a colab will allow to reduce stocasticity in gradients - may need to restart run to get better cpu
    iterator_train__shuffle=True,
    train_split=None, # Set to None for final training of the whole data set - use skorch.dataset.CVSplit(5) for CV
    device='cuda'
)

In [None]:
# Train the model
net.fit(CovidDatasetTrainTransforms(train_imgs, train_labels), y=None) 

In [None]:
# Predict labels
net.predict(test_imgs)

In [None]:
# Return prediction probabilities to adjust threshold
net.predict_proba(test_imgs)

In [None]:
# Results of six runs
l1 = [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1]
l2 = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1]
l3 = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
l4 = [0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1]
l5 = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
l6 = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

p1 = np.array([[-0.27811393,  0.14735694],
       [-0.31309873,  0.31708255],
       [-1.645754  ,  1.8220613 ],
       [-2.024819  ,  2.2430968 ],
       [-2.1563482 ,  2.3367486 ],
       [-3.0071325 ,  3.6435819 ],
       [-1.1942502 ,  1.325249  ],
       [ 0.08353306, -0.18525176],
       [-0.44969025,  0.42628393],
       [-3.2869964 ,  3.7485967 ],
       [-1.5691644 ,  1.6424458 ],
       [-0.80113155,  0.84656566],
       [-0.31289968,  0.2370822 ],
       [-1.5189306 ,  1.715897  ],
       [-2.802221  ,  3.2429638 ],
       [ 0.0137746 , -0.07757613],
       [-2.841208  ,  3.2509167 ],
       [-3.3783474 ,  3.8599539 ],
       [-1.7117329 ,  1.9713236 ],
       [-0.45359203,  0.48920602]])
p2 = np.array([[-1.2749506 ,  1.0391239 ],
       [-1.900973  ,  1.9431375 ],
       [-2.5834403 ,  2.681222  ],
       [-1.4890026 ,  1.4130588 ],
       [-3.5990112 ,  3.9035232 ],
       [-6.6630864 ,  7.079091  ],
       [-4.0231423 ,  4.281615  ],
       [-1.5663888 ,  1.6640311 ],
       [-2.9456632 ,  2.81032   ],
       [-6.8810472 ,  7.2846017 ],
       [-1.9778112 ,  1.8363433 ],
       [-0.6686553 ,  0.43312752],
       [-0.1640156 , -0.06224612],
       [-2.4926245 ,  2.6010714 ],
       [-4.0785923 ,  4.278558  ],
       [ 0.06012314, -0.2652411 ],
       [-4.0402765 ,  4.1026406 ],
       [-3.7486196 ,  4.107744  ],
       [-0.93817925,  0.82568735],
       [-3.3648107 ,  3.5886166 ]])
p3 = np.array([[-1.2306907 ,  0.56224006],
       [-1.0445266 ,  0.5948984 ],
       [-2.5809722 ,  2.4847565 ],
       [-0.7480598 , -0.08442096],
       [-1.7244426 ,  1.0858076 ],
       [-0.92488146,  0.25149146],
       [-1.6521794 ,  1.1366856 ],
       [-1.3261532 ,  0.98897165],
       [-1.6935257 ,  1.1551778 ],
       [-3.2475042 ,  2.8874002 ],
       [-1.5037724 ,  1.3778486 ],
       [-1.5310456 ,  1.5377936 ],
       [-1.060547  ,  0.16646525],
       [-1.545185  ,  1.1831194 ],
       [-1.6743516 ,  1.6942255 ],
       [-1.330892  ,  1.205843  ],
       [-2.9033842 ,  2.721028  ],
       [-2.1217332 ,  2.112975  ],
       [-1.2982858 ,  0.74919575],
       [-1.4203405 ,  0.8983746 ]])
p4 = np.array([[ 1.2999684 , -1.0184019 ],
       [ 0.302412  , -0.14628246],
       [-2.2321208 ,  2.45701   ],
       [-0.76543415,  0.9794851 ],
       [-2.483469  ,  2.8138683 ],
       [-0.5529454 ,  0.7383251 ],
       [-1.0397247 ,  1.1657188 ],
       [ 0.48791945, -0.4144309 ],
       [-0.5531086 ,  0.7455833 ],
       [-1.9503564 ,  2.2771873 ],
       [ 0.11591991,  0.04245187],
       [-0.10108346,  0.32727188],
       [ 1.3019236 , -1.0425917 ],
       [-1.9521222 ,  2.1308362 ],
       [-2.0220902 ,  2.0503178 ],
       [ 0.66444993, -0.4163615 ],
       [-4.0514026 ,  4.4185004 ],
       [-2.42929   ,  2.610626  ],
       [-1.5444623 ,  1.6495777 ],
       [-0.87637365,  1.0479659 ]])
p5 = np.array([[-3.684291 ,  4.226165 ],
       [-3.4298754,  3.8904073],
       [-4.6021442,  4.902417 ],
       [-4.867758 ,  5.4403343],
       [-4.9466047,  5.800642 ],
       [-3.6632683,  4.315364 ],
       [-4.094542 ,  4.6687307],
       [-1.7738253,  2.3605208],
       [-2.4931495,  2.9982553],
       [-5.4592533,  6.141738 ],
       [-2.4794126,  2.9762235],
       [-1.9761522,  2.305777 ],
       [-2.9955099,  3.5300498],
       [-4.1956825,  4.6662884],
       [-3.3402123,  3.8599875],
       [-2.4725065,  2.8076339],
       [-4.437647 ,  5.302191 ],
       [-4.1420946,  4.736328 ],
       [-2.4322925,  3.0357833],
       [-3.4101582,  3.8587825]])
p6 = np.array([[-4.233863  ,  3.7813666 ],
       [-1.6022581 ,  1.2509221 ],
       [-3.7062836 ,  3.3190281 ],
       [-4.1410913 ,  3.7958035 ],
       [-3.1741126 ,  2.8784363 ],
       [-1.8122883 ,  1.5302601 ],
       [-3.6508923 ,  3.201008  ],
       [-2.4756792 ,  1.9590744 ],
       [-2.5117185 ,  2.1222913 ],
       [-3.9538887 ,  3.6661747 ],
       [-4.596003  ,  3.9455087 ],
       [-2.9618855 ,  2.4459069 ],
       [-1.5618628 ,  1.1708493 ],
       [-3.517534  ,  3.2123692 ],
       [-3.7762914 ,  3.3270233 ],
       [-0.98137105,  0.5727222 ],
       [-4.3943396 ,  4.0746307 ],
       [-4.2565136 ,  3.8912156 ],
       [-3.6053953 ,  3.129113  ],
       [-2.7658763 ,  2.4037802 ]])

In [None]:
# Average over all probabilites

np.mean([p1, p2, p3, p4, p5, p6], axis=0)

In [None]:
# Average for the three runs that were not degenerate 
# i.e. predicted not covid at least once

np.mean([p1, p2, p4], axis=0)