In [1]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

I_FOLD = 0
train = pd.read_csv(f"train_fold_{I_FOLD}.csv")
validation = pd.read_csv(f"val_fold_{I_FOLD}.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
validation['tile_path'] = validation['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,8280,HGSC,2964,2964,True,../tiles_768/8280
1,8985,LGSC,65003,31754,False,../tiles_768/8985
2,9183,LGSC,74091,34185,False,../tiles_768/9183
3,9200,MC,3388,3388,True,../tiles_768/9200
4,10252,LGSC,49053,39794,False,../tiles_768/10252


In [2]:
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.models import VisionTransformer
from timm.layers import SwiGLUPacked
from timm.models.layers import DropPath
import copy

device = "cuda" if torch.cuda.is_available() else "cpu"

class CustomViT(nn.Module):
    def __init__(self, n_classes=5, embed_dim=768):
        super().__init__()
        self.n_classes = n_classes
        self.embed_dim = embed_dim
        # Load the base ViT model
        self.base_model = VisionTransformer(
            img_size=224, 
            num_classes=self.n_classes, 
            patch_size=16, 
            embed_dim=self.embed_dim, 
            depth=12, 
            num_heads=12, 
            global_pool='avg', 
            pre_norm=True, 
            act_layer=nn.SiLU
        )

        # Initialize a learnable mask token
        self.mask_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
        
        max_drop_path_rate = 0.4
        dropout_rate = 0.1

        drop_path_rates = [x.item() for x in torch.linspace(0, max_drop_path_rate, len(self.base_model.blocks))]

        # Assign drop path rates
        for i, block in enumerate(self.base_model.blocks):
            block.drop_path1 = DropPath(drop_prob=drop_path_rates[i])
            block.drop_path2 = DropPath(drop_prob=drop_path_rates[i])
            block.attn.attn_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.attn.proj_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop1 = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop2 = nn.Dropout(p=dropout_rate, inplace=False)
        self.head_dropout = nn.Dropout(p=dropout_rate, inplace=False)

        self.class_token_head = nn.Linear(self.embed_dim, self.n_classes)
        self.patch_token_head = nn.Linear(self.embed_dim, self.n_classes) 

    def forward_features(self, x, mask=None):
        # Get the patch embeddings (excluding the class token)
        x = self.base_model.patch_embed(x)
        
        to_cat = []
        if self.base_model.cls_token is not None:
            to_cat.append(self.base_model.cls_token.expand(x.shape[0], -1, -1))
        x = torch.cat(to_cat + [x], dim=1)

        # Handle masked patches if a mask is provided
        if mask is not None:
            # Adjust mask to account for the class token
            mask = torch.cat([torch.zeros(x.shape[0], 1).bool().to(mask.device), mask], dim=1)
            # Expand mask token to match the batch size and masked patches
            mask_tokens = self.mask_token.expand(x.size(0), -1, -1)
            # Apply the mask - replace masked patches with the mask token
            x = torch.where(mask.unsqueeze(-1), mask_tokens, x)
        
        x = self.base_model.pos_drop(x + self.base_model.pos_embed)
        x = self.base_model.norm_pre(x)
        x = self.base_model.blocks(x)
        x = self.base_model.norm(x)

        # Exclude the class token and return the patch representations
        return x

    def forward_head(self, x):
        class_token, patch_tokens = x[:, :1], x[:, 1:]

        # Apply dropout
        class_token = self.head_dropout(class_token)
        patch_tokens = self.head_dropout(patch_tokens)

        # Process class token and patch tokens through their respective heads
        class_token_output = self.class_token_head(class_token)
        patch_token_output = self.patch_token_head(patch_tokens)
        
        x = torch.cat([class_token_output, patch_token_output], dim=1)

        return x

    def forward(self, x, mask=None):
        x = self.forward_features(x, mask=mask)
        x = self.forward_head(x)
        return x

D_MODEL = 768
N_CLASSES = 16 * 16 * 3
model = CustomViT(n_classes=N_CLASSES, embed_dim=D_MODEL)

model = model.to(device)
state_dict = torch.load('vit_mim_models/fold_0/model_epoch_0_step_87000.pth', map_location=device)
model.load_state_dict(state_dict, strict=False)

# Initialize EMA model
ema_decay = 0.999  # decay factor for EMA
ema_model = copy.deepcopy(model)
ema_model = ema_model.to(device)
state_dict = torch.load('vit_mim_models/fold_0/ema_model_epoch_0_step_87000.pth', map_location=device)
ema_model.load_state_dict(state_dict, strict=False)

<All keys matched successfully>

In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.all_images = []  # Store all images in an interlaced fashion

        # Step 1: Collect all images from each folder
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            image_id = row['image_id']
            if os.path.isdir(folder_path):
                image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith('.png')]
                self.all_images.extend([(image_file, label, image_id) for image_file in image_files])

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        image_path, label, image_id = self.all_images[idx]
        image = Image.open(image_path)
        
        image = self.transform(image)

        return image, label_to_integer[label], image_id

In [4]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms

BATCH_SIZE = 32

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.75, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=(0, 360)),
    transforms.RandomAffine(degrees=0, shear=(-20, 20, -20, 20)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3),
    transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=0.25),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 1))], p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

val_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

train_dataset = ImageDataset(dataframe=train, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

In [5]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/vit_mim/fold_{I_FOLD}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [None]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F

initial_lr = 0.0005 * BATCH_SIZE/256
final_lr = initial_lr * 0.01
num_epochs = 10000

# Function for linear warmup
def learning_rate(step, warmup_steps=10000, max_steps=100000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

def update_ema_variables(model, ema_model, alpha, global_step):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

save_dir = f"vit_mim_models/fold_{I_FOLD}"
os.makedirs(save_dir, exist_ok=True)

scaler = GradScaler()
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=5e-2)
state_dict = torch.load('vit_mim_models/fold_0/optimizer_epoch_0_step_87000.pth', map_location=device)
optimizer.load_state_dict(state_dict)

best_val_accuracy = 0.0
step = 87001

model.train()
for epoch in range(num_epochs):
    logging.info('start of epoch time!!! :D')
    for i, (images, _, _) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        images = images.to(device)

        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with autocast
        with autocast():
            mask = torch.rand((BATCH_SIZE, (224 // 16) ** 2)) < 0.5
            mask = mask.to(device)
            outputs = model(images, mask=mask) # torch.Size([1, 197, 768])
            outputs = outputs[:, 1:, :]
            masked_outputs = outputs[mask]

            patches = images.unfold(2, 16, 16).unfold(3, 16, 16)
            patches = patches.permute(0, 2, 3, 4, 5, 1).contiguous()
            patches = patches.reshape(BATCH_SIZE, (224 // 16) * (224 // 16), 16 * 16 * 3)
            masked_patches = patches[mask]
            
            # Calculating L2 loss (mean squared error) between the masked patches and reconstructed patches
            loss = F.mse_loss(masked_patches, masked_outputs)

        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        update_ema_variables(model, ema_model, ema_decay, step)

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if step % 1000 == 0:
            # Save the EMA model
            ema_model_save_path = os.path.join(save_dir, f'ema_model_epoch_{epoch}_step_{step}.pth')
            torch.save(ema_model.state_dict(), ema_model_save_path)
            
            # Save the model
            model_save_path = os.path.join(save_dir, f'model_epoch_{epoch}_step_{step}.pth')
            torch.save(ema_model.state_dict(), model_save_path)

            # Save the optimizer
            optimizer_save_path = os.path.join(save_dir, f'optimizer_epoch_{epoch}_step_{step}.pth')
            torch.save(optimizer.state_dict(), optimizer_save_path)

            logging.info(f'Model and optimizer saved after epoch {epoch} and step {step}')
            torch.cuda.empty_cache()

        step += 1

    logging.info('end of epoch time :/ emptying cache')
    torch.cuda.empty_cache()