In [1]:
# numpy, scipy, pandas, sklearn, matplotlib
import numpy as np
import pandas as pd
from scipy.stats import entropy
# from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

# pytorch and pytorch lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision 
from torchvision import datasets
import torchvision.transforms as transforms
!pip install torchsummary
from torchsummary import summary
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

# others
import os
from tqdm import tqdm_notebook as tqdm
import time
import warnings
warnings.simplefilter("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
[0m

In [2]:
# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
# print("Number of workers:", NUM_WORKERS)

Device: cuda:0


In [3]:
# global constants
IMAGE_SIZE = (224, 224)
NUM_CLASSES = 3
BATCH_SIZE = 32

In [4]:
# dataset path
deploy_path = '../input/covid19radiographydatabaseedited/COVID-19_Radiography_Dataset'

# import data
image_transform = transforms.Compose([transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
deployset = datasets.ImageFolder(deploy_path,transform = image_transform)

# reduce size of deployment data (for computational time and memory)
n_deployment = 800
deploy_idx = np.random.choice(len(deployset.targets), size=n_deployment, replace=False)
deployset = Subset(deployset, deploy_idx)

# print("Training set size: {}\nValidation set size: {}\nDeployment set size: {}".format(len(trainset.targets),len(valset.targets),n_deployment))
# print(deployset.class_to_idx)

In [5]:
class SimCLR(pl.LightningModule):

    def __init__(self, hidden_dim, lr, temperature, weight_decay, pretrained_model=None, max_epochs=500):
        super().__init__()
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        # Base model f(.)
        self.model = pretrained_model if pretrained_model is not None else torchvision.models.resnet50(pretrained=True)  # Output of last linear layer: 2048-dim representation
        # print("Hi\n{}".format(list(self.model.children())))
        # The MLP for g(.) consists of Linear->ReLU->Linear
        # print(self.model.fc)
        # print(len(list(list(self.model.children())[-1].children())))
        # print(list(self.model.children())[-1])
        # print(list(list(self.model.children())[-1].children()))
        if len(list(list(self.model.children())[-1].children())) == 0:
            self.model = nn.Sequential(*(list(self.model.children())[:-1]),
                                        nn.Flatten(),
                                        nn.Linear(self.model.fc.in_features, 2048),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(2048, hidden_dim))
            print("done")
        # print("Hihi\n{}".format(list(self.model.children())))
    
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr/50) # search what does it refer to
        return [optimizer], [lr_scheduler]

    def info_nce_loss(self, batch, mode='train'):
        imgs, _ = batch
        # print(imgs.shape)
        imgs = torch.cat(imgs, dim=0)
        # print(imgs.shape)

        # Encode all images
        feats = self(imgs)
        # print(feats.shape)
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()

        # Logging loss
        print(mode+'_loss: {:.4f}'.format(nll))
        self.log(mode+'_loss', nll)
        # Get ranking position of positive example
        comb_sim = torch.cat([cos_sim[pos_mask][:,None],  # First position positive example
                              cos_sim.masked_fill(pos_mask, -9e15)],
                             dim=-1)
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
        # Logging ranking metrics
        self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean())
        self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean())
        self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())

        return nll

    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode='val')

In [6]:
simclr_model = SimCLR.load_from_checkpoint("../input/ssrchestxray/ssr-chest-x-ray.ckpt") # load pre-trained model
feature_model = torch.nn.Sequential(*(list(simclr_model.model.children())[:-3]))
# feature_model = simclr_model

FEATURE_SIZE = 2048

done


In [7]:
deploy_loader = DataLoader(deployset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)
# deploy_values = torch.empty(n_deployment, FEATURE_SIZE)
# deploy_labels = torch.empty(n_deployment)
deploy_values = torch.empty(n_deployment, FEATURE_SIZE)
deploy_labels = torch.empty(n_deployment)

for i, (images, labels) in enumerate(tqdm(deploy_loader, total=int(len(deploy_loader)))):
    deploy_values[BATCH_SIZE*i:BATCH_SIZE*i+images.shape[0],:] = feature_model(images).detach().cpu()
    deploy_labels[BATCH_SIZE*i:BATCH_SIZE*i+images.shape[0]] = labels

  0%|          | 0/25 [00:00<?, ?it/s]

In [8]:
# print(deploy_values.shape)
# print(deploy_labels)
# print(torch.sum(deploy_labels == 0.), torch.sum(deploy_labels == 1.), torch.sum(deploy_labels == 2.))

In [9]:
# class distribution
print("Number of class 0 (COVID) in the dataset: {}\nNumber of class 1 (Normal) in the dataset: {}\nNumber of class 2 (Viral Pneumonia) in the dataset: {}".format(torch.sum(deploy_labels == 0.).item(), torch.sum(deploy_labels == 1.).item(), torch.sum(deploy_labels == 2.).item()))

Number of class 0 (COVID) in the dataset: 194
Number of class 1 (Normal) in the dataset: 536
Number of class 2 (Viral Pneumonia) in the dataset: 70


In [10]:
cos_sim = nn.CosineSimilarity(dim = 1)
cos_sim_matrix = F.cosine_similarity(deploy_values[:,None,:], deploy_values[None,:,:], dim=-1)

In [11]:
# print(cos_sim_matrix)

In [12]:
# self mask
self_mask = torch.eye(cos_sim_matrix.shape[0], dtype=torch.bool, device=cos_sim_matrix.device)
cos_sim_matrix.masked_fill_(self_mask, -2)
# print(cos_sim_matrix)

tensor([[-2.0000,  0.8809,  0.5422,  ...,  0.7102,  0.9351,  0.9827],
        [ 0.8809, -2.0000,  0.5649,  ...,  0.9029,  0.9681,  0.9024],
        [ 0.5422,  0.5649, -2.0000,  ...,  0.7108,  0.4944,  0.4618],
        ...,
        [ 0.7102,  0.9029,  0.7108,  ..., -2.0000,  0.8090,  0.7036],
        [ 0.9351,  0.9681,  0.4944,  ...,  0.8090, -2.0000,  0.9649],
        [ 0.9827,  0.9024,  0.4618,  ...,  0.7036,  0.9649, -2.0000]])

In [13]:
# active learning step
SELECTION_BATCH_SIZE = 16
MIN_ALL_CLASSES = 2 # alternative: edit algorithm so random selection happens with a probability (exploration-exploitation trade-off)
selection_idx = []
least_selected = None

for i in range(int(0.5*n_deployment/SELECTION_BATCH_SIZE)):
    print("Selection batch {}".format(i+1))
    if i == 0 or torch.sum(deploy_labels[selection_idx] == least_selected).item() < MIN_ALL_CLASSES:
        # search until all classes have at least 1 data
        new_selection = np.random.choice(cos_sim_matrix.shape[0], size=SELECTION_BATCH_SIZE, replace=False)
    else:
        least_selected_pos = np.where(deploy_labels[selection_idx] == least_selected)[0] # get the positions in selection_idx for the least selected class
        least_selected_idx = selection_idx[np.random.choice(least_selected_pos, size=1)[0]] # get an index belonging to the least selected class
        sim_array = cos_sim_matrix[least_selected_idx,:] # get the similarity array for the specific index
        sim_array_rank = torch.argsort(sim_array, descending=True) # get the rank of each image in terms of similarity
        # print(sim_array) # sanity check
        # print(sim_array_rank) # sanity check
        new_selection = sim_array_rank[:SELECTION_BATCH_SIZE]
    for idx in new_selection.tolist():
        selection_idx.append(idx)
    # mask the index so they won't be selected again
    new_selection_mask = torch.Tensor([(i in selection_idx) for i in range(cos_sim_matrix.shape[0])]).type(torch.BoolTensor)
    cos_sim_matrix.masked_fill_(new_selection_mask, -2) # mask by columns
    
    least_selected = np.argmin([torch.sum(deploy_labels[selection_idx] == 0.).item(), torch.sum(deploy_labels[selection_idx] == 1.).item(), torch.sum(deploy_labels[selection_idx] == 2.).item()])
    print("Number of class 0 (COVID) selected: {}\nNumber of class 1 (Normal) selected: {}\nNumber of class 2 (Viral Pneumonia) selected: {}".format(torch.sum(deploy_labels[selection_idx] == 0.).item(), torch.sum(deploy_labels[selection_idx] == 1.).item(), torch.sum(deploy_labels[selection_idx] == 2.).item()))
    print("The least selected class is class {}".format(least_selected)) # sanity check

print(selection_idx)

Selection batch 1
Number of class 0 (COVID) selected: 1
Number of class 1 (Normal) selected: 13
Number of class 2 (Viral Pneumonia) selected: 2
The least selected class is class 0
Selection batch 2
Number of class 0 (COVID) selected: 4
Number of class 1 (Normal) selected: 25
Number of class 2 (Viral Pneumonia) selected: 3
The least selected class is class 2
Selection batch 3
Number of class 0 (COVID) selected: 10
Number of class 1 (Normal) selected: 26
Number of class 2 (Viral Pneumonia) selected: 12
The least selected class is class 0
Selection batch 4
Number of class 0 (COVID) selected: 17
Number of class 1 (Normal) selected: 34
Number of class 2 (Viral Pneumonia) selected: 13
The least selected class is class 2
Selection batch 5
Number of class 0 (COVID) selected: 23
Number of class 1 (Normal) selected: 44
Number of class 2 (Viral Pneumonia) selected: 13
The least selected class is class 2
Selection batch 6
Number of class 0 (COVID) selected: 31
Number of class 1 (Normal) selected: 