In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [None]:
import logging

logging.basicConfig(filename="test_log.txt", level=logging.INFO, filemode="w")

In [None]:
import os
from tqdm import tqdm

In [None]:
# Data directory on Sandra's drive
data_dir = 'unzip_data'

In [None]:
!pip install torchvision==0.12.0

Defaulting to user installation because normal site-packages is not writeable


In [None]:
torch.__version__

'1.11.0+cu102'

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
!pip install tqdm
from tqdm.notebook import tqdm
import os
import copy
import pandas as pd
import PIL 
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import skimage
  
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only")
    print("You may want to try to use the GPU in Google Colab by clicking in:")
    print("Runtime > Change Runtime type > Hardware accelerator > GPU.")

Defaulting to user installation because normal site-packages is not writeable
Using the GPU!


In [None]:
normalize = transforms.Normalize(mean=[0.45271412, 0.45271412, 0.45271412],
                                     std=[0.33165374, 0.33165374, 0.33165374])
train_transformer = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop((224),scale=(0.5,1.0)),
    transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(90),
    # random brightness and random contrast
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    normalize
])

val_transformer = transforms.Compose([
#     transforms.Resize(224),
#     transforms.CenterCrop(224),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    normalize
])

In [None]:
## Another way to process the dataset based on the txt split file
## Consistent with the original paper

batchsize=4
def read_txt(txt_path):
    with open(txt_path) as f:
        lines = f.readlines()
    txt_data = [line.strip() for line in lines]
    return txt_data


class CovidCTDataset(Dataset):
    def __init__(self, root_dir, txt_COVID, txt_NonCOVID, transform=None):
        """
        Args:
            txt_path (string): Path to the txt file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        File structure:
        - root_dir
            - CT_COVID
                - img1.png
                - img2.png
                - ......
            - CT_NonCOVID
                - img1.png
                - img2.png
                - ......
        """
        self.root_dir = root_dir
        self.txt_path = [txt_NonCOVID, txt_COVID]
        self.classes = ['CT_NonCOVID', 'CT_COVID']
        self.num_cls = len(self.classes)
        self.img_list = []
        for c in range(self.num_cls):
            cls_list = [[os.path.join(self.root_dir,self.classes[c],item), c] for item in read_txt(self.txt_path[c])]
            self.img_list += cls_list
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.img_list[idx][0]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        sample = {'img': image,
                  'label': int(self.img_list[idx][1])}
        return sample

In [None]:
trainset = CovidCTDataset(root_dir=data_dir,
                          txt_COVID='Data-split/COVID/trainCT_COVID.txt',
                          txt_NonCOVID='Data-split/NonCOVID/trainCT_NonCOVID.txt',
                          transform= train_transformer)
valset = CovidCTDataset(root_dir=data_dir,
                          txt_COVID='Data-split/COVID/valCT_COVID.txt',
                          txt_NonCOVID='Data-split/NonCOVID/valCT_NonCOVID.txt',
                          transform= val_transformer)
testset = CovidCTDataset(root_dir=data_dir,
                          txt_COVID='Data-split/COVID/testCT_COVID.txt',
                          txt_NonCOVID='Data-split/NonCOVID/testCT_NonCOVID.txt',
                          transform= val_transformer)

print(trainset.__len__())
print(valset.__len__())
print(testset.__len__())

train_loader = DataLoader(trainset, batch_size=batchsize, drop_last=False, shuffle=True)
val_loader = DataLoader(valset, batch_size=batchsize, drop_last=False, shuffle=False)
test_loader = DataLoader(testset, batch_size=batchsize, drop_last=False, shuffle=False)

425
118
203


In [None]:
train_data = []
for i in range(trainset.__len__()):
   train_data.append([trainset.__getitem__(i)['img'], trainset.__getitem__(i)['label']])

trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batchsize, drop_last=True)

In [None]:
val_data = []
for i in range(valset.__len__()):
   val_data.append([valset.__getitem__(i)['img'], valset.__getitem__(i)['label']])

valloader = torch.utils.data.DataLoader(val_data, shuffle=False, batch_size=batchsize, drop_last=True)

In [None]:
test_data = []
for i in range(testset.__len__()):
   test_data.append([testset.__getitem__(i)['img'], testset.__getitem__(i)['label']])

testloader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=batchsize, drop_last=True)

In [None]:
source_dataset = datasets.ImageFolder('additional_data', train_transformer)

In [None]:
source_dataset

Dataset ImageFolder
    Number of datapoints: 14137
    Root location: additional_data
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
               RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)
               RandomHorizontalFlip(p=0.5)
               ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2], saturation=None, hue=None)
               ToTensor()
               Normalize(mean=[0.45271412, 0.45271412, 0.45271412], std=[0.33165374, 0.33165374, 0.33165374])
           )

In [None]:
source_dataset.classes

['1NonCOVID', '2COVID']

In [None]:
source_dataset.class_to_idx

{'1NonCOVID': 0, '2COVID': 1}

In [None]:
source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=batchsize, shuffle=True, drop_last=True)

In [None]:
source_loader

<torch.utils.data.dataloader.DataLoader at 0x2afca8faf550>

In [None]:
class MMD_loss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss

In [None]:
def CORAL(source, target):
    d = source.size(1)
    ns, nt = source.size(0), target.size(0)

    # source covariance
    tmp_s = torch.ones((1, ns)).cuda() @ source
    cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)

    # target covariance
    tmp_t = torch.ones((1, nt)).cuda() @ target
    ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)

    # frobenius norm
    loss = (cs - ct).pow(2).sum().sqrt()
    loss = loss / (4 * d * d)

    return loss

In [None]:
class ResNet50Fc(nn.Module):
    def __init__(self):
        super(ResNet50Fc, self).__init__()
        model_resnet50 = models.resnet50(pretrained=True)
        self.newmodel = torch.nn.Sequential(*(list(model_resnet50.children())[:-1]))
        self.__in_features = model_resnet50.fc.in_features

    def forward(self, x):
        x = self.newmodel(x)
        print(x.shape)
        x = x.view(x.size(0), -1)
        return x

    def output_num(self):
        return self.__in_features

In [None]:
class ResNet101Fc(nn.Module):
    def __init__(self):
        super(ResNet101Fc, self).__init__()
        model_resnet101 = models.resnet101(pretrained=True)
        self.newmodel = torch.nn.Sequential(*(list(model_resnet101.children())[:-1]))
        self.__in_features = model_resnet101.fc.in_features

    def forward(self, x):
        x = self.newmodel(x)
        x = x.view(x.size(0), -1)
        return x

    def output_num(self):
        return self.__in_features

In [None]:
class ResNet18Fc(nn.Module):
    def __init__(self):
        super(ResNet18Fc, self).__init__()
        model_resnet18 = models.resnet18(pretrained=True)
        self.newmodel = torch.nn.Sequential(*(list(model_resnet18.children())[:-1]))
        self.__in_features = model_resnet18.fc.in_features

    def forward(self, x):
        x = self.newmodel(x)
        x = x.view(x.size(0), -1)
        return x

    def output_num(self):
        return self.__in_features

In [None]:
class TransferNet(nn.Module):
    def __init__(self,
                 num_class, 
                 base_net='resnet18', 
                 transfer_loss='mmd', 
                 use_bottleneck=True, 
                 bottleneck_width=256, 
                 width=1024):
        super(TransferNet, self).__init__()
        if base_net == 'resnet50':
            self.base_network = ResNet50Fc()
        else:
            self.base_network = ResNet18Fc()
        self.use_bottleneck = use_bottleneck
        self.transfer_loss = transfer_loss
        bottleneck_list = [nn.Linear(self.base_network.output_num(
        ), bottleneck_width), nn.BatchNorm1d(bottleneck_width), nn.ReLU(), nn.Dropout(0.5)]
        self.bottleneck_layer = nn.Sequential(*bottleneck_list)
        # classifier_layer_list = [nn.Linear(self.base_network.output_num(), width), nn.ReLU(), nn.Dropout(0.5),
        #                          nn.Linear(width, num_class)]
        classifier_layer_list = [nn.Linear(self.base_network.output_num(), num_class)]
        self.classifier_layer = nn.Sequential(*classifier_layer_list)

        self.bottleneck_layer[0].weight.data.normal_(0, 0.005)
        self.bottleneck_layer[0].bias.data.fill_(0.1)
        # for i in range(2):
        #     self.classifier_layer[i * 3].weight.data.normal_(0, 0.01)
        #     self.classifier_layer[i * 3].bias.data.fill_(0.0)

    def forward(self, source, target):
        source = self.base_network(source)
        target = self.base_network(target)
        source_clf = self.classifier_layer(source)
        # add classification 
        target_clf = self.classifier_layer(target)
        if self.use_bottleneck:
            source = self.bottleneck_layer(source)
            target = self.bottleneck_layer(target)
        transfer_loss = self.adapt_loss(source, target, self.transfer_loss)
        return source_clf, transfer_loss, target_clf

    def predict(self, x):
        features = self.base_network(x)
        clf = self.classifier_layer(features)
        return clf

    def adapt_loss(self, X, Y, adapt_loss):
        """Compute adaptation loss, currently we support mmd and coral

        Arguments:
            X {tensor} -- source matrix
            Y {tensor} -- target matrix
            adapt_loss {string} -- loss type, 'mmd' or 'coral'. You can add your own loss

        Returns:
            [tensor] -- adaptation loss tensor
        """
        if adapt_loss == 'mmd':
            mmd_loss = MMD_loss()
            loss = mmd_loss(X, Y)
        elif adapt_loss == 'coral':
            loss = CORAL(X, Y)
        else:
            # Your own loss
            loss = 0
        return loss

In [None]:
transfer_loss = 'coral'
learning_rate = 0.0001
n_class = 2
transfer_model = TransferNet(n_class, transfer_loss=transfer_loss, base_net='').cuda()
# optimizer = torch.optim.SGD([
#     {'params': transfer_model.base_network.parameters()},
#     {'params': transfer_model.bottleneck_layer.parameters(), 'lr': 10 * learning_rate},
#     {'params': transfer_model.classifier_layer.parameters(), 'lr': 10 * learning_rate},
# ], lr=learning_rate, momentum=0.9, weight_decay=5e-4)
optimizer = optim.SGD(transfer_model.parameters(), lr=learning_rate, momentum=0.9)
lamb = 2.5
n_epoch = 40

In [None]:
def get_eval_results(model, data_loader):
    model.eval()
    true_label_list = []
    outputs_list = []
    predicted_label_list = []
    original_image_list = []

    # TQDM has nice progress bars
    for data, target in data_loader:
        inputs, labels = data.cuda(), target.cuda()
        with torch.set_grad_enabled(False):
            # Get model outputs and calculate loss
            outputs = model.predict(inputs)
            true_label_list.append(labels)
            original_image_list.append(inputs)
            outputs_list.append(outputs)
            _, preds = torch.topk(outputs, k=1, dim=1)
            predicted_label_list.append(preds)
    return torch.concat(true_label_list).unsqueeze(-1).cpu().numpy(), \
           torch.concat(predicted_label_list).cpu().numpy(), \
           torch.softmax(torch.concat(outputs_list), dim=1).cpu().numpy(), \
           torch.concat(original_image_list).cpu().numpy()

In [None]:
def evaluate(model, data_loader, is_labelled = False, generate_labels = True, k = 5):
    # If is_labelled, we want to compute loss, top-1 accuracy and top-5 accuracy
    # If generate_labels, we want to output the actual labels
    # Set the model to evaluate mode
    model.eval()
    running_loss = 0
    running_top1_correct = 0
    running_top5_correct = 0
    predicted_labels = []
    gt_labels = []

    # Iterate over data.
    # TQDM has nice progress bars
    for data, target in data_loader:
        inputs, labels = data.cuda(), target.cuda()
        tiled_labels = torch.stack([labels.data for i in range(k)], dim=1) 
        # Makes this to calculate "top 5 prediction is correct"
        # [[label1 label1 label1 label1 label1], [label2 label2 label2 label label2]]

        # forward
        # track history if only in train
        with torch.set_grad_enabled(False):
            # Get model outputs and calculate loss
            outputs = model.predict(inputs)
            criterion = nn.CrossEntropyLoss()
            if is_labelled:
                loss = criterion(outputs, labels)

            # torch.topk outputs the maximum values, and their indices
            # Since the input is batched, we take the max along axis 1
            # (the meaningful outputs)
            _, preds = torch.topk(outputs, k=k, dim=1)
            if generate_labels:
                # We want to store these results
                nparr = preds.cpu().detach().numpy()
                predicted_labels.extend([list(nparr[i]) for i in range(len(nparr))])
                gt_labels.extend(np.array(labels.cpu()))

        if is_labelled:
            # statistics
            running_loss += loss.item() * inputs.size(0)
            # Check only the first prediction
            running_top1_correct += torch.sum(preds[:, 0] == labels.data)
            # Check all 5 predictions
            running_top5_correct += torch.sum(preds == tiled_labels)
        else:
            pass

    # Only compute loss & accuracy if we have the labels
    if is_labelled:
        epoch_loss = float(running_loss / len(data_loader.dataset))
        epoch_top1_acc = float(running_top1_correct.double() / len(data_loader.dataset))
        epoch_top5_acc = float(running_top5_correct.double() / len(data_loader.dataset))
    else:
        epoch_loss = None
        epoch_top1_acc = None
        epoch_top5_acc = None
    
    # Return everything
    return epoch_loss, epoch_top1_acc, gt_labels, predicted_labels  

In [None]:
criterion = nn.CrossEntropyLoss()
early_stop = 20
dataloaders = {'src': source_loader, 'val': trainloader, 'tar': valloader}

In [None]:
def test(model, target_test_loader):
    model.eval()
    correct = 0
    len_target_dataset = len(target_test_loader.dataset)
    with torch.no_grad():
        for data, target in target_test_loader:
            data, target = data.cuda(), target.cuda()
            s_output = model.predict(data)
            pred = torch.max(s_output, 1)[1]
            correct += torch.sum(pred == target)
    acc = correct.double() / len(target_test_loader.dataset)
    return acc

In [None]:
CUDA_LAUNCH_BLOCKING=1

In [None]:
def train(dataloaders, model, optimizer):
    source_loader, target_train_loader, target_test_loader = dataloaders['src'], dataloaders['val'], dataloaders['tar']
    len_source_loader = len(source_loader)
    len_target_loader = len(target_train_loader)
    best_acc = 0
    stop = 0
    n_batch = min(len_source_loader, len_target_loader)
    for e in range(n_epoch):
        stop += 1
        train_loss_clf, train_loss_transfer, train_loss_total, target_loss_ = 0, 0, 0, 0
        model.train()
        for (src, tar) in zip(source_loader, target_train_loader):
            data_source, label_source = src
            data_target, label_target = tar
            data_source, label_source = data_source.cuda(), label_source.cuda()
            data_target, label_target = data_target.cuda(), label_target.cuda()

            optimizer.zero_grad()
            label_source_pred, transfer_loss, target_clf = model(data_source, data_target)
            clf_loss = criterion(label_source_pred, label_source)
            # add another loss
            clf_loss_target = criterion(target_clf, label_target)

            loss = clf_loss + lamb * transfer_loss + clf_loss_target
            # loss = clf_loss + lamb * transfer_loss
            loss.backward()
            optimizer.step()
            train_loss_clf = clf_loss.detach().item() + train_loss_clf
            train_loss_transfer = transfer_loss.detach().item() + train_loss_transfer
            train_loss_total = loss.detach().item() + train_loss_total
            target_loss_ = clf_loss_target.detach().item() + target_loss_
        acc = test(model, target_test_loader)
        logging.info(f'Epoch: [{e:2d}/{n_epoch}], target_loss_cls_:{target_loss_/n_batch:.4f}, loss: {train_loss_clf/n_batch:.4f}, transfer_loss: {train_loss_transfer/n_batch:.4f}, total_Loss: {train_loss_total/n_batch:.4f}, acc: {acc:.4f}')
        if best_acc < acc:
            best_acc = acc
            torch.save(model.state_dict(), 'trans_model.pkl')
            stop = 0
        if stop >= early_stop:
            break

In [None]:
def calc_prec_recall(y_true, y_pred):
    TP = ((y_true == 1) & (y_pred == 1)).sum()
    TN = ((y_true != 1) & (y_pred != 1)).sum()
    FP = ((y_pred == 1) & (y_true != 1)).sum()
    FN = ((y_pred != 1) & (y_true == 1)).sum()
    
    if TP + FP > 0:
        precision = TP / (TP + FP)
    else:
        precision = 1
    
    if TP + FN > 0:
        recall = TP / (TP + FN)
    else:
        recall = 1

    return precision, recall


def calc_f1_auc(y_true, y_pred):
    precision, recall = calc_prec_recall(y_label, y_pred)
    f1 = 2 * (precision * recall) / (precision + recall)
    
    probability_thresholds = np.linspace(0, 1, num=100, endpoint=True)

    precision_scores, recall_scores = [], []
    for p in probability_thresholds:

        y_new_pred = np.zeros(y_pred.shape)
        y_new_pred[(outputs[:, 1] > p)] = 1

        precision, recall = calc_prec_recall(y_label, y_new_pred)

        precision_scores.append(precision)
        recall_scores.append(recall)

    area_val = 0
    for i in range(1, len(precision_scores)):
        area_val += abs((precision_scores[i] + precision_scores[i-1]) * (recall_scores[i] - recall_scores[i-1])  / 2)
    
    return f1, area_val

In [None]:
logging.info("resnet 18, transfer_loss = CORAL, learning_rate = 0.0001, lamb = 2.5, n_epoch = 40, loss=clf_loss + lamb * transfer_loss + clf_loss_target")

In [None]:
test_acc, test_f1, test_auc = [], [], []

for i in range(4):
    train(dataloaders, transfer_model, optimizer)
    transfer_model.load_state_dict(torch.load('trans_model.pkl'))
    acc_test = test(transfer_model, testloader)
    logging.info(f'Test accuracy: {acc_test}')
    test_loss_yours, test_top1_yours, _, test_labels_yours = evaluate(transfer_model, testloader, is_labelled = True, generate_labels = True, k = 1)

    print("Our Trained model: ")
    print("Test Top-1 Accuracy: {}".format(test_top1_yours))
    test_acc.append(test_top1_yours)

    y_label, y_pred, outputs, inputs =  get_eval_results(transfer_model, testloader)
    f1, auc = calc_f1_auc(y_label, y_pred)
    
    test_f1.append(f1)
    test_auc.append(auc)

    print("f1 score is :", f1)
    print("AUC score is ", auc)

In [None]:
logging.info(f'Test accuracy: {np.mean(test_acc):.4f}, f1 score:{np.mean(test_f1):.4f}, AUC score:{np.mean(test_auc):.4f}')
logging.info(f'std accuracy: {np.std(test_acc):.4f}, f1 score:{np.std(test_f1):.4f}, AUC score:{np.std(test_auc):.4f}')