In [None]:
from PIL import Image
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import transforms, models, datasets
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, Dataset
import numpy as np
#import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
from torchvision.models import resnet50, efficientnet_b0, inception_v3, efficientnet_b4
from torchvision.models import ResNet50_Weights
from sklearn.metrics import f1_score, recall_score, precision_score, confusion_matrix
import seaborn as sns
import csv
import pandas as pd

from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

from efficientnet_pytorch import EfficientNet

In [None]:
# define the device (gpu/cpu)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## Data Preparation

In [None]:
# folder path
dir_paths = [r'SNACGH_Grade1_Disc', r'test-images', r'train-images', r'validation-images']

# list to store files
res = []

# all the paths
whole_paths = []

# image data array
images = []

# image labels
labels = []

# batch size and epochs
batch = 16
epochs = 800

# Iterate directory to store the paths
for dir_path in dir_paths:
    for path in os.listdir(dir_path):
        if path[0] != '.':
            #print(path)
            img_p = os.path.join(dir_path, path)
            # check if current path is a file
            if os.path.isfile(img_p):
                #print(path)
                # store the path names
                whole_paths.append(img_p)


In [None]:
# =============================================================================
# Resize and Crop (use early)
# =============================================================================
class ResizeCrop(object):
    ''' Resize all to same size, then crop to get a square '''

    def __init__(self, size):
        self.resize = transforms.Resize(size=size, interpolation=transforms.InterpolationMode.NEAREST_EXACT)
        self.centercrop = transforms.CenterCrop(size=size)

    def __call__(self, sample):
        image = sample['image']
        
        # resize and crop to same size for image
        image = self.centercrop(image)
        sample['image'] = self.resize(image)

        return sample

# =============================================================================
# Horizontal Flip (all to one side)
# =============================================================================
class RandomHorizontalFlip(object):
    ''' Flips image horizontal when necessary (all to LE) --> if RE flip, if LE no flip
    '''
    
    def __call__(self, sample):
        image, filename = sample['image'], sample['filename']
        tmp = filename.split('.')
        detail = tmp[0].split('-')
        
        if detail[-1] == 'RE':   
            image = transforms.functional.hflip(image)

        sample['image'] = image
        return sample
    

# =============================================================================
# Random Vertical Flip (use early)
# =============================================================================
class RandomVerticalFlip(object):
    ''' Flips image vartically when random value higher than p
    '''
    def __init__(self, p=0.5):
        self.p = p
    
    def __call__(self, sample):
        image = sample['image']
        
        if random.random() < self.p:   
            image = torchvision.transforms.functional.vflip(image)

        sample['image'] = image
        return sample
    

#
# =============================================================================
#  Normalise (use middle)
# =============================================================================
class Normalize(object):
    # used for both train and validation
    def __init__(self, mean, std):
        self.normalize = transforms.Normalize(mean, std)

    def __call__(self, sample):
        
        image = sample['image']
        sample['image'] = self.normalize(image)
        
        return sample


# =============================================================================
# ToTensor class (use last)
# =============================================================================
class ToTensor(object):
    ''' Creates a tensor from an image '''
    
    def __call__(self, sample):
        
        sample['image'] = transforms.functional.to_tensor(sample['image'])
            
        return sample


In [None]:
# define the customized dataset
class MyDataset(Dataset):
    """my own dataset."""

    def __init__(self, paths, mode):
        self.mode = mode
        self.size = 380
        
        # define the transform
        if mode == 'train':
            self.transform = transforms.Compose([
                ToTensor(),
                ResizeCrop(self.size),
                RandomHorizontalFlip(),
                RandomVerticalFlip(p=0.5),
                #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        else:
            self.transform = transforms.Compose([
                ToTensor(),
                ResizeCrop(self.size),
                #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            
        '''
        # split the sets
        SEED = 62
        random.seed(SEED)
        random.shuffle(paths)
        
        if mode == 'train':
            self.paths = paths[:int(0.7*len(paths))]
        elif mode == 'test':
            self.paths = paths[int(0.7*len(paths)):int(0.85*len(paths))]
        else:
            self.paths = paths[int(0.85*len(paths)):]
        '''
            
        self.paths = paths
                    
        

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        im = Image.open(self.paths[idx]).convert("RGB")
        #print(im.size)
        #print(self.paths[idx])
        
        '''
        width, height= im.size
        # Setting the points for cropped image
        left = 12
        top = 12
        right = width-12
        bottom = height-12
        im = im.crop((left, top, right, bottom))
        #print(im.size)
        '''
        
        # store the labels
        path = os.path.basename(self.paths[idx])
        tmp = path.split('.')
        filename = tmp[0].split('-')
        ori_label = int(filename[0])
        
        # for original dataset
        if ori_label > 2:
            label = 1
        else:
            label = 0
        #print(label)
        
        '''
        # for testing dataset
        if ori_label == 1:
            label = 1
        else:
            label = 0
        #print(label)
        '''
        

        item = {
            'image': im,
            'label': label,
            'filename': path  
        }
        
        
        if self.transform:
            item = self.transform(item)
        
        #print(im.shape)

        return item
    
    
    
    def load_gradcam_item(self, idx):
        
        im = Image.open(self.paths[idx]).convert("RGB")
        input_tensor, label = self.__getitem__(idx)['image'], self.__getitem__(idx)['label']
        
        # store the filename (including .png)
        filename = os.path.basename(self.paths[idx])
        
        return {
            'image':im,
            'tensor':input_tensor,
            'label':label,
            'filename':filename
            }



In [None]:
    
train_ds = MyDataset(whole_paths, 'train')
test_ds = MyDataset(whole_paths, 'test')
val_ds = MyDataset(whole_paths, 'val')

#print(train_ds.__len__())
#print(val_ds.__len__())
    
    
train_loader = DataLoader(train_ds, batch_size=batch, drop_last=True, shuffle=True) # train
test_loader = DataLoader(test_ds, batch_size=2, shuffle=False) # test
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False) # val


print("LOADERS: train, test, validation:", train_loader, test_loader, val_loader)


In [None]:
# test images

test_img_dir = r'E:\E-GLAU ONH Cirrus En Face Images'

test_img_whole_paths = []

for path in os.listdir(test_img_dir):
    if path[0] != '.':
        #print(path)
        img_p = os.path.join(test_img_dir, path)
        # check if current path is a file
        if os.path.isfile(img_p):
            #print(path)
            # store the path names
            test_img_whole_paths.append(img_p)

In [None]:
new_test_ds = MyDataset(test_img_whole_paths, 'test')
new_test_loader = DataLoader(new_test_ds, batch_size=2, shuffle=False) # test

In [None]:

# image, label = ....
# plot the image (hint: .detach().cpu().numpy() and .permute() 
# print the label 
if False:
    img, label = next(iter(train_loader))

    print(img)
    print(label)
    plt.imshow(np.transpose(img[0].cpu().detach().numpy(), (1, 2, 0)))
    plt.title(f"Label: {label[0]}")

## Model Defination and Training Process

In [None]:
def save_checkpoint(dir_name, epoch, f1, model):
    
    # If folder doesn't exists, create it #
    if not os.path.isdir(dir_name):
        os.mkdir(dir_name)
        
    filename = "checkpoint-{:.3f}-{}.pth".format(f1, epoch)
    whole_name = os.path.join(dir_name, filename)
    torch.save(model.state_dict(), whole_name)


In [None]:
# models

# ResNet50 (here we just simple use the predefined model, only change the last fc layer to fit the problem)
#model_ft = resnet50(pretrained=True) # for python 3.6
#model_ft = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)  #for python 3.9  #Alternatively, we can use weights='DEFAULT'
#num_ftrs = model_ft.fc.in_features
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
#model_ft.fc = nn.Linear(num_ftrs, 4)
#model_ft = EfficientNet.from_pretrained('efficientnet-b0', num_classes=4)

#efficientnet
model_ft = efficientnet_b0(weights='DEFAULT')
in_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(in_ftrs, 2)

# inception v3
#model_ft = inception_v3(weights='DEFAULT')
#in_ftrs = model_ft.fc.in_features
#model_ft.fc = nn.Linear(in_ftrs, 4)
print(model_ft)

In [None]:
model_ft = model_ft.to(device)

In [None]:
# loss functions
class_weight = torch.tensor([0.4, 0.6]).to(device) # balance the class
#criterion = nn.CrossEntropyLoss()
criterions = {'ce': nn.CrossEntropyLoss(weight = class_weight)}

In [None]:

def train_model(train_data, val_data, model, epochs, dir_name):
    
    model = model.to(device)
    # define the optimization
    # Observe that all parameters are being optimized
    optimizers = {'sgd': optim.SGD(model.parameters(), lr=0.001), 'adam': optim.Adam(model.parameters(), lr=0.001)}
    # Decay LR by a factor of 0.1 every 7 epochs
    lrsdls = {'sgd': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizers['sgd'], T_0 = 200, eta_min = 0.00001), 
             'adam': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizers['adam'], T_0 = 200, eta_min = 0.00001)}
    #exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    dataset_size = len(train_data.dataset)
    val_dataset_size = len(val_data.dataset)
    #print(dataset_size)
    #print(val_dataset_size)
    
    # default
    criterion = criterions['ce']
    optimizer_ft = optimizers['sgd']
    lrsdl = lrsdls['sgd']
    
    
    # store the accuracy and loss
    train_losses = []
    train_accs = []
    train_f1s = []
    train_recalls = []
    train_precisions = []
    val_losses = []
    val_accs = []
    val_f1s = []
    val_recalls = []
    val_precisions = []
    
    f1_improve = 0 #check early stopping
    

    # enumerate epochs
    for epoch in range(epochs):
        
        # choose the criterion
        #if epoch >= 30:
            #if epoch%20 == 0:
                #criterion = criterions['ce']
            #elif epoch%10 == 0:
                #criterion = criterions['cdw']
            #optimizer_ft = optimizers['adam']
            #lrsdl = lrsdls['adam']
        
        # record the labels and predictions of each epoch
        train_preds = []
        train_labels = []
        val_preds = []
        val_labels = []
        
        running_loss = []
        running_corrects = 0
        model.train()
        
        # enumerate mini batches
        for item in tqdm(train_data):
            
            train_labels.extend(item['label'].numpy())
            
            # pass the data to the device
            inputs = item['image'].to(device)
            targets = item['label'].to(device)
            
            # clear the gradients
            optimizer_ft.zero_grad()
            # compute the model output
            
            outputs = model(inputs)
            #print("train", outputs)
            
            _, preds = torch.max(outputs, 1)
            train_preds.extend(preds.detach().cpu().numpy())
            
            # calculate loss
            loss = criterion(outputs, targets)
            
            # credit assignment
            loss.backward()
            
            # update model weights
            optimizer_ft.step()
            lrsdl.step()
            
            # statistics
            running_loss.append(loss.detach().cpu().numpy())
            running_corrects += torch.sum(preds == targets.data)
            
        epoch_loss = np.mean(running_loss)
        epoch_acc = running_corrects / dataset_size
        
        epoch_loss = epoch_loss # .cpu()
        epoch_acc = epoch_acc.cpu()
        
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)
        
        #print(train_labels)
        #print(train_preds)
        
        '''
        # manually compute the confusion matrix and some statistics
        train_cf_matrix = [[0, 0, 0] for i in range(3)]
        #rint(train_labels)
        #print(train_preds)
        for i in range(len(train_labels)):
            # row true, column predicted
            #print(train_labels[i])
            train_cf_matrix[int(train_labels[i])][int(train_preds[i])] += 1
        train_cf_matrix = np.array(train_cf_matrix)
            
        # precision i
        tmp = np.sum(train_cf_matrix, axis=0)
        train_precision_compute = np.array([train_cf_matrix[i][i]/tmp[i] for i in range(3)])
        # recall i
        tmp = np.sum(train_cf_matrix, axis=1)
        train_recall_compute = np.array([train_cf_matrix[i][i]/tmp[i] for i in range(3)])
        # f1 i
        train_f1_compute = np.array([(2*train_precision_compute[i]*train_recall_compute[i])/(train_precision_compute[i]+train_recall_compute[i]) for i in range(3)])
        '''
        
        
        train_f1 = f1_score(train_labels, train_preds, average='macro')
        train_recall = recall_score(train_labels, train_preds, average='macro')
        train_precision = precision_score(train_labels, train_preds, average='macro')
        
        train_f1s.append(train_f1)
        train_recalls.append(train_recall)
        train_precisions.append(train_precision)

        print('Epoch {} Training Loss: {:.4f} Acc: {:.4f} F1: {:.4f} Recall: {:.4f} Precision: {:.4f}'.format(epoch+1, epoch_loss, epoch_acc, train_f1, train_recall, train_precision))
        
        # validation
        model.eval()
        with torch.no_grad():
            val_running_loss = []
            val_running_corrects = 0
            
            for item in tqdm(val_data):
                
                val_labels.extend(item['label'].numpy())
                val_inputs, val_targets = item['image'].to(device), item['label'].to(device)
                val_outputs = model(val_inputs)
                #print("val:", val_outputs)
                
                val_loss = criterion(val_outputs, val_targets)
                
                _, val_pred = torch.max(val_outputs, 1)
                val_preds.extend(val_pred.detach().cpu().numpy())
                
                # statistics
                val_running_loss.append(val_loss.detach().cpu().numpy())
                val_running_corrects += torch.sum(val_pred == val_targets.data)
            
            val_epoch_loss = np.mean(val_running_loss)
            val_epoch_acc = val_running_corrects / val_dataset_size    
            
            #val_epoch_loss = val_epoch_loss.cpu()
            val_epoch_acc = val_epoch_acc.cpu()
            
            val_losses.append(val_epoch_loss)
            val_accs.append(val_epoch_acc)
            
            val_f1 = f1_score(val_labels, val_preds, average='macro')
            val_recall = recall_score(val_labels, val_preds, average='macro')
            val_precision = precision_score(val_labels, val_preds, average='macro')
            
            val_f1s.append(val_f1)
            val_recalls.append(val_recall)
            val_precisions.append(val_precision)
            
            print('Epoch {} Validation Loss: {:.4f} Acc: {:.4f} F1: {:.4f} Recall: {:.4f} Precision: {:.4f}'.format(epoch+1, val_epoch_loss, val_epoch_acc, val_f1, val_recall, val_precision))
            
            if val_f1 > 0.85:
                f1_improve += 1
                if val_f1 == max(val_f1s) and val_f1 > max(val_f1s[:-1]):
                    save_checkpoint(dir_name, epoch+1, val_f1, model)
                    f1_improve = 0
                if f1_improve > 30:
                    print('Epoch {} early stopping'.format(epoch+1))
                    break
                
            
    # plot
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label="Validation")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title("Loss Graph")
    plt.show()
    
    plt.plot(train_accs, label='Train')
    plt.plot(val_accs, label='Validation')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Graph')
    plt.show()
    
    plt.plot(train_f1s, label='Train')
    plt.plot(val_f1s, label="Validation")
    plt.xlabel('Epochs')
    plt.ylabel('Macro F1')
    plt.title("Macro-F1 Graph")
    plt.show()
    
    plt.plot(train_recalls, label='Train')
    plt.plot(val_recalls, label="Validation")
    plt.xlabel('Epochs')
    plt.ylabel('Macro Recall')
    plt.title("Macro-Recall Graph")
    plt.show()
    
    plt.plot(train_precisions, label='Train')
    plt.plot(val_precisions, label="Validation")
    plt.xlabel('Epochs')
    plt.ylabel('Macro Precision')
    plt.title("Macro-Precision Graph")
    plt.show()
    
    # confusion matrix of the last time
    val_cfmtx = confusion_matrix(val_labels, val_preds)
    #print(val_cfmtx.type)
    val_cfmtxn = val_cfmtx.astype('float') / val_cfmtx.sum(axis=1)[:, np.newaxis]
    sns.heatmap(val_cfmtxn, annot=True, fmt='.2%', cmap='Blues')

In [None]:
#model_ft.load_state_dict(torch.load('model_saved.pth'))
        
# train
checkpoint_dir_name = 'binary_classification_efficientnetb0_checkpoints'
train_model(train_loader, val_loader, model_ft, epochs, checkpoint_dir_name)


# save the model (only the weights)
torch.save(model_ft.state_dict(), 'model_saved.pth')

## Result Analysis and Visualization

In [None]:
load_model_path = r'C:\Users\user\Downloads\Susan_project\binary_classification_efficientnetb0_checkpoints\checkpoint-0.950-228.pth'
model_ft.load_state_dict(torch.load(load_model_path))

In [None]:

#load the model again
#model = TheModelClass(*args, **kwargs) // use the original model structure
#model.load_state_dict(torch.load(PATH))
#model.eval()

# make prediction
def make_predictions(model, test_data):
    model.eval()
    test_labels = []
    test_preds = []
    #criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        
        for item in tqdm(test_data):
            
            test_labels.extend(item['label'].numpy())
            test_inputs, test_targets = item['image'].to(device), item['label'].to(device)
            test_outputs = model(test_inputs)
            #print("test:", test_outputs)
                
            test_loss = criterions['ce'](test_outputs, test_targets)
                
            _, test_pred = torch.max(test_outputs, 1)
            test_preds.extend(test_pred.detach().cpu().numpy())
    
    return {
        'groundtruth':test_labels,
        'predictions':test_preds
    }



In [None]:

# create a csv file to record the wrongly predicted ones
def load_error_record(ds, predicts):
    with open('error_record.csv', 'w', newline='') as csvfile:
    
        fieldnames = ['filename', 'true_label', 'predicted_label']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
    
        for i in range(len(ds)):
            if predicts['groundtruth'][i] != predicts['predictions'][i]:
                writer.writerow({'filename': ds.load_gradcam_item(i)['filename'], 'true_label':predicts['groundtruth'][i], 'predicted_label': predicts['predictions'][i]})



In [None]:
#grad cam
def show_gradcam(model, target_layers, desired_ds, predictions, use_cuda=True, true_labels=True):
    '''
    Parameters
    ----------
    model : 
        The model to be used.
    target_layers : 
        target layers... for resnet50 it's layer4
    use_cuda : Boolean, optional
        Whether to use cuda. The default is True.
    desired_ds : MyDataset class
        Which dataset want to use.
    true_labels : Boolean, optional
        Whether compare to true label. The default is False.

    Returns
    -------
    None.

    '''
    # can obtain RBG img and tensor and label from Mydataset function load_gradcam_item
    # take validation as example
    with GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda) as cam:
        
        # save it to a new folder
        dire = "gradcam_figures"
        # If folder doesn't exists, create it ##
        if not os.path.isdir(dire):
            os.mkdir(dire)
        
    
        for i in range(len(desired_ds)):
            if true_labels:
                targets = [ClassifierOutputTarget(category) for category in np.array([desired_ds.load_gradcam_item(i)['label']])]
                grayscale_cam = cam(input_tensor=desired_ds.load_gradcam_item(i)['tensor'].unsqueeze(0), targets=targets)
            else: 
                preds = [ClassifierOutputTarget(category) for category in np.array([predictions['predictions'][i]])]
                grayscale_cam = cam(input_tensor=desired_ds.load_gradcam_item(i)['tensor'].unsqueeze(0), targets=preds)
            
            grayscale_cam = grayscale_cam[0, :]
            img = transforms.CenterCrop(380)(desired_ds.load_gradcam_item(i)['image'])
            img = transforms.Resize(380)(img)
            img = (img - np.min(img)) / (np.max(img) - np.min(img))
            visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
            pImg = Image.fromarray(visualization, 'RGB')
            #break    
    
    
            # adjust the filename
            fn = desired_ds.load_gradcam_item(i)['filename']
            if true_labels:
                new_fn = ''.join(['true-', fn])
            else:
                # need the predictions to form the figures
                fn = ''.join([str(predictions['predictions'][i]), fn[1:]])
                new_fn = ''.join(['predicted-', fn])
            #print(new_fn)
            new_fn = os.path.join(dire, new_fn)
            #print(new_fn)
            #pImg.save(new_fn)
            
            if desired_ds.load_gradcam_item(i)['label'] != predictions['predictions'][i]:
                #print('Mismatch!',desired_ds.load_gradcam_item(i)['label'], predictions['predictions'][i])
                pImg.save(new_fn)
            #else:
                #print('Correct!',desired_ds.load_gradcam_item(i)['label'], predictions['predictions'][i])
            #break
            
    load_error_record(desired_ds, predictions)
    

In [None]:
new_predictions = make_predictions(model_ft, new_test_loader)

In [None]:
target_layers = [model_ft.features[-1]]

In [None]:
show_gradcam(model_ft, target_layers, new_test_ds, new_predictions, true_labels=False) 

In [None]:
predictions = make_predictions(model_ft, test_loader)

In [None]:
#target_layers = [model_ft.layer4[-1]]
target_layers = [model_ft.features[-1]]

In [None]:
show_gradcam(model_ft, target_layers, test_ds, predictions, true_labels=True) # use true labels

In [None]:
test_f1 = f1_score(predictions['groundtruth'], predictions['predictions'], average='macro')
test_recall = recall_score(predictions['groundtruth'], predictions['predictions'], average='macro')
test_precision = precision_score(predictions['groundtruth'], predictions['predictions'], average='macro')
print('Test: F1, Recall, Precision:', test_f1, test_recall, test_precision)

# confusion matrix of the last time
#test_cfmtx = confusion_matrix(predictions['groundtruth'], predictions['predictions'])
#sns.heatmap(test_cfmtx/np.sum(test_cfmtx), annot=True, fmt='.2%', cmap='Blues')


# now it can successfully save!
# what next: save it as the original name and the predicted labels...?/
# how to get the original filename?
# how to get their predicted labels? -> put inside the make_predictions function