In [None]:
import glob
import torch
from torchvision import datasets, models, transforms
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
import torch.nn as nn
import torch.optim as optim
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image

from torch.nn import functional as F

In [None]:
torch.__version__

In [None]:
# Choose either Meander or Spiral to use,
# or you can train a mixed model if uncomment all
flds = [#'Meander_HandPD/Meander_Control',
        #'Meander_HandPD/Meander_Patients',
        'Spiral_HandPD/Spiral_Control',
        'Spiral_HandPD/Spiral_Patients']

images=[]
labels=[]
exams = []
img_files = []
for f in flds:
    img_list = glob.glob(f+'/*.jpg')
    if f.split('_')[-1] == 'Control':
        labels += [0]*len(img_list)
    else:
        labels += [1]*len(img_list)
    for im in img_list:
        images.append(Image.open(im).convert("RGB"))
        exams.append(im.split('/')[-1].split('-')[0])
        img_files.append(im)

In [None]:
print(len(images))
print(len(labels))

In [None]:
from torch.utils.data import Dataset
class DataSet(Dataset):
    def __init__(self, images, labels, transform, exams=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.exams = exams

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        tensor_image = self.transform(image)
        if self.exams:
            return tensor_image, label, self.exams[idx]
        else:
            return tensor_image, label

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ])

dataset = DataSet(images, labels, transform, exams)

In [None]:
def save_model(model, filepath):
    torch.save(model.state_dict(), filepath)
    print(f"Model saved to {filepath}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
torch.manual_seed(8)

# Load ResNet50 model WITH pre-trained parameters,
model = models.resnet50(pretrained=True)
# Freeze the mdoel parameters learned from large general image dataset
for param in model.parameters():
    param.requires_grad = False 

#Add simple MLP layers whose parameters will be learned from our HandPD data
n_feats = model.fc.in_features
model.fc = nn.Sequential(
               nn.Linear(n_feats, 128),
               nn.ReLU(),
               nn.Linear(128, 2)).to(device)


In [None]:
model = model.to(device)

In [None]:
def train_model(dataloaders, model, criterion, optimizer, num_epochs=20):
    best_model = model
    best_val_loss = 1000000
    test_acc = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            npreds = 0.0
            val_loss = 0.0
            for inputs, targets, exams in dataloaders[phase]:
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, targets)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == targets.data)
                npreds += len(preds)

            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = running_corrects.double() / npreds

            print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                        epoch_loss,
                                                        epoch_acc))
            if phase == 'val':
                val_loss = epoch_loss
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model = model
                    print('Set best model.')

    for phase in ['test']:
        best_model.eval()

        running_loss = 0.0
        running_corrects = 0
        npreds = 0.0
        all_preds = []
        all_outputs = []
        for inputs, targets, exams in dataloaders[phase]:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = best_model(inputs)
            loss = criterion(outputs, targets)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == targets.data)
            npreds += len(preds)
            all_preds.append(preds)
            all_outputs.append(outputs.detach().numpy())

        epoch_loss = running_loss / len(dataloaders[phase])
        epoch_acc = running_corrects.double() / npreds
        test_acc = epoch_acc
        print('+++++++++++ Test {} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                    epoch_loss,
                                                    epoch_acc))
        all_outputs = np.vstack(all_outputs)
    return (best_model, all_preds, all_outputs, test_acc)

In [None]:
# Define Stratified K-Fold cross-validator
from sklearn.model_selection import StratifiedGroupKFold
k_folds = 5
skf = StratifiedGroupKFold(n_splits=k_folds, shuffle=True, random_state=8)

In [None]:
from torch.optim import lr_scheduler
from sklearn.model_selection import StratifiedShuffleSplit

# Cross-validation process
best_model = model
results = []
test_outputs = []
test_ids = []

import time
start_time = time.time()

for fold, (train_val_idx, test_idx) in enumerate(skf.split(np.zeros(len(labels)), labels, exams)):
    print(f'Fold {fold + 1}')

    split_train_val = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_idx, val_idx = next(split_train_val.split(np.zeros(len(train_val_idx)), [labels[i] for i in train_val_idx]))

    # Subset datasets
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)
    test_subset = Subset(dataset, test_idx)

    # Data loaders
    train_loader = DataLoader(train_subset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=8, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=8, shuffle=False)

    dataloaders = {'train':train_loader,'val':val_loader, 'test':test_loader}
    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    best_model, all_preds, all_outputs, fold_test_acc = train_model(dataloaders, model, criterion, optimizer, num_epochs=80)

    test_outputs.append(all_outputs)
    test_ids += test_idx.tolist()

    #save_model(best_model,f'models/ResNet_handPD_cv_by_EXAM_fold_{fold + 1}.pth')
    results.append(fold_test_acc)

test_outputs = np.vstack(test_outputs)

print("--- %s seconds ---" % (time.time() - start_time))
test_outputs.shape


In [None]:
import pandas as pd
test_imgs =[img_files[i] for i in test_ids]
test_labels =[labels[i] for i in test_ids]

detailed_result_df = pd.DataFrame(data=test_ids, columns=['test_idx'])
detailed_result_df['healthy_prob'] = test_outputs[:,0]
detailed_result_df['patient_prob'] = test_outputs[:,1]
detailed_result_df['image_path'] = test_imgs
detailed_result_df['label'] = test_labels

detailed_result_df.to_csv('test_results.csv', index=False)

In [None]:
print(np.array([r.tolist() for r in results]).mean())
print(np.array([r.tolist() for r in results]).std())

#Grad-CAM

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50



In [None]:
img_paths = [
    # 'Meander_HandPD/Meander_Patients/0213-8.jpg',
    # 'Meander_HandPD/Meander_Patients/0187-6.jpg',
    'Spiral_HandPD/Spiral_Patients/0197-2.jpg',
    'Spiral_HandPD/Spiral_Patients/0246-3.jpg',
]

for img_path in img_paths:
    img_name = img_path.split('/')[-1].split('.')[0]
    rgb_img = Image.open(img_path).convert('RGB')

    img = np.array(rgb_img, dtype=np.uint8)

    cv_img = cv2.resize(img, (224, 224))
    
    infer_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224,224)),
            normalize
        ])
    
    img_tensor = infer_transform(img)
    input_tensor = torch.unsqueeze(img_tensor, dim=0)
    
    target_layers = [best_model.layer1[0], best_model.layer1[-1],
                     best_model.layer2[0], best_model.layer2[-1],
                     best_model.layer3[0], best_model.layer3[-1],
                     best_model.layer4[0], best_model.layer4[-1],
                    ]

    # We have to specify the target we want to generate the CAM for.
    targets = [ClassifierOutputTarget(1)]

    for ti in range(len(target_layers)):
        tlayers=[target_layers[ti]]
    
        # Construct the CAM object once, and then re-use it on many images.
        with AblationCAM(model=best_model, target_layers=tlayers) as cam:
          # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
          grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    
          grayscale_cam = grayscale_cam[0, :]
          visualization = show_cam_on_image(np.float32(img_tensor.permute(1,2,0)) / 255, grayscale_cam, use_rgb=True)
          # model_outputs = cam.outputs
          cam_image = cv2.cvtColor( visualization, cv2.COLOR_RGB2BGR)
          cv_img2 = cv2.cvtColor( cv_img, cv2.COLOR_BGR2RGB)
          im_h = cv2.hconcat([cv_img2, visualization, cam_image])
          #im_h = cv2.hconcat([img, visualization, cam_image])
    
          cv2.imwrite(f'results/{img_name}_layerid_{ti+1}.jpg', im_h)
