In [2]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

I_FOLD = 3
train = pd.read_csv(f"train_fold_{I_FOLD}.csv")
validation = pd.read_csv(f"val_fold_{I_FOLD}.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
validation['tile_path'] = validation['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,4,HGSC,23785,20008,False,../tiles_768/4
1,66,LGSC,48871,48195,False,../tiles_768/66
2,91,HGSC,3388,3388,True,../tiles_768/91
3,281,LGSC,42309,15545,False,../tiles_768/281
4,286,EC,37204,30020,False,../tiles_768/286


In [4]:
"""
FOR FOLD_3 I USE DROP PATH RATE = 0.5
"""

# import json

# # Open the JSON file for reading
# with open('./convnextv2_base_config/config.json', 'r') as file:
#     data = json.load(file)
    
# data['drop_path_rate'] = 0.5

# with open('./convnextv2_base_config/config.json', 'w') as file:
#     json.dump(data, file, indent=4)

'\nFOR FOLD_3 I USE DROP PATH RATE = 0.5\n'

In [5]:
from PIL import Image
import torch
import torch.nn as nn
from transformers import ConvNextV2ForImageClassification

device = "cuda" if torch.cuda.is_available() else "cpu"
# model_name = "facebook/convnextv2-base-22k-384"
# print(f"Using device {device} and model {model_name}")

model = ConvNextV2ForImageClassification.from_pretrained('./convnextv2_base_config')

model.classifier = nn.Linear(model.classifier.in_features, 5)

# state_dict = torch.load('vit_models_1/epoch_249_step_3500.pth', map_location=device)
# model.load_state_dict(state_dict, strict=False)
model = model.to(device)

In [6]:
import os
from PIL import Image
from torch.utils.data import Dataset

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.folder_paths = []
        self.labels = []
        self.image_ids = []
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            image_id = row['image_id']
            if os.path.isdir(folder_path):  # Check if the folder_path is a valid directory
                self.folder_paths.append(folder_path)
                self.labels.append(label)
                self.image_ids.append(image_id)

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_paths = os.listdir(self.folder_paths[idx])
        image_index = random.randint(0, len(image_paths) - 1)
        while not image_paths[image_index].lower().endswith('.png'):  # Check if the file is a PNG
            image_index = random.randint(0, len(image_paths) - 1)

        image = Image.open(os.path.join(self.folder_paths[idx], image_paths[image_index]))
        label = self.labels[idx]
        image_id = self.image_ids[idx]
        if self.transform:
            image = self.transform(image)
        return image, label_to_integer[label], image_id


In [7]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

BATCH_SIZE = 8

train_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomAffine(degrees=45, translate=(0.25, 0.25), scale=(1, 2), shear=(-30, 30, -30, 30)),
    transforms.Resize(384),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
])

val_transform = transforms.Compose([
    transforms.Resize(384),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

train_dataset = ImageDataset(dataframe=train, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=3, shuffle=True)

In [8]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/convnextv2_base/fold_{I_FOLD}.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [9]:
train.groupby('label').count()

Unnamed: 0_level_0,image_id,image_width,image_height,is_tma,tile_path
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CC,79,79,79,79,79
EC,100,100,100,100,100
HGSC,178,178,178,178,178
LGSC,37,37,37,37,37
MC,37,37,37,37,37


In [None]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
import copy

initial_lr = 1e-5
final_lr = 1e-6
num_epochs = 10000

# Function for linear warmup
def learning_rate(step, warmup_steps=500, max_steps=20000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

# Function to calculate weighted accuracy
def weighted_accuracy(true_labels, predictions, class_weights):
    correct = 0
    total_weight = 0

    for label, pred in zip(true_labels, predictions):
        if label == pred:
            correct += class_weights[label]
        total_weight += class_weights[label]

    return correct / total_weight

def update_ema_variables(model, ema_model, alpha, global_step):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

# Initialize EMA model
ema_decay = 0.999  # decay factor for EMA
ema_model = copy.deepcopy(model)
ema_model.to(device)

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=5e-2)

# Calculate class weights
class_counts = np.array([178, 79, 100, 37, 37], dtype=np.float32)
class_weights = 1. / class_counts
class_weights /= class_weights.sum()

# Convert class weights to tensor
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define the loss function with class weights
criterion = torch.nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

best_val_accuracy = 0.0
step = 0

for epoch in range(num_epochs):
    model.train()  # set the model to training mode
    
    for i, (images, labels, _) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        images = images.to(device)
        labels = labels.to(device)
        
        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            g['lr'] = lr
            
        # Zero the parameter gradients
        optimizer.zero_grad()

        outputs = model(images)
        logits_per_image = outputs.logits
        loss = criterion(logits_per_image, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        update_ema_variables(model, ema_model, ema_decay, step)

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if step % 500 == 0:
            ema_model.eval()

            tma_preds = []
            tma_labels = []
            
            non_tma_preds = []
            non_tma_labels = []

            with torch.no_grad():
                for _, row in validation.iterrows():
                    path = row['tile_path']
                    all_files = [f for f in os.listdir(path) if f.lower().endswith('.png')]
                    
                    probabilities = torch.zeros(5).to(device)

                    # Prepare a list to hold image tiles
                    batch_tiles = []

                    sample_size = min(16, len(all_files))
                    sampled_files = random.sample(all_files, sample_size)

                    for image_name in sampled_files:
                        image_path = os.path.join(path, image_name)
                        sub_image = Image.open(image_path)

                        tile = val_transform(sub_image).unsqueeze(0).to(device)
                        batch_tiles.append(tile)

                    outputs = ema_model(torch.concat(batch_tiles, dim=0))
                    probs = outputs.logits.softmax(dim=1)
                    probabilities += probs[0]

                    pred_label = integer_to_label[probabilities.argmax().detach().cpu().item()]
                    label = row['label']
                    if row['is_tma']:
                        tma_preds.append(pred_label)
                        tma_labels.append(label)
                    else:
                        non_tma_preds.append(pred_label)
                        non_tma_labels.append(label)

            tma_accuracy = balanced_accuracy_score(tma_labels, tma_preds)
            non_tma_accuracy = balanced_accuracy_score(non_tma_labels, non_tma_preds)
            accuracy = (tma_accuracy + non_tma_accuracy) / 2
            logging.info(f'TMA Accuracy: {tma_accuracy} | Non-TMA Accuracy: {non_tma_accuracy} | Overall Accuracy: {accuracy}')

            if accuracy > best_val_accuracy:
                best_val_accuracy = accuracy
                torch.save(ema_model.state_dict(), f'convnextv2_base_models/fold_{I_FOLD}/epoch_{epoch}_step_{step}.pth')
                logging.info(f'Model saved after epoch {epoch} and step {step}')
                
            model.train()

        if step == 20000:
            torch.save(ema_model.state_dict(), f'convnextv2_base_models/fold_{I_FOLD}/final.pth')

        step += 1