In [1]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

I_FOLD = 0
train = pd.read_csv(f"train_fold_{I_FOLD}.csv")
validation = pd.read_csv(f"val_fold_{I_FOLD}.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
validation['tile_path'] = validation['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,8280,HGSC,2964,2964,True,../tiles_768/8280
1,8985,LGSC,65003,31754,False,../tiles_768/8985
2,9183,LGSC,74091,34185,False,../tiles_768/9183
3,9200,MC,3388,3388,True,../tiles_768/9200
4,10252,LGSC,49053,39794,False,../tiles_768/10252


In [2]:
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.models import VisionTransformer
from timm.layers import SwiGLUPacked
from timm.models.layers import DropPath
import copy
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

class CustomViT(nn.Module):
    def __init__(self, n_classes=5, embed_dim=768):
        super().__init__()
        self.n_classes = n_classes
        self.embed_dim = embed_dim
        self.patch_size = 16
        # Load the base ViT model
        self.base_model = VisionTransformer(
            img_size=512, 
            num_classes=self.n_classes, 
            patch_size=self.patch_size, 
            embed_dim=self.embed_dim, 
            depth=12, 
            num_heads=12, 
            global_pool='avg', 
            pre_norm=True, 
            act_layer=nn.SiLU
        )

        # Initialize a learnable mask token
        self.mask_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
        
        max_drop_path_rate = 0.4
        dropout_rate = 0.1

        drop_path_rates = [x.item() for x in torch.linspace(0, max_drop_path_rate, len(self.base_model.blocks))]

        # Assign drop path rates
        for i, block in enumerate(self.base_model.blocks):
            block.drop_path1 = DropPath(drop_prob=drop_path_rates[i])
            block.drop_path2 = DropPath(drop_prob=drop_path_rates[i])
            block.attn.attn_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.attn.proj_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop1 = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop2 = nn.Dropout(p=dropout_rate, inplace=False)
        self.head_dropout = nn.Dropout(p=dropout_rate, inplace=False)

        self.class_token_head = nn.Linear(self.embed_dim, self.n_classes)
        self.patch_token_head = nn.Linear(self.embed_dim, self.n_classes) 
        
        self.base_model.patch_embed.img_size = None

    def resize(self, new_img_size):
        # Calculate the size of the grid of patches
        num_patches_side = new_img_size // self.patch_size
        num_patches = num_patches_side ** 2

        # Extract the original positional embeddings, excluding the class token
        pos_embed = self.base_model.pos_embed
        old_num_patches_side = int((pos_embed.size(1) - 1) ** 0.5)
        pos_grid = pos_embed[:, 1:].reshape(1, old_num_patches_side, old_num_patches_side, -1)
        pos_grid = pos_grid.permute(0, 3, 1, 2).contiguous()
        
        # Resize using bilinear interpolation (make sure to keep the embedding dimension unchanged)
        new_pos_grid = F.interpolate(pos_grid, size=(num_patches_side, num_patches_side), mode='bilinear', align_corners=False)

        # Flatten the grid back to a sequence and re-add the class token
        new_pos_embed = torch.cat([pos_embed[:, :1], new_pos_grid.permute(0, 2, 3, 1).contiguous().view(1, num_patches_side * num_patches_side, -1)], dim=1)

        # Update the positional embeddings
        self.base_model.pos_embed = nn.Parameter(new_pos_embed)
        self.base_model.patch_embed.img_size = (new_img_size, new_img_size)
        
    def forward_features(self, x, mask=None):
        # Get the patch embeddings (excluding the class token)
        x = self.base_model.patch_embed(x)
        
        to_cat = []
        if self.base_model.cls_token is not None:
            to_cat.append(self.base_model.cls_token.expand(x.shape[0], -1, -1))
        x = torch.cat(to_cat + [x], dim=1)

        # Handle masked patches if a mask is provided
        if mask is not None:
            # Adjust mask to account for the class token
            mask = torch.cat([torch.zeros(x.shape[0], 1).bool().to(mask.device), mask], dim=1)
            # Expand mask token to match the batch size and masked patches
            mask_tokens = self.mask_token.expand(x.size(0), -1, -1)
            # Apply the mask - replace masked patches with the mask token
            x = torch.where(mask.unsqueeze(-1), mask_tokens, x)
        
        x = self.base_model.pos_drop(x + self.base_model.pos_embed)
        x = self.base_model.norm_pre(x)
        x = self.base_model.blocks(x)
        x = self.base_model.norm(x)

        # Exclude the class token and return the patch representations
        return x

    def forward_head(self, x):
        class_token, patch_tokens = x[:, :1], x[:, 1:]

        # Apply dropout
        class_token = self.head_dropout(class_token)
        patch_tokens = self.head_dropout(patch_tokens)

        # Process class token and patch tokens through their respective heads
        class_token_output = self.class_token_head(class_token)
        patch_token_output = self.patch_token_head(patch_tokens)
        
        x = torch.cat([class_token_output, patch_token_output], dim=1)

        return x

    def forward(self, x, mask=None):
        x = self.forward_features(x, mask=mask)
        x = self.forward_head(x)
        return x

IMAGE_SIZE = 512
PATCH_SIZE = 16

In [3]:
class ClassifierModel(nn.Module):
    def __init__(self, n_classes=5):
        super().__init__()
        self.d_model = 768
        self.n_classes = n_classes
        
        self.base_model = CustomViT(n_classes=16 * 16 * 3, embed_dim=self.d_model)
        
        self.layer_norm = nn.LayerNorm(self.d_model)
        self.linear = nn.Linear(self.d_model, self.n_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = x[:, 1:].mean(dim=1)
        x = self.layer_norm(x)
        x = self.linear(x)
        return x

N_CLASSES = 5
classifier_model = ClassifierModel(N_CLASSES)
classifier_model = classifier_model.to(device)
state_dict = torch.load('vit_mim_upscale_models/fold_0/ema_model_epoch_1_step_25000.pth', map_location=device)
classifier_model.base_model.load_state_dict(state_dict, strict=False)
classifier_model = classifier_model.to(device)

# Initialize EMA model
ema_decay = 0.999  # decay factor for EMA
ema_classifier_model = copy.deepcopy(classifier_model)
# state_dict = torch.load('vit_mim_upscale_models/fold_0/ema_model_epoch_0_step_9000.pth', map_location=device)
# ema_model.load_state_dict(state_dict, strict=False)
ema_classifier_model = ema_classifier_model.to(device)

In [4]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.all_images = []  # Store all images in an interlaced fashion
        self.folder_images = []  # Temporary storage for images from each folder

        # Step 1: Collect all images from each folder
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            image_id = row['image_id']
            if os.path.isdir(folder_path):
                image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith('.png')]
                random.shuffle(image_files)
                self.folder_images.append([(image_file, label, image_id) for image_file in image_files])

        # Step 2: Interlace the images
        max_length = max(len(images) for images in self.folder_images)
        for i in range(max_length):
            for images in self.folder_images:
                self.all_images.append(images[i % len(images)])

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        image_path, label, image_id = self.all_images[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)

        return image, label_to_integer[label], image_id

In [5]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms

BATCH_SIZE = 8

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=IMAGE_SIZE, scale=(0.75, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=(0, 360)),
    transforms.RandomAffine(degrees=0, shear=(-20, 20, -20, 20)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3),
    transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=0.25),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 1))], p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# val_transform = transforms.Compose([
#     transforms.Resize(448),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),

train_dataset = ImageDataset(dataframe=train, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

In [6]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/vit_mim_upscale_finetune/fold_{I_FOLD}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [None]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F

initial_lr = 0.0005 * BATCH_SIZE/256
final_lr = initial_lr * 0.01
num_epochs = 1000

# Function for linear warmup
def learning_rate(step, warmup_steps=2500, max_steps=25000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

def update_ema_variables(model, ema_model, alpha):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

save_dir = f"vit_mim_upscale_finetune_models/fold_{I_FOLD}"
os.makedirs(save_dir, exist_ok=True)

scaler = GradScaler()
optimizer = optim.AdamW(classifier_model.parameters(), lr=initial_lr, weight_decay=5e-2)
# state_dict = torch.load('vit_mim_upscale_models/fold_0/optimizer_epoch_0_step_9000.pth', map_location=device)
# optimizer.load_state_dict(state_dict)

# Calculate weights for each class
class_counts = np.array([train.groupby('label').count().loc[label]['image_id'] for label in label_to_integer], dtype=np.float32)
class_weights = 1. / class_counts
class_weights /= class_weights.sum()
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

criterion = torch.nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

best_val_accuracy = 0.0
step = 0

classifier_model.train()
for epoch in range(num_epochs):
    for i, (images, labels, _) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        images = images.to(device)
        labels = labels.to(device)

        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with autocast
        with autocast():
            outputs = classifier_model(images)
            logits_per_image = outputs
            loss = criterion(logits_per_image, labels)
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        update_ema_variables(classifier_model, ema_classifier_model, ema_decay)

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if step % 500 == 0:
            # Save the EMA model
            ema_model_save_path = os.path.join(save_dir, f'ema_classifier_model_epoch_{epoch}_step_{step}.pth')
            torch.save(ema_classifier_model.state_dict(), ema_model_save_path)
            
            # Save the model
            model_save_path = os.path.join(save_dir, f'classifier_model_epoch_{epoch}_step_{step}.pth')
            torch.save(classifier_model.state_dict(), model_save_path)

            # Save the optimizer
            optimizer_save_path = os.path.join(save_dir, f'optimizer_epoch_{epoch}_step_{step}.pth')
            torch.save(optimizer.state_dict(), optimizer_save_path)

            logging.info(f'Model and optimizer saved after epoch {epoch} and step {step}')

        step += 1
