In [1]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

I_FOLD = 4
train = pd.read_csv(f"train_fold_{I_FOLD}.csv")
validation = pd.read_csv(f"val_fold_{I_FOLD}.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
validation['tile_path'] = validation['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,4,HGSC,23785,20008,False,../tiles_768/4
1,66,LGSC,48871,48195,False,../tiles_768/66
2,91,HGSC,3388,3388,True,../tiles_768/91
3,281,LGSC,42309,15545,False,../tiles_768/281
4,286,EC,37204,30020,False,../tiles_768/286


In [2]:
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.models import VisionTransformer
from timm.layers import SwiGLUPacked
from timm.models.layers import DropPath
import copy

device = "cuda" if torch.cuda.is_available() else "cpu"

class CustomViT(nn.Module):
    def __init__(self, n_classes=5, embed_dim=384):
        super().__init__()
        self.n_classes = n_classes
        self.embed_dim = embed_dim
        # Load the base ViT model
        self.base_model = VisionTransformer(img_size=224, num_classes=self.n_classes, patch_size=16, embed_dim=self.embed_dim, depth=12, num_heads=6, global_pool='avg', pre_norm=True)

        # Initialize a learnable mask token
        self.mask_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
        
        max_drop_path_rate = 0.4
        dropout_rate = 0.1

        drop_path_rates = [x.item() for x in torch.linspace(0, max_drop_path_rate, len(self.base_model.blocks))]

        # Assign drop path rates
        for i, block in enumerate(self.base_model.blocks):
            block.drop_path1 = DropPath(drop_prob=drop_path_rates[i])
            block.drop_path2 = DropPath(drop_prob=drop_path_rates[i])
            block.attn.attn_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.attn.proj_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop1 = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop2 = nn.Dropout(p=dropout_rate, inplace=False)
        self.head_dropout = nn.Dropout(p=dropout_rate, inplace=False)

        self.class_token_head = nn.Linear(self.embed_dim, self.n_classes)
        self.patch_token_head = nn.Linear(self.embed_dim, self.n_classes) 

    def forward_features(self, x, mask=None):
        # Get the patch embeddings (excluding the class token)
        x = self.base_model.patch_embed(x)
        
        to_cat = []
        if self.base_model.cls_token is not None:
            to_cat.append(self.base_model.cls_token.expand(x.shape[0], -1, -1))
        x = torch.cat(to_cat + [x], dim=1)

        # Handle masked patches if a mask is provided
        if mask is not None:
            # Adjust mask to account for the class token
            mask = torch.cat([torch.zeros(1, 1).bool().to(mask.device), mask], dim=1)
            # Expand mask token to match the batch size and masked patches
            mask_tokens = self.mask_token.expand(x.size(0), -1, -1)
            # Apply the mask - replace masked patches with the mask token
            x = torch.where(mask.unsqueeze(-1), mask_tokens, x)
        
        x = self.base_model.pos_drop(x + self.base_model.pos_embed)
        x = self.base_model.norm_pre(x)
        x = self.base_model.blocks(x)
        x = self.base_model.norm(x)

        # Exclude the class token and return the patch representations
        return x

    def forward_head(self, x):
        class_token, patch_tokens = x[:, :1], x[:, 1:]

        # Apply dropout
        class_token = self.head_dropout(class_token)
        patch_tokens = self.head_dropout(patch_tokens)

        # Process class token and patch tokens through their respective heads
        class_token_output = self.class_token_head(class_token)
        patch_token_output = self.patch_token_head(patch_tokens)
        
        x = torch.cat([class_token_output, patch_token_output], dim=1)

        return x

    def forward(self, x, mask=None):
        x = self.forward_features(x, mask=mask)
        x = self.forward_head(x)
        return x

D_MODEL = 384
N_CLASSES = 8192
model = CustomViT(n_classes=N_CLASSES, embed_dim=D_MODEL)

model = model.to(device)

# Initialize EMA model
ema_decay = 0.9995  # decay factor for EMA
ema_model = copy.deepcopy(model)
ema_model = ema_model.to(device)

In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.all_images = []  # Store all images in an interlaced fashion
        self.folder_images = []  # Temporary storage for images from each folder

        # Step 1: Collect all images from each folder
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            image_id = row['image_id']
            if os.path.isdir(folder_path):
                image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith('.png')]
                random.shuffle(image_files)
                self.folder_images.append([(image_file, label, image_id) for image_file in image_files])

        # Step 2: Interlace the images
        max_length = max(len(images) for images in self.folder_images)
        for i in range(max_length):
            for images in self.folder_images:
                self.all_images.append(images[i % len(images)])

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        image_path, label, image_id = self.all_images[idx]
        image = Image.open(image_path)
        
        augment_a = self.transform(image)
        augment_b = self.transform(image)

        return augment_a, augment_b, label_to_integer[label], image_id

In [4]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms

BATCH_SIZE = 32

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.25, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=(0, 360)),
    transforms.RandomAffine(degrees=0, shear=(-20, 20, -20, 20)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
    transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=0.5),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2))], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

val_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

train_dataset = ImageDataset(dataframe=train, transform=transform)

# Calculate weights for each class
class_counts = [1 for label in label_to_integer] # equally weighted
num_samples = sum(class_counts)
class_weights = [num_samples / class_count for class_count in class_counts]

# Assign a weight to each sample in the dataset based on its class
sample_weights = [class_weights[label_to_integer[label]] for _, label, _ in train_dataset.all_images]

# Create WeightedRandomSampler
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# DataLoader with WeightedRandomSampler
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4)

In [5]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/vit_self_supervised/fold_{I_FOLD}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [6]:
class KoLeoLoss(nn.Module):
    """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""

    def __init__(self):
        super().__init__()
        self.pdist = nn.PairwiseDistance(2, eps=1e-8)

    def pairwise_NNs_inner(self, x):
        """
        Pairwise nearest neighbors for L2-normalized vectors.
        Uses Torch rather than Faiss to remain on GPU.
        """
        # parwise dot products (= inverse distance)
        dots = torch.mm(x, x.t())
        n = x.shape[0]
        dots.view(-1)[:: (n + 1)].fill_(-1)  # Trick to fill diagonal with -1
        # max inner prod -> min distance
        _, I = torch.max(dots, dim=1)  # noqa: E741
        return I

    def forward(self, student_output, eps=1e-8):
        """
        Args:
            student_output (BxD): backbone output of student
        """
        with torch.cuda.amp.autocast(enabled=False):
            student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
            I = self.pairwise_NNs_inner(student_output)  # noqa: E741
            distances = self.pdist(student_output, student_output[I])  # BxD, BxD -> B
            loss = -torch.log(distances + eps).mean()
        return loss

In [None]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F

initial_lr = 0.0005 * BATCH_SIZE/256
final_lr = initial_lr * 0.01
num_epochs = 10000

# Function for linear warmup
def learning_rate(step, warmup_steps=10000, max_steps=100000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

def update_ema_variables(model, ema_model, alpha, global_step):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

scaler = GradScaler()
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=5e-2)
koleo_loss = KoLeoLoss()

best_val_accuracy = 0.0
step = 0

student_temperature = 0.1
teacher_temperature = 0.04

center_momentum = 0.9
center_class = torch.zeros(1, N_CLASSES).to(device)
center_patch = torch.zeros(1, N_CLASSES).to(device)
for epoch in range(num_epochs):
    model.train()
    ema_model.eval()
    
    for i, (augment_a_images, augment_b_images, _, _) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        augment_a_images = augment_a_images.to(device)
        augment_b_images = augment_b_images.to(device)

        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with autocast
        with autocast():
            mask = torch.rand((1, (224 // 16) ** 2)) < 0.5
            mask = mask.to(device)
            student_a_outputs = model(augment_a_images, mask=mask)
            student_b_outputs = model(augment_b_images, mask=mask)
            with torch.no_grad():
                teacher_a_outputs = ema_model(augment_a_images)
                teacher_b_outputs = ema_model(augment_b_images)
            
            teacher_a_outputs = teacher_a_outputs.detach()
            teacher_b_outputs = teacher_b_outputs.detach()

            student_a_probs = F.softmax(student_a_outputs / student_temperature, dim=2)
            student_b_probs = F.softmax(student_b_outputs / student_temperature, dim=2)
            teacher_a_class_probs = F.softmax((teacher_a_outputs[:, :1] - center_class.unsqueeze(1)) / teacher_temperature, dim=2)
            teacher_a_patch_probs = F.softmax((teacher_a_outputs[:, 1:] - center_patch.unsqueeze(1)) / teacher_temperature, dim=2)
            teacher_b_class_probs = F.softmax((teacher_b_outputs[:, :1] - center_class.unsqueeze(1)) / teacher_temperature, dim=2)
            teacher_b_patch_probs = F.softmax((teacher_b_outputs[:, 1:] - center_patch.unsqueeze(1)) / teacher_temperature, dim=2)
            
            class_loss = - (teacher_a_class_probs * torch.log(student_b_probs[:, :1] + 1e-9)).sum(dim=2).mean()
            class_loss += - (teacher_b_class_probs * torch.log(student_a_probs[:, :1] + 1e-9)).sum(dim=2).mean()
            class_loss /= 2
            
            patch_loss = - (teacher_a_patch_probs * torch.log(student_a_probs[:, 1:] + 1e-9)).sum(dim=2).mean()
            patch_loss += - (teacher_b_patch_probs * torch.log(student_b_probs[:, 1:] + 1e-9)).sum(dim=2).mean()
            patch_loss /= 2

            loss = 0.495 * class_loss
            loss += 0.495 * patch_loss
            loss += 0.01 * (koleo_loss(student_a_outputs[:, 0]) + koleo_loss(student_a_outputs[:, 1:].reshape((224 // 16) ** 2 * BATCH_SIZE, N_CLASSES)))
            
            teacher_class_means = torch.cat([teacher_a_outputs[:, :1], teacher_b_outputs[:, :1]], dim=1).mean(dim=(0, 1))
            teacher_patch_means = torch.cat([teacher_a_outputs[:, 1:], teacher_b_outputs[:, 1:]], dim=1).mean(dim=(0, 1))
            
            center_class = center_momentum * center_class + (1 - center_momentum) * teacher_class_means
            center_patch = center_momentum * center_patch + (1 - center_momentum) * teacher_patch_means
            
            class_entropy = -torch.sum(teacher_a_class_probs * torch.log(teacher_a_class_probs + 1e-9), dim=2).mean()
            class_entropy += -torch.sum(teacher_b_class_probs * torch.log(teacher_b_class_probs + 1e-9), dim=2).mean()
            class_entropy /= 2
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        update_ema_variables(model, ema_model, ema_decay, step)

        # logging.info('[%d, %5d] loss: %.3f | class_loss: %.3f | patch_loss: %.3f | mean_koleo_distances: %.3f' % (epoch + 1, step, loss.item(), class_loss.item(), patch_loss.item(), mean_koleo_distances.item()))
        logging.info('[%d, %5d] loss: %.3f | class_loss: %.3f | patch_loss: %.3f | class_entropy: %.3f' % (epoch + 1, step, loss.item(), class_loss.item(), patch_loss.item(), class_entropy.item()))

        if step % 10000 == 0:
            ema_model.eval()
            torch.save(ema_model.state_dict(), f'vit_self_supervised_models/fold_{I_FOLD}/epoch_{epoch}_step_{step}.pth')
            logging.info(f'Model saved after epoch {epoch} and step {step}')\

            model.train()

        step += 1