<a href="https://colab.research.google.com/github/shubham8899/wound-few-shot/blob/main/Supervised_Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q timm pytorch-metric-learning

[K     |████████████████████████████████| 431 kB 8.4 MB/s 
[K     |████████████████████████████████| 110 kB 53.8 MB/s 
[?25h

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import pandas as pd
import os
import numpy as np
import time
import torch
import copy
from PIL import Image
from tqdm.notebook import trange, tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image, ImageReadMode
from pytorch_metric_learning import losses
import timm
from torch.cuda import amp
from collections import defaultdict
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class Wound_Dataset(Dataset):

    def __init__(self, root_dir, transform=None):
        """
        Wound dataset. Returns tensorized images with
        corresponding class 1: healer, 0: nonhealer
        Args:
            root_dir: path to healer/nonhealer image dataset.
        """
        self.root_dir = root_dir
        self.transform = transform 
        self.samples, self.targets = self.make_dataset(self.root_dir)

    def make_dataset(self, root_dir):
 
        instances = []
        targets = []


        for i in os.listdir(root_dir):
            instances.append(root_dir + "/" + i)
            # print(df[df['Filename'] == i]['Label_bin'])
            # print(i)
            label = i.split('_')[2].split('.')[0]
            if label == 'NonHealer':
                targets.append(0)
            elif label == 'Healer':
                targets.append(1)

        print(instances)
        print(targets)

        return instances, targets

    def pil_loader(self, path):
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')
            img = img.resize((224,224), Image.BICUBIC)
            return img

    def __getitem__(self, index: int):
            """
            Returns tensor data and label.
            Args:
                index (int): Index
            Returns:
                tuple: (sample, target) where sample is a tensor of input data
                        target is class of sample.
            """
            # Get images and targets, drop alpha
            img = self.pil_loader(self.samples[index])
            target = self.targets[index]

            # print(img.size)
            # print(self.samples[index])

            if self.transform:
                img = self.transform(img)

            # Concat and return sample, target
            return img , target

    def __len__(self):
        return len(self.samples)

In [5]:
def wound_dataloader(dataset, batch_size: int, num_workers = 0, shuffle = True, pin_memory = False):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle = shuffle,
        pin_memory = pin_memory)

In [24]:
class Encoder(nn.Module):
    """
    Encoder model Pytorch. 
    """   
    def __init__(self):
        # Initialize self._modules as OrderedDict
        super(Encoder, self).__init__() 
        # Initialize densenet121
        self.embed_model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=False)
        # Remove Classifying layer
        self.embed_model = nn.Sequential(*list(self.embed_model.children())[:-1])
        # 7x7 average pool layer
        self.avg = nn.AvgPool2d(kernel_size=7, stride=1)
        # Left image connected layers
        self.fc_16 = nn.Linear(1024, 16)


    def forward(self, x):
        # Embed Left
        u1 = self.embed_model(x)
        u1 = self.avg(u1)
        u1 = u1.view(-1,1024)
        u1 = self.fc_16(u1)
        #u1 = torch.relu(u1)
        return u1

    def load_embed_wts(self, device):
        """
        load pretrained model weights, use only when transfer learning from ImageNET data
        """
        # Initialize densenet121
        self.embed_model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True).to(device)

        # Remove Classifying layer
        self.embed_model = nn.Sequential(*list(self.embed_model.children())[:-1])

class Projection(nn.Module):
    """
    Creates projection head
    Args:
        n_in (int): Number of input features
        n_hidden (int): Number of hidden features
        n_out (int): Number of output features
        use_bn (bool): Whether to use batch norm
    """
    def __init__(self, n_in: int, n_hidden: int, n_out: int,
                use_bn: bool = True):
        super().__init__()
        
        # No point in using bias if we've batch norm
        self.lin1 = nn.Linear(16, 16, bias = True)
        self.bn = nn.BatchNorm1d(16)
        self.relu = nn.ReLU()
        # No bias for the final linear layer
        self.lin2 = nn.Linear(16, 32, bias=False)
    
    def forward(self, x):
        x = self.lin1(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.lin2(x)
        return x

class SupConClassifier(nn.Module):
    def __init__(self, projection_n_in: int = 16,
               projection_n_hidden: int = 16, projection_n_out: int = 16,
               projection_use_bn: bool = True):
        super().__init__()
        
        self.encoder = Encoder()

        for param in self.encoder.parameters():
            param.requires_grad = False

        self.projection = Projection(projection_n_in, projection_n_hidden,
                                    projection_n_out, projection_use_bn)
  
    def forward(self, x):
        x = self.encoder(x)
        x = self.projection(x)
        return x

# class Classifier(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.selfsup = SupConClassifier()
#         self.middle = nn.Linear(128,32)
#         # self.relu = nn.ReLU()
#         # self.classifier = nn.Linear(4,2)

#     def forward(self, x):
#         x = self. selfsup(x)
#         x = self. middle(x)
#         # x = self. middle(x)
#         # x = self.classifier(x)

#         return x

In [25]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            # model.eval()
            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase],desc='batches', leave = False):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    outputs = nn.functional.normalize(outputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            # epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            # print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))

            # # deep copy the model
            # if phase == 'val' and epoch_acc > best_acc:
            #     best_acc = epoch_acc
            #     best_model_wts = copy.deepcopy(model.state_dict())
            # if phase == 'val':
            #     val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [26]:
# Data augmentation and normalization for training
# Just normalization for validation

input_size = 224 # Muse be 224 (3x224x224) for Densenet121

# Standard Pytorch image transforms (source:https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html)
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(input_size),
        # transforms.RandomResizedCrop(input_size),
        # transforms.RandomRotation(degrees=(0,360)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        # transforms.RandomResizedCrop(input_size),
        # transforms.RandomRotation(degrees=(0,360)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [27]:
# Hyper Parameters:
# Load few-shot classifier cleaned dataset
data_path = "/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2/"
learning_rate = 0.001
weight_decay = 0.1
batch_size = 33
num_epochs = 25

# Generate Datasets & Loaders
data_loaders_dict = {phase: wound_dataloader(Wound_Dataset(data_path+"/"+phase, transform=data_transforms[phase]), 
                                             batch_size, num_workers = 0, shuffle = True, pin_memory = True) for phase in ['train', 'val']}

# Create Model
model = SupConClassifier()
model.to(device)
# model.load_encoder_wts(device)
# Path to encoder wts
path_to_wts = "/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Weights/meta_learned_ENCODER.tar" 
model.encoder.load_state_dict(torch.load(path_to_wts))

# Specify optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# lossfun = nn.CrossEntropyLoss()
# lossfun = SupervisedContrastiveLoss()
lossfun = losses.SupConLoss(temperature = 0.1).to(device)

['/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART065_5_NonHealer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART060_35_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART060_23_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART047_33_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART060_49_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART060_28_Healer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART046_21_NonHealer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART046_56_NonHealer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART065_12_NonHealer.png', '/content/drive/MyDrive/HealNet/Healer vs. Non Healer/Few_Shot_V2//train/BAART046_49_NonHealer.png', '/con

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


In [28]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+------------------------+------------+
|        Modules         | Parameters |
+------------------------+------------+
| projection.lin1.weight |    256     |
|  projection.lin1.bias  |     16     |
|  projection.bn.weight  |     16     |
|   projection.bn.bias   |     16     |
| projection.lin2.weight |    512     |
+------------------------+------------+
Total Trainable Params: 816


816

In [None]:
healnet, hist = train_model(model, data_loaders_dict, criterion=lossfun, optimizer = optimizer, num_epochs=num_epochs)

Epoch 0/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 5.3126


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 4.3591

Epoch 1/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 5.2021


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 4.5099

Epoch 2/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 5.0960


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 4.5742

Epoch 3/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 4.9944


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 4.8022

Epoch 4/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 4.8973


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 4.8310

Epoch 5/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 4.8046


batches:   0%|          | 0/5 [00:00<?, ?it/s]

val Loss: 4.8162

Epoch 6/24
----------


batches:   0%|          | 0/1 [00:00<?, ?it/s]

train Loss: 4.7161


batches:   0%|          | 0/5 [00:00<?, ?it/s]