In [1]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join('../tiles_768', str(image_id))

# train = pd.read_csv(f"../data/train.csv")
train = pd.read_csv(f"train_fold_1.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,4,HGSC,23785,20008,False,../tiles_768/4
1,66,LGSC,48871,48195,False,../tiles_768/66
2,91,HGSC,3388,3388,True,../tiles_768/91
3,281,LGSC,42309,15545,False,../tiles_768/281
4,286,EC,37204,30020,False,../tiles_768/286


In [2]:
from timm.models.vision_transformer import Block
import torch
import torch.nn as nn
import copy

device = "cuda" if torch.cuda.is_available() else "cpu"

class GlobalModel(nn.Module):
    def __init__(self, n_heads, n_layers, embed_dim, n_classes):
        super().__init__()
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.embed_dim = embed_dim
        self.n_classes = n_classes
        
        drop_path_rate = 0.0
        drop_out_rate = 0.0

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.patch_embed = nn.Linear(768, embed_dim)
        self.norm_pre = nn.LayerNorm(embed_dim)
        self.drop_pre = nn.Dropout(p=drop_out_rate)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim,
                num_heads=n_heads,
                proj_drop=drop_out_rate,
                attn_drop=drop_out_rate,
                drop_path=dpr[i]
            )
            for i in range(n_layers)])
        self.norm_post = nn.LayerNorm(embed_dim)
        self.head_drop = nn.Dropout(p=drop_out_rate)
        self.fc_head = nn.Linear(embed_dim, n_classes)
    
    def forward(self, x):
        x = self.patch_embed(x)
        x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
        
        x = self.norm_pre(x)
        x = self.drop_pre(x)
        x = self.blocks(x)
        
        x = x[:, 1:].mean(dim=1)
        x = self.norm_post(x)
        x = self.head_drop(x)
        x = self.fc_head(x)
        
        return x

# global_model = GlobalModel(n_heads=3, n_layers=12, embed_dim=192, n_classes=5) # second third fourth
# global_model = GlobalModel(n_heads=12, n_layers=1, embed_dim=768, n_classes=5) # fifth
global_model = GlobalModel(n_heads=3, n_layers=1, embed_dim=192, n_classes=5) # sixth seventh
global_model = global_model.to(device)

ema_decay = 0.99
ema_global_model = copy.deepcopy(global_model)
ema_global_model = ema_global_model.to(device)

In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset
import random

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

LOCAL_SAMPLES = 24

class ImageDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.images_by_label_integer = {i: [] for i in range(5)}
        
        save_folder = '../image_tensors'

        for index, row in dataframe.iterrows():
            label = row['label']
            image_id = row['image_id']
            load_path = os.path.join(save_folder, f"{image_id}.pt")
            images_tensor = torch.load(load_path, map_location=torch.device('cpu'))
            self.images_by_label_integer[label_to_integer[label]].append(images_tensor)

    def __len__(self):
        return 1_000_000_000

    def __getitem__(self, idx):
        n_labels = random.randint(1, 5)
        labels = [0, 1, 2, 3, 4]
        random.shuffle(labels)
        labels = labels[:n_labels]

        n_from_each_label = [0, 0, 0, 0, 0]
        for _ in range(LOCAL_SAMPLES):
            n_from_each_label[random.choice(labels)] += 1

        images = []
        for i in range(len(n_from_each_label)):
            num_samples = n_from_each_label[i]
            if num_samples == 0:
                continue
            selected_images_tensor_raw = self.images_by_label_integer[i][random.randint(0, len(self.images_by_label_integer[i]) - 1)]
            selected_images_tensor = selected_images_tensor_raw[torch.randint(0, selected_images_tensor_raw.shape[0], (num_samples,))]
            images.append(selected_images_tensor)
        
        images = torch.cat(images, dim=0)
        label = torch.tensor(n_from_each_label) / LOCAL_SAMPLES
        
        return images, label

In [4]:
from torch.utils.data import DataLoader

BATCH_SIZE = 512

train_dataset = ImageDataset(dataframe=train)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=8)

In [5]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler(f'logs/eva02_global_attention/seventh_try.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import balanced_accuracy_score
import random
from torch.cuda.amp import GradScaler, autocast

initial_lr = 0.0005 * BATCH_SIZE/256
final_lr = initial_lr * 0.01
num_epochs = 10000

# Function for linear warmup
def learning_rate(step, warmup_steps=100, max_steps=1000):
    if step < warmup_steps:
        return initial_lr * (float(step) / float(max(1, warmup_steps)))
    elif step < max_steps:
        progress = (float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)))
        cos_component = 0.5 * (1 + math.cos(math.pi * progress))
        return final_lr + (initial_lr - final_lr) * cos_component
    else:
        return final_lr

def update_ema_variables(model, ema_model, alpha):
    # Update the EMA model parameters
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

scaler = GradScaler()
optimizer = optim.AdamW(global_model.parameters(), lr=initial_lr, weight_decay=5e-2)

criterion = nn.KLDivLoss(reduction='batchmean')

best_val_accuracy = 0.0
step = 0

ema_global_model.eval()
global_model.train()
for epoch in range(num_epochs):
    
    for images, labels in train_dataloader:
        # Convert images to PIL format
        images = images.to(device)
        labels = labels.to(device)
        
        # Linearly increase the learning rate
        lr = learning_rate(step)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with autocast():
            global_outputs = global_model(images)
            log_probs = F.log_softmax(global_outputs, dim=1)
            loss = criterion(log_probs, labels)

        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        update_ema_variables(global_model, ema_global_model, ema_decay)

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if step % 100 == 0:
            torch.save(ema_global_model.state_dict(), f'eva02_global_attention_models/seventh_try/epoch_{epoch}_step_{step}.pth')
            logging.info(f'Model saved after epoch {epoch} and step {step}')\

        if step == 1000:
            torch.save(ema_global_model.state_dict(), f'eva02_global_attention_models/seventh_try/final.pth')

        step += 1

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f41d0bc6b00>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/opt/conda/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/opt/conda/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 

KeyboardInterrupt

