In [1]:
MODEL_NAME = 'cutmix_mixup_final_try_all_the_data'

In [2]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

train = pd.read_csv(f"../data/train.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,4,HGSC,23785,20008,False,../tiles_768/4
1,66,LGSC,48871,48195,False,../tiles_768/66
2,91,HGSC,3388,3388,True,../tiles_768/91
3,281,LGSC,42309,15545,False,../tiles_768/281
4,286,EC,37204,30020,False,../tiles_768/286


In [3]:
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.models.layers import DropPath
import copy
from itertools import cycle

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k"

print(f"Using device {device} and model {model_name}")

model = timm.create_model(model_name, pretrained=True)

# 25000 cutmix_mixup was 0.3 and 0.1
drop_path_rate = 0.5
dropout_rate = 0.
head_dropout_rate = 0.3
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, len(model.blocks))]

# Assign drop path rates
for i, block in enumerate(model.blocks):
    block.drop_path1 = DropPath(drop_prob=drop_path_rates[i])
    block.drop_path2 = DropPath(drop_prob=drop_path_rates[i])
    block.attn.attn_drop = nn.Dropout(p=dropout_rate, inplace=False)
    block.attn.proj_drop = nn.Dropout(p=dropout_rate, inplace=False)
    block.mlp.drop1 = nn.Dropout(p=dropout_rate, inplace=False)
    block.mlp.drop2 = nn.Dropout(p=dropout_rate, inplace=False)

model.head = nn.Linear(model.head.in_features, 5)
model.pos_drop = nn.Dropout(dropout_rate)
model.head_drop = nn.Dropout(head_dropout_rate)

model_location = 'eva02_base_models_cutmix_mixup_final_try_all_the_data/ema_0.999_step_54000.pth'
state_dict = torch.load(model_location, map_location=device)
model.load_state_dict(state_dict, strict=False)
model = model.to(device)

# Initialize EMA model
ema_decays = [0.999, 0.9995, 0.9998, 0.9999, 0.99995, 0.99998, 0.99999]
ema_models = [copy.deepcopy(model) for _ in range(len(ema_decays))]
for i_ema, ema_model in enumerate(ema_models):
    model_location = f'eva02_base_models_cutmix_mixup_final_try_all_the_data/ema_{ema_decays[i_ema]}_step_54000.pth'
    state_dict = torch.load(model_location, map_location=device)
    ema_model.load_state_dict(state_dict, strict=False)
    ema_model = ema_model.to(device)
    ema_model.eval()

Using device cuda and model timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

In [5]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random
    
class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.all_images = []  # Store all images in an interlaced fashion
        self.label_images = [[] for _ in range(5)]  # Temporary storage for images from each folder

        # Step 1: Collect all images from each folder
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            image_id = row['image_id']
            if os.path.isdir(folder_path):
                image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith('.png')]
                random.shuffle(image_files)
                self.all_images.extend([(image_file, label, image_id) for image_file in image_files])
        
        random.shuffle(self.all_images)

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        image_path, label, image_id = self.all_images[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)

        return image, label_to_integer[label], image_id

In [6]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
import torchvision.transforms.v2 as v2
from torch.utils.data import default_collate

BATCH_SIZE = 16

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=448, scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    transforms.RandAugment(9, 15, 31),
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[
        0.48145466,
        0.4578275,
        0.40821073
    ], std=[
        0.26862954,
        0.26130258,
        0.27577711
    ]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
])

val_transform = transforms.Compose([
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[
        0.48145466,
        0.4578275,
        0.40821073
    ], std=[
        0.26862954,
        0.26130258,
        0.27577711
    ]),
])

train_dataset = ImageDataset(dataframe=train, transform=train_transform)

cutmix = v2.CutMix(num_classes=5)
mixup = v2.MixUp(num_classes=5)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

def collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=7, collate_fn=collate_fn)

In [7]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/eva02_base_train_{MODEL_NAME}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [None]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast

initial_lr = 1e-5
final_lr = 1e-7
num_epochs = 10000

# Function for linear warmup
def learning_rate(step, warmup_steps=5000, max_steps=200000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

def update_ema_variables(model, ema_model, alpha, global_step):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

scaler = GradScaler()
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=1e-7)

# Class counts: {'CC': 146576, 'EC': 165037, 'HGSC': 272477, 'LGSC': 46346, 'MC': 79036}
class_counts = torch.tensor([272477, 146576, 165037, 46346, 79036], dtype=torch.float32)

# Calculate weights: Inverse of class frequencies
weights = 1.0 / class_counts
weights = weights / weights.sum()  # Normalize to make the sum of weights equal to 1
weights = weights.to(device)  # Move weights to the device (CPU/GPU)

# Define the weighted loss function
criterion = torch.nn.CrossEntropyLoss(weight=weights)

best_val_accuracy = 0.0
step = 54001

for epoch in range(num_epochs):
    model.train()  # set the model to training mode
    
    for i, (images, labels, _) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        images = images.to(device)
        labels = labels.to(device)
        
        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            # g['lr'] = g['lr'] * lr / initial_lr
            g['lr'] = lr

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with autocast():
            outputs = model(images)
            logits_per_image = outputs
            loss = criterion(logits_per_image, labels)
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        for i_ema, ema_model in enumerate(ema_models):
            update_ema_variables(model, ema_model, ema_decays[i_ema], step)

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if step % 1000 == 0:
            for i_ema, ema_model in enumerate(ema_models):
                torch.save(ema_model.state_dict(), f'eva02_base_models_{MODEL_NAME}/ema_{ema_decays[i_ema]}_step_{step}.pth')
            logging.info(f'Models saved after epoch {epoch} and step {step}')\

            model.train()

        if step == 200000:
            break
    
        step += 1

    if step >= 200000:
        break

In [None]:
import shutil
import os

def duplicate_ipynb_with_new_name(src_file_path, dest_dir, new_name):
    """
    Duplicate an IPython notebook file to a new directory with a new file name.

    Parameters:
    src_file_path (str): The path of the source IPython notebook file.
    dest_dir (str): The destination directory where the file will be copied.
    new_name (str): The new file name for the duplicated notebook.

    Returns:
    str: The path of the duplicated file with the new name.
    """
    # Check if the new name contains the '.ipynb' extension, add if not
    if not new_name.endswith('.ipynb'):
        new_name += '.ipynb'

    # Creating the destination file path with the new name
    dest_file_path = os.path.join(dest_dir, new_name)

    # Copying the file to the new directory
    shutil.copy(src_file_path, dest_file_path)

    return dest_file_path

src_file = "eva02-base-finetune.ipynb"
dest_directory = "notebook_history"
new_filename = f"{MODEL_NAME}.ipynb"
duplicate_ipynb_with_new_name(src_file, dest_directory, new_filename)


In [None]:
torch.cuda.empty_cache()

In [None]:
import random
from sklearn.metrics import balanced_accuracy_score

ema_model = model
state_dict = torch.load('eva02_base_models_25000_cutmix_mixup/ema_0.9998_step_50000.pth', map_location=device)
ema_model.load_state_dict(state_dict, strict=False)
ema_model = ema_model.to(device)
ema_model.eval()

# Maximum number of tiles per image
MAX_TILES_PER_IMAGE = 8

# Maximum number of images per batch
MAX_IMAGES_PER_BATCH = 8  # Adjust based on model capacity and memory constraints

image_ids = []
logits = []
labels = []

with torch.no_grad():
    # Temporary storage for the current batch
    batch_tiles = []
    batch_image_ids = []
    batch_labels = []

    for _, row in train.iterrows():
        print(row['image_id'])
        path = row['tile_path']
        all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]
        
        # Randomly sample tiles from this image
        sample_size = min(MAX_TILES_PER_IMAGE, len(all_files))
        sampled_files = random.sample(all_files, sample_size)

        image_tiles = []
        for image_name in sampled_files:
            image_path = os.path.join(path, image_name)
            sub_image = Image.open(image_path)
            tile = val_transform(sub_image).unsqueeze(0)
            image_tiles.append(tile)

        # Add this image's tiles to the batch
        batch_tiles.append(torch.concat(image_tiles, dim=0))
        batch_image_ids.append(row['image_id'])
        batch_labels.append(row['label'])

        # Process the batch if it's full or this is the last row
        if len(batch_tiles) == MAX_IMAGES_PER_BATCH or row.equals(train.iloc[-1]):
            # Concatenate tiles from each image in the batch
            batch_input = torch.concat(batch_tiles, dim=0).to(device)

            # Run the batch through the model
            batch_output = ema_model(batch_input)

            # Split the outputs back into per-image groups and store them
            start = 0
            for i, tiles in enumerate(batch_tiles):
                end = start + tiles.shape[0]
                logits.append(batch_output[start:end])
                start = end

            image_ids.extend(batch_image_ids)
            labels.extend(batch_labels)

            # Reset the batch
            batch_tiles = []
            batch_image_ids = []
            batch_labels = []
        

In [None]:
predictions = []
for image_logits in logits:
    argmax_indices = torch.argmax(image_logits, dim=1)

    frequency_counts = torch.bincount(argmax_indices, minlength=5)

    max_vote_key = frequency_counts.argmax().cpu().item()
    predictions.append(integer_to_label[max_vote_key])

plurality_accuracy = balanced_accuracy_score(labels, predictions)
plurality_accuracy

In [None]:
predictions = []
for image_logits in logits:
    summed_logits = image_logits.sum(dim=0)
    
    max_vote_key = summed_logits.argmax().cpu().item()
    predictions.append(integer_to_label[max_vote_key])

logit_sum_accuracy = balanced_accuracy_score(labels, predictions)
logit_sum_accuracy

In [None]:
predictions = []
for image_logits in logits:
    summed_probs = image_logits.softmax(dim=1).sum(dim=0)
    
    max_vote_key = summed_probs.argmax().cpu().item()
    predictions.append(integer_to_label[max_vote_key])

prob_sum_accuracy = balanced_accuracy_score(labels, predictions)
prob_sum_accuracy

In [None]:
predictions = []
for image_logits in logits:
    summed_log_probs = torch.log(image_logits.softmax(dim=1)).sum(dim=0)
    
    max_vote_key = summed_log_probs.argmax().cpu().item()
    predictions.append(integer_to_label[max_vote_key])

log_prob_sum_accuracy = balanced_accuracy_score(labels, predictions)
log_prob_sum_accuracy

In [None]:
predictions = []
for image_logits in logits:
    summed_one_minus_log_probs = torch.log(1 - image_logits.softmax(dim=1)).sum(dim=0)
    
    min_vote_key = summed_one_minus_log_probs.argmin().cpu().item()
    predictions.append(integer_to_label[min_vote_key])

one_minus_log_prob_sum_accuracy = balanced_accuracy_score(labels, predictions)
one_minus_log_prob_sum_accuracy

In [None]:
predictions = []
for i, image_logits in enumerate(logits):
    summed_log_probs = torch.log(image_logits.softmax(dim=1)).sum(dim=0)
    summed_one_minus_log_probs = torch.log(1 - image_logits.softmax(dim=1)).sum(dim=0)
    if summed_log_probs.argmax().cpu().item() != summed_one_minus_log_probs.argmin().cpu().item():
        print(i, labels[i], integer_to_label[summed_log_probs.argmax().cpu().item()], integer_to_label[summed_one_minus_log_probs.argmin().cpu().item()], summed_log_probs, summed_one_minus_log_probs)


In [None]:
from collections import defaultdict

def borda_count_winner(image_logits):
    num_classes = image_logits.size(1)
    borda_scores = defaultdict(int)

    # Rank each class for each image_logits and assign Borda points
    for logits in image_logits:
        # Get ranks (in descending order of logits)
        ranks = torch.argsort(logits, descending=True)

        # Assign Borda points (highest rank gets num_classes - 1 points, next gets num_classes - 2, ...)
        for rank, class_index in enumerate(ranks):
            borda_scores[class_index.item()] += num_classes - 1 - rank

    # Find the class with the highest total Borda score
    borda_winner = max(borda_scores, key=borda_scores.get)
    return borda_winner

predictions = []
for image_logits in logits:
    borda_winner = borda_count_winner(image_logits)
    predictions.append(integer_to_label[borda_winner])
    
borda_accuracy = balanced_accuracy_score(labels, predictions)
borda_accuracy

In [None]:
def instant_runoff_winner(image_logits):
    num_voters, num_classes = image_logits.shape
    active_candidates = set(range(num_classes))

    while True:
        # Count the first-preference votes for each candidate
        first_pref_counts = torch.zeros(num_classes)
        for logits in image_logits:
            for rank in torch.argsort(logits, descending=True):
                if rank.item() in active_candidates:
                    first_pref_counts[rank] += 1
                    break

        # Check if any candidate has more than 50% of the votes
        if torch.any(first_pref_counts > num_voters / 2):
            return torch.argmax(first_pref_counts).item()

        # Find the candidate with the fewest votes among active candidates
        active_candidates_votes = first_pref_counts[list(active_candidates)]
        min_votes, min_index = torch.min(active_candidates_votes, 0)
        min_candidate = list(active_candidates)[min_index.item()]

        # Eliminate the candidate with the fewest votes
        active_candidates.remove(min_candidate)
        
predictions = []
for image_logits in logits:
    irv_winner = instant_runoff_winner(image_logits)
    predictions.append(integer_to_label[irv_winner])
    
instant_runoff_accuracy = balanced_accuracy_score(labels, predictions)
instant_runoff_accuracy

In [None]:
def calculate_minimax_winner(image_logits):
    num_voters, num_classes = image_logits.shape
    max_regrets = torch.zeros(num_classes)

    # Perform pairwise comparisons between all classes
    for i in range(num_classes):
        for j in range(i + 1, num_classes):
            # Count how many voters prefer class i over class j and vice versa
            votes_for_i = torch.sum(image_logits[:, i] > image_logits[:, j])
            votes_for_j = num_voters - votes_for_i

            # Update the maximum regret for each class
            max_regrets[i] = max(max_regrets[i], votes_for_j)
            max_regrets[j] = max(max_regrets[j], votes_for_i)

    # The winner is the class with the smallest maximum regret
    return torch.argmin(max_regrets).item()
        
predictions = []
for image_logits in logits:
    minimax_winner = calculate_minimax_winner(image_logits)
    predictions.append(integer_to_label[minimax_winner])
    
minimax_accuracy = balanced_accuracy_score(labels, predictions)
minimax_accuracy

In [None]:
from collections import defaultdict

def find_winner(graph, source, visited):
    """Helper function to find the winner in the graph."""
    if source not in graph:
        return False

    visited.add(source)
    for target in graph[source]:
        if target not in visited and find_winner(graph, target, visited):
            return True
    visited.remove(source)
    return False

def ranked_pairs_winner(image_logits):
    num_voters, num_classes = image_logits.shape
    margins = defaultdict(int)

    # Perform pairwise comparisons
    for i in range(num_classes):
        for j in range(num_classes):
            if i != j:
                votes_for_i = torch.sum(image_logits[:, i] > image_logits[:, j])
                votes_for_j = num_voters - votes_for_i
                margins[(i, j)] = votes_for_i - votes_for_j

    # Sort pairs by margin of victory
    sorted_pairs = sorted(margins, key=margins.get, reverse=True)

    # Initialize graph for locking pairs
    graph = {i: set() for i in range(num_classes)}
    for pair in sorted_pairs:
        winner, loser = pair
        graph[winner].add(loser)
        visited = set()

        # Check for cycle
        if find_winner(graph, loser, visited):
            graph[winner].remove(loser)

    # Determine the winner
    for i in range(num_classes):
        if not any(i in targets for targets in graph.values()):
            return i

    # Fallback: Choose the class with the highest total votes if no clear winner is found
    total_votes = torch.sum(image_logits, axis=0)
    return torch.argmax(total_votes).item()
        
predictions = []
for image_logits in logits:
    rp_winner = ranked_pairs_winner(image_logits)
    predictions.append(integer_to_label[rp_winner])
    
rp_accuracy = balanced_accuracy_score(labels, predictions)
rp_accuracy

In [None]:
from collections import defaultdict

def star_voting_winner(image_logits):
    num_voters, num_classes = image_logits.shape

    # Step 1: Sum the scores for each candidate
    total_scores = torch.sum(image_logits, axis=0)

    # Step 2: Find the two candidates with the highest total scores
    top_two = torch.topk(total_scores, 2).indices

    # Step 3: Runoff between the top two candidates
    first_choice_votes = torch.sum(image_logits[:, top_two[0]] > image_logits[:, top_two[1]])
    second_choice_votes = num_voters - first_choice_votes

    # Determine the winner
    if first_choice_votes > second_choice_votes:
        return top_two[0].item()
    else:
        return top_two[1].item()
        
predictions = []
for image_logits in logits:
    star_winner = ranked_pairs_winner(image_logits)
    predictions.append(integer_to_label[star_winner])
    
star_accuracy = balanced_accuracy_score(labels, predictions)
star_accuracy