In [1]:
from sklearn.model_selection import train_test_split
import pandas as pd

whole_df = pd.read_csv('../tile_masks_768.csv')

whole_df['image_path'] = whole_df['image_path'].apply(lambda x: f'../{x}')
train, validation = train_test_split(whole_df, test_size=0.2, random_state=42)

In [2]:
from PIL import Image
import torch
import torch.nn as nn
import timm
from timm.models.layers import DropPath
import copy

MODEL_NAME = 'even-more-balanced'

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k"

print(f"Using device {device} and model {model_name}")

model = timm.create_model(model_name, pretrained=True)

# Define the maximum drop path rate
max_drop_path_rate = 0.3
dropout_rate = 0.1

drop_path_rates = [x.item() for x in torch.linspace(0, max_drop_path_rate, len(model.blocks))]

# Assign drop path rates
for i, block in enumerate(model.blocks):
    block.drop_path1 = DropPath(drop_prob=drop_path_rates[i])
    block.drop_path2 = DropPath(drop_prob=drop_path_rates[i])
    block.attn.attn_drop = nn.Dropout(p=dropout_rate, inplace=False)
    block.attn.proj_drop = nn.Dropout(p=dropout_rate, inplace=False)
    block.mlp.drop1 = nn.Dropout(p=dropout_rate, inplace=False)
    block.mlp.drop2 = nn.Dropout(p=dropout_rate, inplace=False)

model.head = nn.Linear(model.head.in_features, 3)

# state_dict = torch.load('vit_models_1/epoch_249_step_3500.pth', map_location=device)
# model.load_state_dict(state_dict, strict=False)

model = model.to(device)

# Initialize EMA model
ema_decay = 0.999  # decay factor for EMA
ema_model = copy.deepcopy(model)
ema_model = ema_model.to(device)

Using device cuda and model timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random

integer_to_label = {
    0: 'tumor',
    1: 'stroma',
    2: 'necrosis',
}

label_to_integer = {
    'tumor': 0,
    'stroma': 1,
    'necrosis': 2,
}

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.all_images = []

        # Step 1: Collect all images from each folder
        for index, row in dataframe.iterrows():
            image_file = row['image_path']
            label = row['mask_label']
            self.all_images.append((image_file, label))

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        image_path, label = self.all_images[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)

        return image, label_to_integer[label]

In [4]:
train.groupby('mask_label').count()

Unnamed: 0_level_0,image_path
mask_label,Unnamed: 1_level_1
necrosis,505
stroma,2438
tumor,39464


In [5]:
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms

BATCH_SIZE = 8

train_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomAffine(degrees=45, translate=(0.25, 0.25), scale=(1, 2), shear=(-30, 30, -30, 30)),
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[
        0.48145466,
        0.4578275,
        0.40821073
    ], std=[
        0.26862954,
        0.26130258,
        0.27577711
    ]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
])

val_transform = transforms.Compose([
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[
        0.48145466,
        0.4578275,
        0.40821073
    ], std=[
        0.26862954,
        0.26130258,
        0.27577711
    ]),
])

train_dataset = ImageDataset(dataframe=train, transform=train_transform)

# Calculate weights for each class
class_counts = [39464, 2438, 505]  # Example class counts
num_samples = sum(class_counts)
class_weights = [num_samples / class_count for class_count in class_counts]

# Assign a weight to each sample in the dataset based on its class
sample_weights = [class_weights[label_to_integer[label]] for _, label in train_dataset.all_images]

# Create WeightedRandomSampler
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# DataLoader with WeightedRandomSampler
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=3)

In [6]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/{MODEL_NAME}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [7]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast

initial_lr = 1e-5
final_lr = 1e-6
num_epochs = 10000

# Function for linear warmup
def learning_rate(step, warmup_steps=2000, max_steps=20000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

# Function to calculate weighted accuracy
def weighted_accuracy(true_labels, predictions, class_weights):
    correct = 0
    total_weight = 0

    for label, pred in zip(true_labels, predictions):
        if label == pred:
            correct += class_weights[label]
        total_weight += class_weights[label]

    return correct / total_weight

def update_ema_variables(model, ema_model, alpha, global_step):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

scaler = GradScaler()

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=5e-2)

# # Calculate class weights
# class_counts = np.array([39464, 2438, 505], dtype=np.float32)
# class_weights = 1. / class_counts
# class_weights /= class_weights.sum()

# # Convert class weights to tensor
# class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define the loss function with class weights
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

best_val_accuracy = 0.0
step = 0

for epoch in range(num_epochs):
    model.train()  # set the model to training mode
    
    for i, (images, labels) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        images = images.to(device)
        labels = labels.to(device)
        
        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            # g['lr'] = g['lr'] * lr / initial_lr
            g['lr'] = lr

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with autocast():
            outputs = model(images)
            logits_per_image = outputs
            loss = criterion(logits_per_image, labels)
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        update_ema_variables(model, ema_model, ema_decay, step)

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if step % 2000 == 0:
            ema_model.eval()
            torch.save(ema_model.state_dict(), f'eva02_base_models/{MODEL_NAME}/epoch_{epoch}_step_{step}.pth')
            logging.info(f'Model saved after epoch {epoch} and step {step}')\

            model.train()

        if step == 20000:
            torch.save(ema_model.state_dict(), f'eva02_base_models/{MODEL_NAME}/final.pth')

        step += 1


KeyboardInterrupt



In [None]:
ema_model.eval()

tma_preds = []
tma_labels = []

non_tma_preds = []
non_tma_labels = []

with torch.no_grad():
    for _, row in validation.iterrows():
        path = row['tile_path']
        all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]

        sum_probabilities = torch.zeros(5).to(device)
        sum_log_probabilities = torch.zeros(5).to(device)
        sum_log_neg_probabilities = torch.zeros(5).to(device)

        # Prepare a list to hold image tiles
        batch_tiles = []

        sample_size = min(16, len(all_files))
        sampled_files = random.sample(all_files, sample_size)

        for image_name in sampled_files:
            image_path = os.path.join(path, image_name)
            sub_image = Image.open(image_path)

            tile = val_transform(sub_image).unsqueeze(0).to(device)
            batch_tiles.append(tile)

        outputs = ema_model(torch.concat(batch_tiles, dim=0))
        probs = outputs.softmax(dim=1)
        sum_probabilities += probs.sum(dim=0)
        sum_log_probabilities += torch.log(probs).sum(dim=0)
        sum_log_neg_probabilities += torch.log(1 - probs).sum(dim=0)
        if sum_probabilities.argmax().detach().cpu().item() != (-sum_log_neg_probabilities).argmax().detach().cpu().item():
            print(sum_probabilities / 16, sum_log_probabilities / 16, sum_log_neg_probabilities / 16, row)

        pred_label = integer_to_label[sum_probabilities.argmax().detach().cpu().item()]
        label = row['label']
        print(pred_label, label)
        if row['is_tma']:
            tma_preds.append(pred_label)
            tma_labels.append(label)
        else:
            non_tma_preds.append(pred_label)
            non_tma_labels.append(label)

tma_accuracy = balanced_accuracy_score(tma_labels, tma_preds)
non_tma_accuracy = balanced_accuracy_score(non_tma_labels, non_tma_preds)
accuracy = (tma_accuracy + non_tma_accuracy) / 2
print(f'TMA Accuracy: {tma_accuracy} | Non-TMA Accuracy: {non_tma_accuracy} | Overall Accuracy: {accuracy}')