In [2]:
import cv2
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import os
from glob import glob

from tqdm.notebook import tqdm_notebook as tqdm
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader, sampler, random_split, Dataset
import torch
import timm
import random

In [16]:
DEVICE = torch.device("cpu")
torch.manual_seed(276)
SHAPE = (256,256,1)
IN_CHANNELS = 1
NUM_CLASSES = 4
BATCH_SIZE = 16
WEIGHT_PATH = '/data/users/6370327221/project/brain-mri-cls/weight/'
CLASS_MAP = {'0':'no_tumor', '1':'glioma', '2':'meningioma', '3':'pituatary'}

In [5]:
def getImageFromDataset(testDataset, idx):
    sampleImage, sampleLabel, path = testDataset.__getitem__(idx)[0], testDataset.__getitem__(idx)[1].numpy(), str(testDataset.__getitem__(idx)[2])
    sampleLabel = int(sampleLabel)
    sampleImage = ((sampleImage.permute(1,2,0).numpy()))
    sampleClassName = sampleLabel
    return sampleImage, sampleClassName, path

In [6]:
def test_model(dataloader, model):
    model.eval()
    correct_images = 0
    total_images = 0
    all_labels = []
    all_conf = []
    all_predicted = []
    test_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    with torch.no_grad():
        for batch_idx, batch_data in test_bar:
            images, labels, im_path = batch_data
            images, labels, im_path = images.to(DEVICE), labels.to(DEVICE), im_path
            outputs = model(images)
            scores = torch.nn.functional.softmax(outputs, dim=1)
            conf = torch.max(scores,1).values
            _, predicted = torch.max(outputs, 1)
            correct_images += (predicted == labels).sum().item()
            total_images += labels.size(0)
            accum_acc = round((correct_images/total_images)*100,4)
            test_bar.set_description("Testing accuracy: {}".format(accum_acc))
            all_labels.append(labels)
            all_predicted.append(predicted)
            all_conf.append(conf)
    
    all_conf = torch.cat(all_conf).cpu().numpy()
    all_labels = torch.cat(all_labels).cpu().numpy()
    all_predicted = torch.cat(all_predicted).cpu().numpy()
    return correct_images, total_images, all_labels, all_predicted, all_conf

In [7]:
def plot_confusion_matrix(cm,target_names,title='Confusion matrix',cmap=None,normalize=True):
    import matplotlib.pyplot as plt
    import numpy as np
    import itertools

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
            
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()

In [8]:
def showRandomSamples(labels, predicted, conf,dataset, sampleNum=5):
    incorrect_im_path = []
    
    equals = labels == predicted 
    correctIdx = np.where(equals == 1)[0].tolist()
    
    incorrectIdx = np.where(equals == 0)[0].tolist()

    randomCorrectIdx = random.sample(correctIdx, sampleNum)
    randomIncorrectIdx  = random.sample(incorrectIdx, sampleNum)

    _, figure = plt.subplots(2,sampleNum,figsize=(30,10))

    for imageIdx in range(len(randomCorrectIdx)):
        image, className, path = getImageFromDataset(dataset, randomCorrectIdx[imageIdx])
        sample_conf = conf[randomCorrectIdx[imageIdx]]
        figure[0,imageIdx].imshow(image[:,:,1],cmap='gray')
        className = map_label_name(className)
        figure[0,imageIdx].title.set_text(f'[CORRECT] {str(className)} confidence : {str(sample_conf)}')
  
    for imageIdx in range(len(randomIncorrectIdx)):
        image, className, path = getImageFromDataset(dataset, randomIncorrectIdx[imageIdx])
        incorrect_im_path.append(path)
        figure[1,imageIdx].imshow(image[:,:,1],cmap='gray')
        predictedClassName = predicted[randomIncorrectIdx[imageIdx]]
        predictedClassName = int(predictedClassName)
        sample_conf = conf[randomIncorrectIdx[imageIdx]]
        className = map_label_name(className)
        predictedClassName = map_label_name(predictedClassName)
        figure[1,imageIdx].title.set_text(f'Actual: {str(className)} Predicted: {str(predictedClassName)} Conf : {str(sample_conf)}')
    plt.show()
    return incorrect_im_path

In [9]:
def map_label_name(cat):
    for k,v in CLASS_MAP.items():
            if k == str(cat):
                label = v
    return label

In [39]:
def create_model(backbone, is_feat_extract=False):
    model = timm.create_model(backbone, pretrained=True, 
                                         in_chans=IN_CHANNELS, num_classes=NUM_CLASSES)
    if is_feat_extract:
        model = unfreeze_only_fc(model)
        
    return model

In [43]:
def load_weight(model, backbone, fold):
    weight_path = os.path.join(WEIGHT_PATH, f"{backbone}_fold_{fold}.pth")
    weight = torch.load(weight_path, map_location=DEVICE)
    model.load_state_dict(weight)
    model.eval()
    return model

In [46]:
class BrainDataset(Dataset):
    def __init__(
            self,
            root_dir = '/data/users/6370327221/dataset/MRI-Brain-tumor-cls/',
            is_train = True,
            transform = None
            ):
        if is_train:
            self.im_paths = sorted(glob(root_dir + 'Training/**' + '/*.jpg'))
        elif not is_train:
            self.im_paths = sorted(glob(root_dir + 'Testing/**' + '/*.jpg'))
        self.transform = transform

    def __len__(self):
        return len(self.im_paths)

    def __getitem__(self, index):
        im_path = self.im_paths[index]
        image = cv2.imread(im_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image = image / 255.0

        label = im_path.split('/')[-2].split('_')[0]
        label = torch.tensor(int(label)).long()

        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations["image"]
            image = image.float()

        return image, label, im_path

In [47]:
transform = A.Compose(
            [
                A.Resize(height=256, width=256),
                ToTensorV2(),
            ]
        )

In [48]:
test_dataset = BrainDataset(is_train=False, transform=transform)

In [51]:
model = create_model('resnet18')
model = load_weight(model, 'resnet18', 4)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,shuffle=False, num_workers=8)
correctImages, totalImages, testLabels, testPredicted, testConf = test_model(test_loader, model)
print(classification_report(testLabels, testPredicted))
confusionMatrix = confusion_matrix(testLabels, testPredicted)
plot_confusion_matrix(cm           = confusionMatrix, 
                      normalize    = False,
                      target_names = ['normal', 'abnormal'],
                      title        = "Classification Confusion Matrix")

true_pos = confusionMatrix[1][1] 
precision = true_pos / (true_pos + confusionMatrix[0][1])
recall = true_pos / (true_pos +  confusionMatrix[1][0])
print(f"Recall : {recall}")
print(f"Precision : {precision}")

  0%|          | 0/25 [00:00<?, ?it/s]