In [1]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

# I_FOLD = 1
# validation_df = pd.read_csv(f"../folds/val_fold_{I_FOLD}.csv")
validation_df = pd.read_csv(f"../data/train.csv")

validation_df['tile_path'] = validation_df['image_id'].apply(lambda x: get_image_path(x))
validation_df.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,4,HGSC,23785,20008,False,../tiles_768/4
1,66,LGSC,48871,48195,False,../tiles_768/66
2,91,HGSC,3388,3388,True,../tiles_768/91
3,281,LGSC,42309,15545,False,../tiles_768/281
4,286,EC,37204,30020,False,../tiles_768/286


In [2]:
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.models.layers import DropPath
import copy
from itertools import cycle

# device = "cuda"
device = "cpu"
model_name = "timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k"

print(f"Using device {device} and model {model_name}")

model = timm.create_model(model_name, pretrained=True)

drop_path_rate = 0.5
dropout_rate = 0.
head_dropout_rate = 0.3
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, len(model.blocks))]

# Assign drop path rates
for i, block in enumerate(model.blocks):
    block.drop_path1 = DropPath(drop_prob=drop_path_rates[i])
    block.drop_path2 = DropPath(drop_prob=drop_path_rates[i])
    block.attn.attn_drop = nn.Dropout(p=dropout_rate, inplace=False)
    block.attn.proj_drop = nn.Dropout(p=dropout_rate, inplace=False)
    block.mlp.drop1 = nn.Dropout(p=dropout_rate, inplace=False)
    block.mlp.drop2 = nn.Dropout(p=dropout_rate, inplace=False)

model.head = nn.Linear(model.head.in_features, 5)
model.pos_drop = nn.Dropout(dropout_rate)
model.head_drop = nn.Dropout(head_dropout_rate)

model = model.to(device)

# Initialize EMA model
ema_decays = [0.999, 0.9995, 0.9998, 0.9999, 0.99995, 0.99998, 0.99999]
# ema_decays = [0.99, 0.995, 0.998, 0.999, 0.9995, 0.9998, 0.9999]
# model_name = "cutmix_mixup_different_sampling"
# model_name = "cutmix_mixup_final_try_all_the_data"
# model_name = "cutmix_mixup_final_try_fold"
# model_name = "cutmix_mixup_high_reg"
model_name = "tma_special_pt_2"
model_step = 26000
model_locations = [
    f'eva02_base_models_{model_name}/ema_{ema_decays[0]}_step_{model_step}.pth',
    f'eva02_base_models_{model_name}/ema_{ema_decays[1]}_step_{model_step}.pth',
    f'eva02_base_models_{model_name}/ema_{ema_decays[2]}_step_{model_step}.pth',
    f'eva02_base_models_{model_name}/ema_{ema_decays[3]}_step_{model_step}.pth',
    f'eva02_base_models_{model_name}/ema_{ema_decays[4]}_step_{model_step}.pth',
    f'eva02_base_models_{model_name}/ema_{ema_decays[5]}_step_{model_step}.pth',
    f'eva02_base_models_{model_name}/ema_{ema_decays[6]}_step_{model_step}.pth',
]
# ema_decay = 0.9995
# model_locations = [
#     f'eva02_base_models_{model_name}/ema_{ema_decay}_step_53000.pth',
#     f'eva02_base_models_{model_name}/ema_{ema_decay}_step_43000.pth',
#     f'eva02_base_models_{model_name}/ema_{ema_decay}_step_33000.pth',
#     f'eva02_base_models_{model_name}/ema_{ema_decay}_step_23000.pth',
#     f'eva02_base_models_{model_name}/ema_{ema_decay}_step_13000.pth',
#     f'eva02_base_models_{model_name}/ema_{ema_decay}_step_3000.pth',
# ]
# model_name = "cutmix_mixup_third_try"
# model_locations = [
#     f'eva02_base_models_{model_name}/ema_0.9998_step_10000.pth',
#     f'eva02_base_models_{model_name}/ema_0.9999_step_20000.pth',
#     f'eva02_base_models_{model_name}/ema_0.99995_step_30000.pth',
#     f'eva02_base_models_{model_name}/ema_0.99998_step_50000.pth',
# ]
ema_models = [copy.deepcopy(model) for _ in range(len(model_locations))]
for i, ema_model in enumerate(ema_models):
    state_dict = torch.load(model_locations[i], map_location=device)
    ema_model.load_state_dict(state_dict, strict=False)
    ema_model = ema_model.to(device)
    ema_model.eval()

del model

Using device cpu and model timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

In [4]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.all_images = []  # Store all images in an interlaced fashion
        self.wsi_label_images = [[] for _ in range(5)]
        self.tma_label_images = [[] for _ in range(5)]

        # Step 1: Collect all images from each folder
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            image_id = row['image_id']
            is_tma = row['is_tma']
            if os.path.isdir(folder_path):
                image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith('.png')]
                if is_tma:
                    self.tma_label_images[label_to_integer[label]].extend([(image_file, label, image_id, is_tma) for image_file in image_files])
                else:
                    self.wsi_label_images[label_to_integer[label]].extend([(image_file, label, image_id, is_tma) for image_file in image_files])

        for i in range(5):
            random.shuffle(self.tma_label_images[i])
            random.shuffle(self.wsi_label_images[i])

        # Step 3: Interlace the images, repeating data as needed
        max_length = max(max(len(tma) for tma in self.tma_label_images), max(len(wsi) for wsi in self.wsi_label_images))
        for i in range(max_length):
            for label in range(5):
                if len(self.tma_label_images[label]) > 0:
                    tma_index = i % len(self.tma_label_images[label])  # Repeat TMA data
                    self.all_images.append(self.tma_label_images[label][tma_index])
                if len(self.wsi_label_images[label]) > 0:
                    wsi_index = i % len(self.wsi_label_images[label])  # Repeat WSI data
                    self.all_images.append(self.wsi_label_images[label][wsi_index])

    def __len__(self):
        return 1_000_000_000

    def __getitem__(self, idx):
        image_path, label, image_id, is_tma = self.all_images[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)

        return image, label_to_integer[label], image_id

In [5]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
import torchvision.transforms.v2 as v2
from torch.utils.data import default_collate

BATCH_SIZE = 16

validation_transform = transforms.Compose([
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[
        0.48145466,
        0.4578275,
        0.40821073
    ], std=[
        0.26862954,
        0.26130258,
        0.27577711
    ]),
])

validation_dataset = ImageDataset(dataframe=validation_df, transform=validation_transform)

validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers=7)

In [6]:
from sklearn.metrics import confusion_matrix

def class_wise_accuracy_score(actual_classes, predicted_classes):
    # Identify all unique classes in both actual and predicted classes
    all_classes = sorted(set(actual_classes) | set(predicted_classes))

    # Compute confusion matrix with all classes included
    matrix = confusion_matrix(actual_classes, predicted_classes, labels=all_classes)
    
    # Convert to DataFrame for better readability
    matrix_df = pd.DataFrame(matrix, index=all_classes, columns=all_classes)
    
    # Calculate class-wise accuracy and round to the nearest thousandth
    # Use np.nan_to_num to handle division by zero
    class_accuracies = np.round(np.nan_to_num(matrix.diagonal() / matrix.sum(axis=1)), 3)
    
    # Create a dictionary to hold class and its corresponding accuracy
    accuracy_dict = dict(zip(all_classes, class_accuracies))            
    
    return accuracy_dict

In [7]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast

# # Assuming the existence of 'ema_models', 'validation_dataloader', and 'device' from your context
# # Class counts: {'CC': 146576, 'EC': 165037, 'HGSC': 272477, 'LGSC': 46346, 'MC': 79036}
# class_counts = torch.tensor([272477, 146576, 165037, 46346, 79036], dtype=torch.float32)

# # Calculate weights: Inverse of class frequencies
# weights = 1.0 / class_counts
# weights = weights / weights.sum()  # Normalize to make the sum of weights equal to 1
# weights = weights.to(device)  # Move weights to the device (CPU/GPU)

# Define the weighted loss function
criterion = torch.nn.CrossEntropyLoss()

step = 0
cumulative_losses = [0 for _ in range(len(ema_models))]  # Cumulative loss for each model
model_steps = [0 for _ in range(len(ema_models))]  # Number of steps for each model

# Initialize lists for all predictions and true labels
all_predictions = [[] for _ in range(len(ema_models))]
all_labels = [[] for _ in range(len(ema_models))]

with torch.no_grad():
    for i, (images, labels, _) in enumerate(validation_dataloader, 0):
        print(i)
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass with autocast
        with autocast():
            for model_index, model in enumerate(ema_models):
                outputs = model(images)
                logits_per_image = outputs
                loss = criterion(logits_per_image, labels)

                # Predictions for balanced accuracy
                predictions = torch.argmax(logits_per_image, dim=1)

                cumulative_losses[model_index] += loss.item()
                model_steps[model_index] += 1

                # Store predictions and labels for balanced accuracy calculation
                all_predictions[model_index].extend(predictions.cpu().numpy())
                all_labels[model_index].extend(labels.cpu().numpy())

                # Calculate running average loss for this model
                running_avg_loss = cumulative_losses[model_index] / model_steps[model_index]

                # Calculate current balanced accuracy score
                current_balanced_acc_score = balanced_accuracy_score(all_labels[model_index], all_predictions[model_index])
                current_class_acc_score = class_wise_accuracy_score(all_labels[model_index], all_predictions[model_index])
                print(f"Model {model_index} | {running_avg_loss:.3f} | {current_balanced_acc_score:.3f} | {current_class_acc_score}")

        step += 1

0
Model 0 | 0.111 | 1.000 | {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
Model 1 | 0.111 | 1.000 | {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
Model 2 | 0.111 | 1.000 | {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
Model 3 | 0.118 | 1.000 | {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
Model 4 | 0.173 | 1.000 | {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
Model 5 | 0.538 | 0.850 | {0: 0.25, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
Model 6 | 0.720 | 0.750 | {0: 0.25, 1: 1.0, 2: 1.0, 3: 0.5, 4: 1.0}
1
Model 0 | 0.084 | 1.000 | {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}


KeyboardInterrupt: 

In [None]:
# 50 10000 tma
# Model 0 | 0.304 | 0.888 | {0: 0.811, 1: 0.902, 2: 0.823, 3: 0.938, 4: 0.963}
# Model 1 | 0.304 | 0.888 | {0: 0.811, 1: 0.902, 2: 0.823, 3: 0.938, 4: 0.963}
# Model 2 | 0.304 | 0.888 | {0: 0.811, 1: 0.902, 2: 0.823, 3: 0.938, 4: 0.963}
# Model 3 | 0.305 | 0.888 | {0: 0.811, 1: 0.902, 2: 0.823, 3: 0.938, 4: 0.963}
# Model 4 | 0.305 | 0.885 | {0: 0.811, 1: 0.902, 2: 0.811, 3: 0.938, 4: 0.963}
# Model 5 | 0.299 | 0.885 | {0: 0.811, 1: 0.896, 2: 0.817, 3: 0.938, 4: 0.963}
# Model 6 | 0.303 | 0.891 | {0: 0.817, 1: 0.896, 2: 0.829, 3: 0.957, 4: 0.957}

In [None]:
# 50 5000 tma
# Model 0 | 0.320 | 0.895 | {0: 0.811, 1: 0.927, 2: 0.848, 3: 0.944, 4: 0.944}
# Model 1 | 0.319 | 0.899 | {0: 0.811, 1: 0.927, 2: 0.86, 3: 0.951, 4: 0.944}
# Model 2 | 0.318 | 0.895 | {0: 0.793, 1: 0.927, 2: 0.86, 3: 0.951, 4: 0.944}
# Model 3 | 0.317 | 0.895 | {0: 0.787, 1: 0.927, 2: 0.866, 3: 0.951, 4: 0.944}
# Model 4 | 0.314 | 0.889 | {0: 0.787, 1: 0.915, 2: 0.841, 3: 0.951, 4: 0.951}
# Model 5 | 0.318 | 0.888 | {0: 0.75, 1: 0.915, 2: 0.872, 3: 0.951, 4: 0.951}
# Model 6 | 0.349 | 0.883 | {0: 0.756, 1: 0.921, 2: 0.872, 3: 0.914, 4: 0.951}

In [None]:
# 50 1000 tma
# Model 0 | 0.343 | 0.872 | {0: 0.829, 1: 0.89, 2: 0.701, 3: 0.969, 4: 0.969}
# Model 1 | 0.346 | 0.878 | {0: 0.866, 1: 0.89, 2: 0.701, 3: 0.963, 4: 0.969}
# Model 2 | 0.339 | 0.879 | {0: 0.86, 1: 0.909, 2: 0.713, 3: 0.963, 4: 0.951}
# Model 3 | 0.338 | 0.874 | {0: 0.799, 1: 0.909, 2: 0.774, 3: 0.938, 4: 0.951}
# Model 4 | 0.361 | 0.883 | {0: 0.762, 1: 0.915, 2: 0.841, 3: 0.938, 4: 0.957}
# Model 5 | 0.408 | 0.830 | {0: 0.524, 1: 0.878, 2: 0.854, 3: 0.938, 4: 0.957}
# Model 6 | 0.434 | 0.819 | {0: 0.463, 1: 0.878, 2: 0.854, 3: 0.944, 4: 0.957}

In [None]:
# 50 44000 all the data
# Model 0 | 0.484 | 0.818 | {0: 0.445, 1: 0.878, 2: 0.86, 3: 0.944, 4: 0.963}
# Model 1 | 0.489 | 0.812 | {0: 0.421, 1: 0.872, 2: 0.86, 3: 0.944, 4: 0.963}
# Model 2 | 0.491 | 0.812 | {0: 0.421, 1: 0.878, 2: 0.86, 3: 0.938, 4: 0.963}
# Model 3 | 0.451 | 0.820 | {0: 0.506, 1: 0.884, 2: 0.848, 3: 0.907, 4: 0.957}
# Model 4 | 0.407 | 0.869 | {0: 0.756, 1: 0.921, 2: 0.799, 3: 0.901, 4: 0.969}
# Model 5 | 0.522 | 0.804 | {0: 0.811, 1: 0.915, 2: 0.671, 3: 0.821, 4: 0.802}
# Model 6 | 0.825 | 0.694 | {0: 0.713, 1: 0.915, 2: 0.683, 3: 0.506, 4: 0.654}

In [None]:
# 72 20000 all the data
# Model 0 | 0.529 | 0.833 | {0: 0.808, 1: 0.923, 2: 0.615, 3: 0.859, 4: 0.961}
# Model 1 | 0.524 | 0.834 | {0: 0.812, 1: 0.927, 2: 0.615, 3: 0.855, 4: 0.961}
# Model 2 | 0.529 | 0.834 | {0: 0.846, 1: 0.923, 2: 0.59, 3: 0.85, 4: 0.961}
# Model 3 | 0.561 | 0.807 | {0: 0.88, 1: 0.906, 2: 0.585, 3: 0.812, 4: 0.853}
# Model 4 | 0.687 | 0.735 | {0: 0.795, 1: 0.902, 2: 0.607, 3: 0.65, 4: 0.72}
# Model 5 | 1.063 | 0.619 | {0: 0.577, 1: 0.906, 2: 0.62, 3: 0.376, 4: 0.616}
# Model 6 | 1.412 | 0.438 | {0: 0.064, 1: 0.803, 2: 0.509, 3: 0.265, 4: 0.547}