In [None]:
!pip install pytorch-metric-learning

In [None]:
!pip install natsort

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import pandas as pd
import os
import numpy as np
import time
import torch
import copy
from PIL import Image
from tqdm.notebook import trange, tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image, ImageReadMode
from pytorch_metric_learning import losses
from torch.cuda import amp
from collections import defaultdict
import natsort
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
class Human_Wound_Dataset(Dataset):
    def __init__(self, root_dir, seq_len, transform=None):
        self.root_dir = root_dir
        self.transform = transform 
        self.sequence_length = seq_len
        self.end_day = seq_len+1
        # self.labels_csv = pd.read_csv('/content/gdrive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_Learning_Data/Human_Labels_Final.csv')
        self.data_sequence = self.make_dataset(self.root_dir)
   
    def make_dataset(self, root_dir):
        
        image_name = []
        image_one_hot = []
        patient_image_sequence = []
        for image_root, image_dirs, image_files in os.walk(self.root_dir, topdown=False):
            image_dirs.sort()
            # print(image_files)
            for image_dirname in image_dirs:
                image_sub_path = os.path.join(image_root, image_dirname)
                image_patient_list = natsort.natsorted(os.listdir(image_sub_path))
                num_images_in_dir = len(image_patient_list)

                sequence_class = image_patient_list[0].split('_')[2]

                path_string = os.path.join(image_root, image_dirname + '/')
                
                

                if num_images_in_dir < self.sequence_length:
                    continue
                elif num_images_in_dir == self.sequence_length:
                    image_sequence = [path_string + s for s in image_patient_list]
                    if sequence_class[:-4] == 'Healer':
                        image_sequence.append(1)
                    else:
                        image_sequence.append(0)
                    patient_image_sequence.append(image_sequence)

                else:
                    image_sequence = []
                    for i in range(0, num_images_in_dir-self.sequence_length+1):
                        temp = image_patient_list[i: i+self.sequence_length]
                        image_sequence = [path_string + s for s in temp]
                        if sequence_class[:-4] == 'Healer':
                            image_sequence.append(1)
                        else:
                            image_sequence.append(0)
                        patient_image_sequence.append(image_sequence)
        for pat_seq in patient_image_sequence:
            print(pat_seq)
        return patient_image_sequence
        

    def pil_loader(self, path):
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def __getitem__(self, index: int):
        """
        Returns tensor data and label.
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where sample is a tensor of input data
                    target is class of sample.
        """
        # Get images and targets, drop alpha
        loaded_img_list = []
        image_paths_sequence = self.data_sequence[index][0:-1]
        target = self.data_sequence[index][-1]

        for img_path in image_paths_sequence:
            img = self.pil_loader(img_path)
            loaded_img_list.append(img)

        final_list = []
        if self.transform:
            for img in loaded_img_list:
                img = self.transform(img)
                final_list.append(img)            

        stacked_final = torch.stack(final_list)
        return stacked_final, target

    def __len__(self):
        return len(self.data_sequence)

In [6]:
human = Human_Wound_Dataset('/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/', 3)

['/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_0_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_7_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_14_Healer.png', 1]
['/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_7_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_14_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_21_Healer.png', 1]
['/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_14_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/BAART044/BAART044_21_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Hea

In [11]:
def wound_dataloader(dataset, batch_size: int, num_workers = 0, shuffle = True, pin_memory = False):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle = shuffle,
        pin_memory = pin_memory)

In [None]:
class Encoder(nn.Module):
    """
    Encoder model Pytorch. 
    """   
    def __init__(self):
        # Initialize self._modules as OrderedDict
        super(Encoder, self).__init__() 
        # Initialize densenet121
        self.embed_model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=False)
        # Remove Classifying layer
        self.embed_model = nn.Sequential(*list(self.embed_model.children())[:-1])
        # 7x7 average pool layer
        self.avg = nn.AvgPool2d(kernel_size=7, stride=1)
        # Left image connected layers
        self.fc_16 = nn.Linear(1024, 16)

    def forward(self, x):
        # Embed Left
        u1 = self.embed_model(x)
        u1 = self.avg(u1)
        u1 = u1.view(-1,1024)
        u1 = self.fc_16(u1)
        #u1 = torch.relu(u1)
        return u1

    def load_embed_wts(self, device):
        """
        load pretrained model weights, use only when transfer learning from ImageNET data
        """
        # Initialize densenet121
        self.embed_model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True).to(device)

        # Remove Classifying layer
        self.embed_model = nn.Sequential(*list(self.embed_model.children())[:-1])

class Projection(nn.Module):
    """
    Creates projection head
    Args:
        n_in (int): Number of input features
        n_hidden (int): Number of hidden features
        n_out (int): Number of output features
        use_bn (bool): Whether to use batch norm
    """
    def __init__(self, n_in: int, n_hidden: int, n_out: int,
                use_bn: bool = True):
        super().__init__()
        
        # No point in using bias if we've batch norm
        self.lin1 = nn.Linear(1024, 1024, bias = True)
        self.bn = nn.BatchNorm1d(1024)
        self.relu = nn.ReLU()
        # No bias for the final linear layer
        self.lin2 = nn.Linear(1024, 128, bias=False)
    
    def forward(self, x):
        x = self.lin1(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.lin2(x)
        return x

class SupConClassifier(nn.Module):
    def __init__(self, projection_n_in: int = 16,
               projection_n_hidden: int = 16, projection_n_out: int = 16,
               projection_use_bn: bool = True):
        super().__init__()
        
        # self.encoder = Encoder()
        self.pre_encoder = Encoder()
        self.encoder = nn.Sequential(*list(self.pre_encoder.children())[:-1])

        for param in self.encoder.parameters():
            param.requires_grad = False
        self.projection = Projection(projection_n_in, projection_n_hidden,
                                    projection_n_out, projection_use_bn)
  
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1,1024)
        x = self.projection(x)
        return x

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.selfsup = SupConClassifier()
        self.middle = nn.Linear(128,32)
        # self.relu = nn.ReLU()
        # self.classifier = nn.Linear(4,2)

    def forward(self, x):
        x = self. selfsup(x)
        x = self. middle(x)
        # x = self. middle(x)
        # x = self.classifier(x)

        return x

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            # model.eval()
            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase],desc='batches', leave = False):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    outputs = nn.functional.normalize(outputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            # epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            # print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))

            # # deep copy the model
            # if phase == 'val' and epoch_acc > best_acc:
            #     best_acc = epoch_acc
            #     best_model_wts = copy.deepcopy(model.state_dict())
            # if phase == 'val':
            #     val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [13]:
# Data augmentation and normalization for training
# Just normalization for validation

input_size = 224 # Muse be 224 (3x224x224) for Densenet121

# Standard Pytorch image transforms (source:https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html)
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(input_size),
        # transforms.RandomResizedCrop(input_size),
        # transforms.RandomRotation(degrees=(0,360)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        # transforms.RandomResizedCrop(input_size),
        # transforms.RandomRotation(degrees=(0,360)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [24]:
# Hyper Parameters:
#path to Wound_images
data_path = "/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/Dataset/query_set/"
learning_rate = 0.001
weight_decay = 0.1
batch_size = 8
num_epochs = 25
seq_length = 3
1
# Generate Datasets & Loaders
data_loaders_dict = {phase: wound_dataloader(Human_Wound_Dataset(data_path+"/"+phase, seq_length, transform=data_transforms[phase]), 
                                             batch_size, num_workers = 0, shuffle = True, pin_memory = True) for phase in ['train', 'val']}

# Create Model
model = SupConClassifier()
model.to(device)
path_to_wts = "/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Weights/meta_learned_ENCODER.tar" # Path to encoder wts
model.pre_encoder.load_state_dict(torch.load(path_to_wts))

# Specify optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# lossfun = nn.CrossEntropyLoss()
# lossfun = SupervisedContrastiveLoss()
lossfun = losses.SupConLoss(temperature = 0.1).to(device)

IndexError: ignored

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+--------------------------+------------+
|         Modules          | Parameters |
+--------------------------+------------+
| pre_encoder.fc_16.weight |   16384    |
|  pre_encoder.fc_16.bias  |     16     |
|  projection.lin1.weight  |  1048576   |
|   projection.lin1.bias   |    1024    |
|   projection.bn.weight   |    1024    |
|    projection.bn.bias    |    1024    |
|  projection.lin2.weight  |   131072   |
+--------------------------+------------+
Total Trainable Params: 1199120


1199120

In [None]:
healnet, hist = train_model(model, data_loaders_dict, criterion=lossfun, optimizer = optimizer, num_epochs=num_epochs)

Epoch 0/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 3.6293


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 3.6601

Epoch 1/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 2.8666


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 3.6615

Epoch 2/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 2.7980


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 3.7306

Epoch 3/24
----------


KeyboardInterrupt: ignored

In [None]:
save_as = "/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Weights/Contrastive_Encoder_June_16" # path and name of trained healNet weights
torch.save(healnet.state_dict(), save_as)

In [None]:
classification_model = Classifier()