In [1]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

I_FOLD = 3
validation = pd.read_csv(f"val_fold_{I_FOLD}.csv")

validation['tile_path'] = validation['image_id'].apply(lambda x: get_image_path(x))
validation.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,33976,LGSC,38146,21220,False,../tiles_768/33976
1,34277,LGSC,65805,35570,False,../tiles_768/34277
2,34688,LGSC,27441,19507,False,../tiles_768/34688
3,35565,MC,2964,2964,True,../tiles_768/35565
4,37385,LGSC,3388,3388,True,../tiles_768/37385


In [2]:
import torch
import torch.nn as nn
import timm
from timm.models import VisionTransformer
from timm.models.layers import DropPath
import copy

class CustomViT(nn.Module):
    def __init__(self, n_classes=5, embed_dim=768):
        super().__init__()
        self.n_classes = n_classes
        self.embed_dim = embed_dim
        # Load the base ViT model
        self.base_model = VisionTransformer(img_size=384, num_classes=self.n_classes, patch_size=16, embed_dim=self.embed_dim, depth=12, num_heads=12, global_pool='avg', pre_norm=True)

        # Initialize a learnable mask token
        self.mask_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        max_drop_path_rate = 0.3
        dropout_rate = 0.1

        drop_path_rates = [x.item() for x in torch.linspace(0, max_drop_path_rate, len(self.base_model.blocks))]

        # Assign drop path rates
        for i, block in enumerate(self.base_model.blocks):
            block.drop_path1 = DropPath(drop_prob=drop_path_rates[i])
            block.drop_path2 = DropPath(drop_prob=drop_path_rates[i])
            block.attn.attn_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.attn.proj_drop = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop1 = nn.Dropout(p=dropout_rate, inplace=False)
            block.mlp.drop2 = nn.Dropout(p=dropout_rate, inplace=False)
        self.head_dropout = nn.Dropout(p=dropout_rate, inplace=False)

        self.class_token_head = nn.Linear(self.embed_dim, self.n_classes)
        self.patch_token_head = nn.Linear(self.embed_dim, self.embed_dim) 

    def forward_features(self, x, mask=None):
        # Get the patch embeddings (excluding the class token)
        x = self.base_model.patch_embed(x)

        # Handle masked patches if a mask is provided
        if mask is not None:
            # Adjust mask to account for the class token
            mask = torch.cat((torch.zeros(x.shape[0], 1).bool().to(mask.device), mask), dim=1)
            # Expand mask token to match the batch size and masked patches
            mask_tokens = self.mask_token.expand(x.size(0), -1, -1)
            # Apply the mask - replace masked patches with the mask token
            x = torch.where(mask.unsqueeze(-1), mask_tokens, x)

        to_cat = []
        if self.base_model.cls_token is not None:
            to_cat.append(self.base_model.cls_token.expand(x.shape[0], -1, -1))
        x = torch.cat(to_cat + [x], dim=1)

        x = self.base_model.pos_drop(x + self.base_model.pos_embed)
        x = self.base_model.norm_pre(x)
        x = self.base_model.blocks(x)
        x = self.base_model.norm(x)

        # Exclude the class token and return the patch representations
        return x

    def forward_head(self, x):
        class_token, patch_tokens = x[:, 0], x[:, 1:]

        # Apply dropout
        class_token = self.head_dropout(class_token)
        patch_tokens = self.head_dropout(patch_tokens)

        # Process class token and patch tokens through their respective heads
        class_token_output = self.class_token_head(class_token)
        patch_token_output = self.patch_token_head(patch_tokens)

        return {"class_token_output": class_token_output, "patch_token_output": patch_token_output}

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x['class_token_output']

def load_model(model_location):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CustomViT(n_classes=5, embed_dim=768)

    state_dict = torch.load(model_location, map_location=device)
    model.load_state_dict(state_dict, strict=False)

    model = model.to(device)
    model.eval()
    return model

In [3]:
import torchvision.transforms as transforms
import numpy as np
import cv2  # Required for CLAHE

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

def apply_clahe_to_color_image(img):
    # Convert PIL Image to OpenCV format
    img_cv = np.array(img)
    img_cv = img_cv[:, :, ::-1]  # Convert RGB to BGR

    # Split the image into its B, G, R channels
    b, g, r = cv2.split(img_cv)

    # Apply CLAHE to each channel
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    b_clahe = clahe.apply(b)
    g_clahe = clahe.apply(g)
    r_clahe = clahe.apply(r)

    # Merge the CLAHE enhanced channels back together
    img_clahe = cv2.merge([b_clahe, g_clahe, r_clahe])
    img_clahe = cv2.cvtColor(img_clahe, cv2.COLOR_BGR2RGB)  # Convert BGR back to RGB

    # Convert back to PIL Image
    img_clahe_pil = Image.fromarray(img_clahe)

    return img_clahe_pil

# val_transform = transforms.Compose([
#     transforms.Resize(448),
#     # transforms.Lambda(lambda img: apply_clahe_to_color_image(img)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[
#         0.48145466,
#         0.4578275,
#         0.40821073
#     ], std=[
#         0.26862954,
#         0.26130258,
#         0.27577711
#     ]),
# ])

val_transform = transforms.Compose([
    transforms.Resize(384),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [4]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'validation_logs/vit_base_pretrained/fold_{I_FOLD}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [5]:
from collections import defaultdict

def find_winner(graph, source, visited):
    """Helper function to find the winner in the graph."""
    if source not in graph:
        return False

    visited.add(source)
    for target in graph[source]:
        if target not in visited and find_winner(graph, target, visited):
            return True
    visited.remove(source)
    return False

def ranked_pairs_winner(image_logits):
    num_voters, num_classes = image_logits.shape
    margins = defaultdict(int)

    # Perform pairwise comparisons
    for i in range(num_classes):
        for j in range(num_classes):
            if i != j:
                votes_for_i = torch.sum(image_logits[:, i] > image_logits[:, j])
                votes_for_j = num_voters - votes_for_i
                margins[(i, j)] = votes_for_i - votes_for_j

    # Sort pairs by margin of victory
    sorted_pairs = sorted(margins, key=margins.get, reverse=True)

    # Initialize graph for locking pairs
    graph = {i: set() for i in range(num_classes)}
    for pair in sorted_pairs:
        winner, loser = pair
        graph[winner].add(loser)
        visited = set()

        # Check for cycle
        if find_winner(graph, loser, visited):
            graph[winner].remove(loser)

    # Determine the winner
    for i in range(num_classes):
        if not any(i in targets for targets in graph.values()):
            return i

    # Fallback: Choose the class with the highest total votes if no clear winner is found
    total_votes = torch.sum(image_logits, axis=0)
    return torch.argmax(total_votes).item()

def star_voting_winner(image_logits):
    num_voters, num_classes = image_logits.shape

    # Step 1: Sum the scores for each candidate
    total_scores = torch.sum(image_logits, axis=0)

    # Step 2: Find the two candidates with the highest total scores
    top_two = torch.topk(total_scores, 2).indices

    # Step 3: Runoff between the top two candidates
    first_choice_votes = torch.sum(image_logits[:, top_two[0]] > image_logits[:, top_two[1]])
    second_choice_votes = num_voters - first_choice_votes

    # Determine the winner
    if first_choice_votes > second_choice_votes:
        return top_two[0].item()
    else:
        return top_two[1].item()

def instant_runoff_winner(image_logits):
    num_voters, num_classes = image_logits.shape
    active_candidates = set(range(num_classes))

    while True:
        # Count the first-preference votes for each candidate
        first_pref_counts = torch.zeros(num_classes)
        for logits in image_logits:
            for rank in torch.argsort(logits, descending=True):
                if rank.item() in active_candidates:
                    first_pref_counts[rank] += 1
                    break

        # Check if any candidate has more than 50% of the votes
        if torch.any(first_pref_counts > num_voters / 2):
            return torch.argmax(first_pref_counts).item()

        # Find the candidate with the fewest votes among active candidates
        active_candidates_votes = first_pref_counts[list(active_candidates)]
        min_votes, min_index = torch.min(active_candidates_votes, 0)
        min_candidate = list(active_candidates)[min_index.item()]

        # Eliminate the candidate with the fewest votes
        active_candidates.remove(min_candidate)

In [6]:
from PIL import Image
import random
from sklearn.metrics import balanced_accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

with torch.no_grad():
    for step in range(30000, -2000, -2000):
        model = load_model(f'vit_base_pretrained_models/fold_3/epoch_0_step_{step}.pth')
        model.eval()
        
        image_ids = []
        logits = []
        labels = []
        for idx, row in validation.iterrows():
            if idx % 10 == 0:
                logging.info(f'idx: {idx}')
            random.seed(0)
            path = row['tile_path']
            all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]

            batch_logits = []

            # Prepare a list to hold image tiles
            batch_tiles = []

            sample_size = min(32, len(all_files))
            sampled_files = random.sample(all_files, sample_size)

            for image_name in sampled_files:
                image_path = os.path.join(path, image_name)
                sub_image = Image.open(image_path)

                tile = val_transform(sub_image).unsqueeze(0).to(device)
                batch_tiles.append(tile)

            for i_batch in range(0, len(batch_tiles), 32):
                outputs = model(torch.concat(batch_tiles[i_batch:i_batch+32], dim=0))
                probs = outputs.softmax(dim=1)
                batch_logits.append(outputs)

            image_id = row['image_id']
            batch_logits = torch.concat(batch_logits, dim=0)
            label = row['label']
            image_ids.append(image_id)
            logits.append(batch_logits)
            labels.append(label)
            
        logging.info(f'-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_')
        logging.info(f'step: {step}')
        
        predictions = []
        for image_logits in logits:
            argmax_indices = torch.argmax(image_logits, dim=1)

            frequency_counts = torch.bincount(argmax_indices, minlength=5)

            max_vote_key = frequency_counts.argmax().cpu().item()
            predictions.append(integer_to_label[max_vote_key])

        plurality_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'plurality_accuracy: {plurality_accuracy}')
        
        predictions = []
        for image_logits in logits:
            summed_logits = image_logits.sum(dim=0)

            max_vote_key = summed_logits.argmax().cpu().item()
            predictions.append(integer_to_label[max_vote_key])

        logit_sum_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'logit_sum_accuracy: {logit_sum_accuracy}')
        
        predictions = []
        for image_logits in logits:
            summed_probs = image_logits.softmax(dim=1).sum(dim=0)

            max_vote_key = summed_probs.argmax().cpu().item()
            predictions.append(integer_to_label[max_vote_key])

        prob_sum_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'prob_sum_accuracy: {logit_sum_accuracy}')

        predictions = []
        for image_logits in logits:
            summed_log_probs = torch.log(image_logits.softmax(dim=1)).sum(dim=0)

            max_vote_key = summed_log_probs.argmax().cpu().item()
            predictions.append(integer_to_label[max_vote_key])

        log_prob_sum_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'log_prob_sum_accuracy: {log_prob_sum_accuracy}')
        
        predictions = []
        for image_logits in logits:
            summed_one_minus_log_probs = torch.log(1 - image_logits.softmax(dim=1)).sum(dim=0)

            min_vote_key = summed_one_minus_log_probs.argmin().cpu().item()
            predictions.append(integer_to_label[min_vote_key])

        one_minus_log_prob_sum_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'one_minus_log_prob_sum_accuracy: {one_minus_log_prob_sum_accuracy}')
        
        predictions = []
        for image_logits in logits:
            rp_winner = ranked_pairs_winner(image_logits)
            predictions.append(integer_to_label[rp_winner])

        rp_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'rp_accuracy: {rp_accuracy}')
        
        predictions = []
        for image_logits in logits:
            star_winner = star_voting_winner(image_logits)
            predictions.append(integer_to_label[star_winner])

        star_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'star_accuracy: {star_accuracy}')
        
        predictions = []
        for image_logits in logits:
            irv_winner = instant_runoff_winner(image_logits)
            predictions.append(integer_to_label[irv_winner])

        instant_runoff_accuracy = balanced_accuracy_score(labels, predictions)
        logging.info(f'instant_runoff_accuracy: {instant_runoff_accuracy}')

        logging.info(f'-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_')
