In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip /gdrive/MyDrive/COVID-19_Radiography_Database_old.zip -d ./content

In [None]:
import os
import sys
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import glob
import cv2
import tqdm
import time
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision
from PIL import Image

In [None]:
def _min_max_scaling(img):
    return (img-np.min(img)) / (np.max(img)-np.min(img))

In [None]:
path = './content/COVID-19_Radiography_Dataset/COVID/'
covid_images = sorted(glob.glob(path + '*.png'))

idx = random.randint(0, len(covid_images))
img = cv2.imread(covid_images[idx])
print("image min value : ",img.min())
print("image max value : ",img.max())
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.hist(img.flatten())
plt.show()
    
img = cv2.imread(covid_images[idx])
img = _min_max_scaling(img)
print("image min value : ",img.min())
print("image max value : ",img.max())
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.hist(img.flatten())
plt.show()
    

In [None]:
path = './content/COVID-19_Radiography_Dataset/COVID/'
covid_images = sorted(glob.glob(path + '*.png'))

idx = random.randint(0, len(covid_images))
img = cv2.imread(covid_images[idx])
img = _min_max_scaling(img)
print("image min value : ",img.min())
print("image max value : ",img.max())
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.hist(img.flatten())
plt.show()
    

In [None]:
path = './content/COVID-19_Radiography_Dataset/Normal/'
normal_images = sorted(glob.glob(path + '*.png'))

idx = random.randint(0, len(normal_images))
img = cv2.imread(normal_images[idx])
print("image min value : ",img.min())
print("image max value : ",img.max())
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.hist(img.flatten())
plt.show()

In [None]:
path = './content/COVID-19_Radiography_Dataset/Viral Pneumonia/'
pneumonia_images = sorted(glob.glob(path + '*.png'))

idx = random.randint(0, len(pneumonia_images))
img = cv2.imread(pneumonia_images[idx])
print("image min value : ",img.min())
print("image max value : ",img.max())
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.hist(img.flatten())
plt.show()

In [None]:
path = './content/COVID-19_Radiography_Dataset/'
normal_images = sorted(glob.glob(path +'Normal/' +  '*.png'))
covid_images = sorted(glob.glob(path + 'COVID/' + '*.png'))
opacity_images = sorted(glob.glob(path + 'Lung_Opacity/' + '*.png'))
pneumonia_images = sorted(glob.glob(path+ 'Viral Pneumonia/' + '*.png'))

In [None]:
def seed_everything(seed):
    random.seed(seed) # python random seed 고정
    os.environ['PYTHONHASHSEED'] = str(seed) # os 자체의 seed 고정
    np.random.seed(seed) # numpy seed 고정 
    torch.manual_seed(seed) # torch seed 고정
    torch.cuda.manual_seed(seed) # cudnn seed 고정
    torch.backends.cudnn.deterministic = True # cudnn seed 고정(nn.Conv2d)
    torch.backends.cudnn.benchmark = False # CUDA 내부 연산에서 가장 빠른 알고리즘을 찾아 수행

In [None]:
from sklearn.model_selection import train_test_split
random_stat = 42
seed_everything(random_stat)

In [None]:
pneumonia_train_list , pneumonia_val_list = train_test_split(pneumonia_images, test_size=0.4, random_state=random_stat)
pneumonia_val_list , pneumonia_test_list = train_test_split(pneumonia_val_list, test_size=0.5, random_state=random_stat)
print("pneumonia_train_list :" , len(pneumonia_train_list))
print("pneumonia_val_list :" , len(pneumonia_val_list))
print("pneumonia_test_list :" , len(pneumonia_test_list))
print('-'*20)
normal_train_list , normal_val_list = train_test_split(normal_images, test_size=0.4, random_state=random_stat)
normal_val_list , normal_test_list = train_test_split(normal_val_list, test_size=0.5, random_state=random_stat)
print("normal_train_list :" , len(normal_train_list))
print("normal_val_list :" , len(normal_val_list))
print("normal_test_list :" , len(normal_test_list))
print('-'*20)
covid_train_list , covid_val_list = train_test_split(covid_images, test_size=0.4, random_state=random_stat)
covid_val_list , covid_test_list = train_test_split(covid_val_list, test_size=0.5, random_state=random_stat)
print("covid_train_list :" , len(covid_train_list))
print("covid_val_list :" , len(covid_val_list))
print("covid_test_list :" , len(covid_test_list))
print('-'*20)
opacity_train_list , opacity_val_list = train_test_split(opacity_images, test_size=0.4, random_state=random_stat)
opacity_val_list , opacity_test_list = train_test_split(opacity_val_list, test_size=0.5, random_state=random_stat)
print("opacity_train_list :" , len(opacity_train_list))
print("opacity_val_list :" , len(opacity_val_list))
print("opacity_test_list :" , len(opacity_test_list))
print('-'*20)

In [None]:
# overall_train_list = normal_train_list + covid_train_list  + opacity_train_list + pneumonia_train_list
# overall_val_list = normal_val_list + covid_val_list  + opacity_val_list + pneumonia_val_list
# overall_test_list = normsal_test_list[:len(pneumonia_test_list)] + covid_test_list[:len(pneumonia_test_list)] + opacity_test_list[:len(pneumonia_test_list)] + pneumonia_test_list

overall_train_list = normal_train_list[:len(pneumonia_test_list)] + covid_train_list[:len(pneumonia_test_list)] + opacity_train_list[:len(pneumonia_test_list)] + pneumonia_train_list[:len(pneumonia_test_list)]
overall_val_list = normal_val_list[:len(pneumonia_test_list)] + covid_val_list[:len(pneumonia_test_list)] + opacity_val_list[:len(pneumonia_test_list)] + pneumonia_val_list[:len(pneumonia_test_list)]
overall_test_list = normal_test_list[:len(pneumonia_test_list)] + covid_test_list[:len(pneumonia_test_list)] + opacity_test_list[:len(pneumonia_test_list)] +  pneumonia_test_list

print("overall_train_list :" , len(overall_train_list))
print("overall_val_list :" , len(overall_val_list))
print("overall_test_list :" , len(overall_test_list))

In [None]:
!pip install -U git+https://github.com/albumentations-team/albumentations --quiet

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

#feature generalization --> model 이 robust 하게 한다.

In [None]:
class DiseaseDataset(Dataset):
    def __init__(self ,phase_list, mode, image_size, aug, transform=None):
        self.mode = mode 
        self.image_size = image_size
        self.samples = phase_list
        self.aug = aug
        
        if mode == 'train':
            if self.aug == 'True':
                self.transform = A.Compose([
                    A.Resize(self.image_size, self.image_size),
                    A.OneOf([
                        A.MedianBlur(blur_limit=3, p=0.1),
                        A.MotionBlur(p=0.2),
                        ], p=0.2),
                    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=10, p=0.2),
                    A.OneOf([
                        A.OpticalDistortion(p=0.3),
                        ], p=0.2),
                    A.OneOf([
                        A.GaussNoise(p=0.2),
                        A.MultiplicativeNoise(p=0.2),
                        ], p=0.2),
                    A.HueSaturationValue(hue_shift_limit=0, sat_shift_limit=0, val_shift_limit=0.1, p=0.3),
                    ToTensorV2(),
                    ])
            else:
                self.transform = A.Compose([
                    A.Resize(self.image_size, self.image_size),
                    ToTensorV2(),
                    ])
        else:
            self.transform = A.Compose([
                A.Resize(self.image_size, self.image_size),
                ToTensorV2(),
                ])

    def __getitem__(self, idx):
        imgs = self.transform(image=self._preprocessing(self.samples[idx]))['image']
        if self.samples[idx].split('/')[-2] == 'Normal':
            labels = 0
        elif self.samples[idx].split('/')[-2] == 'COVID':
            labels = 1
        elif self.samples[idx].split('/')[-2] == 'Lung_Opacity':
            labels = 2
        else:
            labels = 3
        return imgs, labels
    
    def __len__(self):
        return len(self.samples)
    
    def _preprocessing(self, path):
        img = cv2.imread(path).astype(np.float32)
        img = self._min_max_scaling(img)
        return img
        
    def _min_max_scaling(self, img):
        return (img-np.min(img)) / (np.max(img)-np.min(img))


In [None]:
img_size=  256
aug = True
batch_size = 16
w = 6

train_datasets = DiseaseDataset(overall_train_list, mode='train', 
                                image_size=img_size, aug=aug)
val_datasets = DiseaseDataset(overall_val_list, mode='test', 
                              image_size=img_size, aug=False)
test_datasets = DiseaseDataset(overall_test_list, mode='test', 
                               image_size=img_size, aug=False)

In [None]:
train_loader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, 
                                num_workers=w, pin_memory=True, 
                                shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_datasets, batch_size=batch_size, 
                                num_workers=w, pin_memory=True, 
                                shuffle=False, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_datasets, batch_size=1, 
                                num_workers=w, pin_memory=True, 
                                shuffle=False, drop_last=True)

# shuffle , drop_last , pin_memory

In [None]:
from torchvision import models

model = models.resnet50(pretrained=False , num_classes=4)

In [None]:
from torchsummary import summary # keras와 다르게 torch는 기본 라이브러리에서 모델 구조를 가시화할 방법이 없습니다.

summary(model, input_size=(3, 256, 256), device='cpu') # channel, width, height

In [None]:
epochs = 3
print_freq = 30
lr = 0.001
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
model = model.cuda()

In [None]:
# training 을 위한 tool입니다.
class AverageMeter(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    
    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [None]:
!pip install livelossplot --quiet

In [None]:
from livelossplot import PlotLosses # 훈련하는 과정에서 동적으로 loss graph를 보여주게 하는 라이브러리입니다.

In [None]:
liveloss = PlotLosses()

for epoch in range(0, epochs):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, losses],
        prefix='Epoch: [{}]'.format(epoch))
    
    model.train()
    correct = 0
    total = 0
    end = time.time()
    running_loss = 0
    logs = {}
    
    for iter_, (imgs, labels) in enumerate(iter(train_loader)):
        imgs = imgs.cuda()
        labels = labels.cuda()
        
        outputs = model(imgs)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, labels)
        _, preds = outputs.max(1)
        total += labels.size(0)
        correct += preds.eq(labels).sum().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.update(loss.item(), imgs[0].size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        running_loss += loss.item()
        
        if (iter_ % print_freq == 0)& (iter_ != 0):
            progress.display(iter_)
    
    logs['train' + ' loss'] = running_loss / len(train_loader)
    logs['train' + '  acc'] = (100.*correct/total)
    model.eval()
    
    val_batch_time = AverageMeter('Time', ':6.3f')
    val_losses = AverageMeter('Loss', ':.4f')
    progress = ProgressMeter(
        len(val_loader),
        [val_batch_time, val_losses],
        prefix='Epoch: [{}]'.format(epoch))
    
    val_correct = 0
    val_total = 0
    val_running_loss = 0
    
    with torch.no_grad():
        for iter_, (imgs, labels) in enumerate(iter(val_loader)):
            imgs = imgs.cuda()
            labels = labels.cuda()
            
            outputs = model(imgs)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(outputs, labels)
            
            _, preds = outputs.max(1)
            val_total += labels.size(0)
            val_correct += preds.eq(labels).sum().item()
            
            val_losses.update(loss.item(), imgs[0].size(0))
            val_batch_time.update(time.time() - end)
            val_running_loss += loss.item()
            end = time.time()
            
            if (iter_ % print_freq == 0)& (iter_ != 0):
                progress.display(iter_)
    
    logs['val' + ' loss'] = val_running_loss / len(val_loader)
    logs['val' + '  acc'] = (100.*val_correct/val_total)
    model.eval()
    
    liveloss.update(logs)
    liveloss.draw()
    
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for iter_, (imgs, labels) in enumerate(iter(test_loader)):
            imgs = imgs.cuda()
            labels = labels.cuda()
            
            outputs = model(imgs)
            outputs = F.softmax(outputs, dim=1)
            _, preds = outputs.max(1)

            test_total += labels.size(0)
            test_correct += preds.eq(labels).sum().item()
        
    test_acc = 100.*test_correct/test_total
    print('[*] Test Acc: {:5f}'.format(test_acc))
    
    model.train()
    

In [None]:
from sklearn.metrics import confusion_matrix
import itertools

def show_confusion_matrix(cm, target_names, title='CFMatrix', cmap=None, normalize=False):
        
    acc = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - acc

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(12,10))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                    horizontalalignment="center",
                    color="white" if cm[i,j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                    horizontalalignment="center",
                    color="white" if cm[i,j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\n accuracy={:0.4f}'.format(acc))
    
def get_mertrix(gt, pred, class_list=['Normal', 'Abnormal']):
    cnf_matrix = confusion_matrix(gt,pred)
    FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)
    FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)
    TP = np.diag(cnf_matrix)
    TN = cnf_matrix.sum() - (FP + FN + TP)
    FP = FP.astype(float)
    FN = FN.astype(float)
    TP = TP.astype(float)
    TN = TN.astype(float)
    # Sensitivity, hit rate, recall, or true positive rate
    TPR = TP/(TP+FN)
    # Specificity or true negative rate
    TNR = TN/(TN+FP) 
    # Precision or positive predictive value
    PPV = TP/(TP+FP)
    # Negative predictive value
    NPV = TN/(TN+FN)
    # Fall out or false positive rate
    FPR = FP/(FP+TN)
    # False negative rate
    FNR = FN/(TP+FN)
    # False discovery rate
    FDR = FP/(TP+FP)
    F1_Score = 2*(PPV*TPR) / (PPV+TPR)
    # Overall accuracy for each class
    ACC = (TP + TN)/ (TP+FP+FN+TN)


    print('specificity: ', TNR) 
    print('sensitivity (recall): ', TPR) # true positive rate
    print('positive predictive value (precision): ', PPV)
    print('negative predictive value: ', NPV)
    print('acc: ', ACC)
    print('F1_score: ', F1_Score)
    show_confusion_matrix(cnf_matrix, class_list)
    
    return cnf_matrix

In [None]:
class_list = ['Normal', 'COVID', 'Lung Opacity' ,'Pneumonia']

def evaluate(loader, model , class_list):
    model.eval()
    
    correct = 0
    total = 0    
    overall_preds = []
    overall_gts = []

    for iter_, (imgs, labels) in tqdm.tqdm(enumerate(iter(loader))):
        imgs = imgs.cuda()
        labels = labels.cuda()

        outputs = model(imgs)
        outputs = F.softmax(outputs, dim=1)
        _, preds = outputs.max(1)

        total += labels.size(0)
        correct += preds.eq(labels).sum().item()
        
        ## For evaluation
        overall_preds += preds.cpu().detach().numpy().tolist()
        overall_gts += labels.cpu().detach().numpy().tolist()

    print('[*] Test Acc: {:5f}'.format(100.*correct/total))          
    return get_mertrix(overall_gts, overall_preds, class_list)


In [None]:
evaluate(test_loader, model.cuda() , class_list)

In [None]:
!git clone https://github.com/Project-MONAI/MONAI.git

In [None]:
import sys
sys.path.append('./MONAI')
from monai.visualize import GradCAM
from monai.visualize import CAM

In [None]:
def evaluate_cam(loader, model):
    model.eval()
    correct = 0
    total = 0
    cnt = 0
    
    for iter_, (imgs, labels) in tqdm.tqdm(enumerate(iter(loader))):
        if iter_ == 10:
            break
        imgs = imgs.cuda()
        labels = labels.cuda()
        pred_labels = model(imgs)
        cam = GradCAM(nn_module = model, target_layers = 'layer3')
        result = cam(x=imgs, layer_idx=-1)
        result = result.squeeze().cpu().detach().numpy()
        heatmap = np.uint8(255 * result)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap = heatmap/255

        gt_imgs = imgs.cpu().detach().numpy().transpose(0,2,3,1)
        cam_imgs = gt_imgs[0] 
        
        pred_labels = F.softmax(pred_labels, dim=1)
        
        print("Class label 0 : Normal  , Class label 1 : COVID  , Class label 2 : Lung Opacity  , Class label 3 : Pneumonia ")
        print("Labels is {} , pred_labels is {} ".format(labels.cpu().detach().numpy(),
                                                         np.round(pred_labels.cpu().detach().numpy(), 4)))
        print('-'*30)
        plt.figure(figsize=(15,15))
        plt.subplot(121)
        plt.imshow(gt_imgs[0][:,:,0],'gray')
        plt.subplot(122)
        plt.imshow(gt_imgs[0][:,:,0],'gray')
        plt.imshow(heatmap , 'inferno', alpha=0.3)
        plt.show()
        

In [None]:
evaluate_cam(test_loader, model.cuda())

In [None]:
class DiseaseDataset(Dataset):
    def __init__(self ,phase_list, mode, image_size, aug, transform=None):
        self.mode = mode 
        self.image_size = image_size
        self.samples = phase_list
        self.aug = aug
        
        if mode == 'train':
            if self.aug == 'True':
                self.transform = A.Compose([
                    A.Resize(self.image_size, self.image_size),
                    A.OneOf([
                        A.MedianBlur(blur_limit=3, p=0.1),
                        A.MotionBlur(p=0.2),
                        ], p=0.2),
                    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=10, p=0.2),
                    A.OneOf([
                        A.OpticalDistortion(p=0.3),
                        ], p=0.2),
                    A.OneOf([
                        A.GaussNoise(p=0.2),
                        A.MultiplicativeNoise(p=0.2),
                        ], p=0.2),
                    A.HueSaturationValue(hue_shift_limit=0, sat_shift_limit=0, val_shift_limit=0.1, p=0.3),
                    A.Normalize(mean=(0.485), std=(0.229)),
                    ToTensorV2(),
                    ])
            else:
                self.transform = A.Compose([
                    A.Resize(self.image_size, self.image_size),
                    A.Normalize(mean=(0.485), std=(0.229)),
                    ToTensorV2(),
                    ])
        else:
            self.transform = A.Compose([
                A.Resize(self.image_size, self.image_size),
                A.Normalize(mean=(0.485), std=(0.229)),
                ToTensorV2(),
                ])

    def __getitem__(self, idx):
        imgs = self.transform(image=self._preprocessing(self.samples[idx]))['image']
        if self.samples[idx].split('/')[-2] == 'Normal':
            labels = 0
        elif self.samples[idx].split('/')[-2] == 'COVID':
            labels = 1
        elif self.samples[idx].split('/')[-2] == 'Lung_Opacity':
            labels = 2
        else:
            labels = 3
        return imgs, labels
    
    def __len__(self):
        return len(self.samples)
    
    def _preprocessing(self, path):
        img = cv2.imread(path).astype(np.float32)
        img = self._min_max_scaling(img)
        return img
    def _min_max_scaling(self, img):
        return (img-np.min(img)) / (np.max(img)-np.min(img))
    

In [None]:
img_size=  256
aug = True
batch_size = 16
w = 6

train_datasets = DiseaseDataset(overall_train_list, mode='train', 
                                image_size=img_size, aug=aug)
val_datasets = DiseaseDataset(overall_val_list, mode='test', 
                              image_size=img_size, aug=False)
test_datasets = DiseaseDataset(overall_test_list, mode='test', 
                               image_size=img_size, aug=False)

train_loader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, 
                                num_workers=w, pin_memory=True, 
                                shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_datasets, batch_size=batch_size, 
                                num_workers=w, pin_memory=True, 
                                shuffle=False, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_datasets, batch_size=1, 
                                num_workers=w, pin_memory=True, 
                                shuffle=False, drop_last=True)

In [None]:

print('[*] build network...')
resnet_model = models.resnet50(pretrained=False)
resnet_model_pretrained = '/gdrive/MyDrive/moco_resnet50.pth.tar'

if resnet_model_pretrained is not None:
    if os.path.isfile(resnet_model_pretrained):
        print("[*] loading checkpoint '{}'".format(resnet_model_pretrained))
        checkpoint = torch.load(resnet_model_pretrained, map_location="cpu")

        # rename moco pre-trained keys
        state_dict = checkpoint['state_dict']
        for k in list(state_dict.keys()):
            # retain only encoder_q up to before the embedding layer
            if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                # remove prefix
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            # delete renamed or unused k
            del state_dict[k]

        msg = resnet_model.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

        print("=> loaded pre-trained model '{}'".format(resnet_model_pretrained))
    else:
        print("=> no checkpoint found at '{}'".format(resnet_model_pretrained))
    print("[*] moco weight load completed") 


model = resnet_model
num_classes = 4
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

checkpoint = torch.load('/gdrive/MyDrive/fine_tuning_moco.pth.tar')
model.load_state_dict(checkpoint['state_dict'])


In [None]:
evaluate(test_loader, model.cuda() , class_list)

In [None]:
def evaluate_cam(loader, model):
    model.eval()
    correct = 0
    total = 0
    cnt = 0
    img_mean = 0.485
    img_std = 0.229
    for iter_, (imgs, labels) in tqdm.tqdm(enumerate(iter(loader))):
        if iter_ == 100:
            break
        imgs = imgs.cuda()
        labels = labels.cuda()    
        pred_labels = model(imgs)
                
        cam = GradCAM(nn_module = model, target_layers = 'layer3' )
        result = cam(x=imgs, layer_idx=-1)
        result = result.squeeze().cpu().detach().numpy()
        
        heatmap = np.uint8(255 * result)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap = heatmap/255
                
        gt_imgs = imgs.cpu().detach().numpy().transpose(0,2,3,1)
        gt_imgs = (gt_imgs*img_std)+img_mean
        
        cam_imgs = gt_imgs[0]
        pred_labels = F.softmax(pred_labels, dim=1)
        
        print("Class label 0 : Normal  , Class label 1 : COVID  , Class label 2 : Lung Opacity  , Class label 3 : Pneumonia ")
        print("Labels is {} , pred_labels is {} ".format(labels.cpu().detach().numpy(), np.round(pred_labels.cpu().detach().numpy(), 4)))
        print('-'*30)
        plt.figure(figsize=(15,15))
        plt.subplot(121)
        plt.imshow(gt_imgs[0][:,:,0],'gray')
        plt.subplot(122)
        plt.imshow(gt_imgs[0][:,:,0],'gray')
        plt.imshow(heatmap , 'inferno', alpha=0.3)
        plt.show()

        

In [None]:
evaluate_cam(test_loader,model.cuda())