In [1]:
MODEL_NAME = 'third_try'

In [2]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

train = pd.read_csv(f"../data/train.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,4,HGSC,23785,20008,False,../tiles_768/4
1,66,LGSC,48871,48195,False,../tiles_768/66
2,91,HGSC,3388,3388,True,../tiles_768/91
3,281,LGSC,42309,15545,False,../tiles_768/281
4,286,EC,37204,30020,False,../tiles_768/286


In [3]:
from timm.models.vision_transformer import Block
import torch
import torch.nn as nn
import timm
import copy

device = "cuda" if torch.cuda.is_available() else "cpu"

class GlobalModel(nn.Module):
    def __init__(self, n_heads, n_layers, embed_dim, n_classes):
        super().__init__()
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.embed_dim = embed_dim
        self.n_classes = n_classes
        
        drop_path_rate = 0.1
        drop_out_rate = 0.1
        
        local_model_name = "timm/tiny_vit_21m_224.dist_in22k_ft_in1k"
        self.local_model = timm.create_model(local_model_name, pretrained=True)

        for stage in self.local_model.stages:
            if hasattr(stage, 'blocks'):
                for block in stage.blocks:
                    if hasattr(block, 'mlp'):
                        block.mlp.drop1 = nn.Dropout(p=drop_out_rate, inplace=False)
                        block.mlp.drop2 = nn.Dropout(p=drop_out_rate, inplace=False)

        self.local_model.head.drop = nn.Dropout(p=drop_out_rate, inplace=False)
        self.local_model.head.fc = nn.Identity()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.patch_embed = nn.Linear(self.local_model.head.in_features, embed_dim)
        self.norm_pre = nn.LayerNorm(embed_dim)
        self.drop_pre = nn.Dropout(p=drop_out_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim,
                num_heads=n_heads,
                proj_drop=drop_out_rate,
                attn_drop=drop_out_rate,
                drop_path=dpr[i]
            )
            for i in range(n_layers)])
        self.norm_post = nn.LayerNorm(embed_dim)
        self.head_drop = nn.Dropout(p=drop_out_rate)
        self.fc_head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        batch_size = x.shape[0]
        local_samples = x.shape[1]
        x = x.view(batch_size * local_samples, x.shape[2], x.shape[3], x.shape[4])
        x = self.local_model(x)
        x = x.view(batch_size, local_samples, -1)
        x = self.patch_embed(x)
        x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
        
        x = self.norm_pre(x)
        x = self.drop_pre(x)
        x = self.blocks(x)
        
        x = x[:, 1:].mean(dim=1)
        x = self.norm_post(x)
        x = self.head_drop(x)
        x = self.fc_head(x)
        
        return x

model = GlobalModel(n_heads=3, n_layers=12, embed_dim=192, n_classes=5)
state_dict = torch.load('tinyvit_models_global_attention_higher_lr/epoch_0_step_20000.pth', map_location=device)
model.load_state_dict(state_dict, strict=False)
model = model.to(device)

ema_decay = 0.999
ema_model = copy.deepcopy(model)
ema_model = ema_model.to(device)

In [4]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

LOCAL_SAMPLES = 16

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform):
        self.dataframe = dataframe
        self.transform = transform
        self.images_by_label_integer = {i: [] for i in range(5)}

        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            image_id = row['image_id']
            image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith('.png')]
            self.images_by_label_integer[label_to_integer[label]].append(image_files)

    def __len__(self):
        return 1_000_000_000

    def __getitem__(self, idx):
        n_labels = random.randint(1, 5)
        labels = [0, 1, 2, 3, 4]
        random.shuffle(labels)
        labels = labels[:n_labels]

        n_from_each_label = [0, 0, 0, 0, 0]
        for _ in range(LOCAL_SAMPLES):
            n_from_each_label[random.choice(labels)] += 1

        images = []
        for i in range(len(n_from_each_label)):
            num_samples = n_from_each_label[i]
            for _ in range(num_samples):
                images_list = self.images_by_label_integer[i][random.randint(0, len(self.images_by_label_integer[i]) - 1)]
                selected_image = self.transform(Image.open(images_list[random.randint(0, len(images_list) - 1)]))
                selected_image = selected_image.unsqueeze(0)
                images.append(selected_image)
        
        images = torch.cat(images, dim=0)
        label = torch.tensor(n_from_each_label) / LOCAL_SAMPLES
        
        return images, label

In [None]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms

BATCH_SIZE = 4

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.75, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=(0, 360)),
    transforms.RandomAffine(degrees=0, shear=(-20, 20, -20, 20)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.3, hue=0.3),
    transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=0.25),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 1))], p=0.25),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[
        0.485,
        0.456,
        0.406
    ], std=[
        0.229,
        0.224,
        0.225
    ]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
])

val_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[
        0.485,
        0.456,
        0.406
    ], std=[
        0.229,
        0.224,
        0.225
    ]),
])

train_dataset = ImageDataset(dataframe=train, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=3)

In [None]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/tinyvit_train_{MODEL_NAME}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast

initial_lr = 0.0005 * (BATCH_SIZE * LOCAL_SAMPLES) / 1024
final_lr = initial_lr * 0.01
num_epochs = 10000

# Function for linear warmup
def learning_rate(step, warmup_steps=2500, max_steps=25000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

def update_ema_variables(model, ema_model, alpha):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

scaler = GradScaler()
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=1e-5)

criterion = nn.CrossEntropyLoss()

best_val_accuracy = 0.0
step = 0

for epoch in range(num_epochs):
    model.train()  # set the model to training mode
    
    for i, (images, labels) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        images = images.to(device)
        labels = labels.to(device)
        
        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with autocast
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        update_ema_variables(model, ema_model, ema_decay)

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if step % 1000 == 0:
            ema_model.eval()
            torch.save(ema_model.state_dict(), f'tinyvit_models_{MODEL_NAME}/epoch_{epoch}_step_{step}.pth')
            logging.info(f'Model saved after epoch {epoch} and step {step}')\

            model.train()

        if step == 25000:
            torch.save(ema_model.state_dict(), f'tinyvit_models_{MODEL_NAME}/final.pth')
            break

        step += 1
        
    if step >= 25000:
        break

In [None]:
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

In [None]:
import random
from sklearn.metrics import balanced_accuracy_score

ema_model.eval()

image_ids = []
logits = []
labels = []

with torch.no_grad():
    for _, row in train.iterrows():
        path = row['tile_path']
        all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]

        sum_probabilities = torch.zeros(5).to(device)
        sum_log_probabilities = torch.zeros(5).to(device)
        sum_log_neg_probabilities = torch.zeros(5).to(device)
        
        batch_logits = []

        # Prepare a list to hold image tiles
        batch_tiles = []

        sample_size = min(256, len(all_files))
        sampled_files = random.sample(all_files, sample_size)

        for image_name in sampled_files:
            image_path = os.path.join(path, image_name)
            sub_image = Image.open(image_path)

            tile = val_transform(sub_image).unsqueeze(0).to(device)
            batch_tiles.append(tile)
        
        for i_batch in range(0, len(batch_tiles), 256):
            model_input = torch.concat(batch_tiles[i_batch:i_batch+256], dim=0).unsqueeze(0)
            outputs = ema_model(model_input)
            probs = outputs.softmax(dim=1)
            batch_logits.append(outputs)
            sum_probabilities += probs.sum(dim=0)
            sum_log_probabilities += torch.log(probs).sum(dim=0)
            sum_log_neg_probabilities += torch.log(1 - probs).sum(dim=0)
        
        image_id = row['image_id']
        batch_logits = torch.concat(batch_logits, dim=0)
        mean_logits = batch_logits.mean(dim=0)
        label = row['label']
        print(batch_logits.shape, label, image_id, integer_to_label[mean_logits.argmax().cpu().item()], mean_logits)
        image_ids.append(image_id)
        logits.append(batch_logits)
        labels.append(label)
        

In [None]:
predictions = []
for image_logits in logits:
    summed_logits = image_logits.sum(dim=0)
    
    max_vote_key = summed_logits.argmax().cpu().item()
    predictions.append(integer_to_label[max_vote_key])

logit_sum_accuracy = balanced_accuracy_score(labels, predictions)
logit_sum_accuracy