In [1]:
import pandas as pd
# train = pd.read_csv("train-yes-tma.csv")
# validation = pd.read_csv("validation-yes-tma.csv")
train = pd.read_csv("data/train.csv")
validation = pd.read_csv("data/train.csv")
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma
0,4,HGSC,23785,20008,False
1,66,LGSC,48871,48195,False
2,91,HGSC,3388,3388,True
3,281,LGSC,42309,15545,False
4,286,EC,37204,30020,False


In [2]:
import os

def get_image_path(image_id:int, directory:str):
    return os.path.join(directory, str(image_id))

train['tile_path_0'] = train['image_id'].apply(lambda x: get_image_path(x, 'tiles'))
validation['tile_path_0'] = validation['image_id'].apply(lambda x: get_image_path(x, 'tiles'))
train['tile_path_1'] = train['image_id'].apply(lambda x: get_image_path(x, 'tiles_2964'))
validation['tile_path_1'] = validation['image_id'].apply(lambda x: get_image_path(x, 'tiles_2964'))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path_0,tile_path_1
0,4,HGSC,23785,20008,False,tiles/4,tiles_2964/4
1,66,LGSC,48871,48195,False,tiles/66,tiles_2964/66
2,91,HGSC,3388,3388,True,tiles/91,tiles_2964/91
3,281,LGSC,42309,15545,False,tiles/281,tiles_2964/281
4,286,EC,37204,30020,False,tiles/286,tiles_2964/286


In [3]:
from PIL import Image
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTForImageClassification

N_MODELS = 2

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "google/vit-base-patch16-224"
print(f"Using device {device} and model {model_name}")

processor = ViTImageProcessor.from_pretrained(model_name)
models = [ViTForImageClassification.from_pretrained(model_name) for _ in range(N_MODELS)]
classifier = nn.Linear(models[0].classifier.in_features * N_MODELS, 5)

# model.classifier = nn.Linear(model.classifier.in_features, 5)
for model in models:
    model.classifier = nn.Identity()
    model = model.to(device)

classifier = classifier.to(device)

Using device cuda and model google/vit-base-patch16-224


In [47]:
# epoch = 780
# step = 50000

# models = [ViTForImageClassification.from_pretrained(model_name) for _ in range(N_MODELS)]
# classifier = nn.Linear(models[0].classifier.in_features * N_MODELS, 5)
# for i_model, model in enumerate(models):
#     model.classifier = nn.Identity()
#     model = model.to(device)

#     state_dict = torch.load(f'vit-finetune-big-and-small-models-pt-3/model_{i_model}_epoch_{epoch}_step_{step}.pth', map_location=device)
#     model.load_state_dict(state_dict)

# state_dict = torch.load(f'vit-finetune-big-and-small-models-pt-3/classifier_epoch_{epoch}_step_{step}.pth', map_location=device)
# classifier.load_state_dict(state_dict)
# classifier = classifier.to(device)

In [4]:
import random
import os
from PIL import Image, ImageOps
from torch.utils.data import Dataset

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.image_info = []
        for index, row in dataframe.iterrows():
            image_info = (row['label'], [])
            for i_model in range(N_MODELS):
                image_info[1].append(row[f'tile_path_{i_model}'])
            self.image_info.append(image_info)

    def __len__(self):
        return len(self.image_info)

    def __getitem__(self, idx):
        label, tile_paths = self.image_info[idx]
        image_paths = []
        for tile_path in tile_paths:
            if os.path.isdir(tile_path):
                image_names = [img for img in os.listdir(tile_path) if img.lower().endswith('.png')]
                if len(image_names) == 1:
                    selected_images = image_names * 2
                else:
                    selected_images = random.sample(image_names, 2)
                image_paths.extend(os.path.join(tile_path, img) for img in selected_images)
        images = [self.transform(Image.open(path)) if self.transform else Image.open(path) for path in image_paths]
        return images, label_to_integer[label]


In [5]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torchvision.transforms as transforms

BATCH_SIZE = 8

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation((-45, 45)),
    transforms.RandomResizedCrop((224, 224), scale=(0.4, 1)),
    transforms.RandAugment(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = ImageDataset(dataframe=train, transform=train_transform)
val_dataset = ImageDataset(dataframe=validation, transform=val_transform)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=3, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=3, shuffle=True)

In [6]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler('training_log_finetune_vit_non_tma_big_and_small_pt_5.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [None]:
import torch
import torch.optim as optim
import logging
import numpy as np
import math
from sklearn.metrics import accuracy_score

initial_lr = 1e-5
num_epochs = 1000

# Function for linear warmup
def warmup_linear(step, warmup_steps=1000):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    progress = float(step - warmup_steps) / float(max(1, 20000 - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

all_parameters = []
for model in models:
    all_parameters += list(model.parameters())
all_parameters += list(classifier.parameters())

optimizer = optim.Adam(all_parameters, lr=initial_lr, weight_decay=1e-2)

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: warmup_linear(step))

# Calculate class weights
# class_counts = np.array([216, 93, 118, 41, 40], dtype=np.float32)
class_counts = np.array([222, 99, 124, 47, 46], dtype=np.float32)
class_weights = 1. / class_counts
class_weights /= class_weights.sum()

# Convert class weights to tensor
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define the loss function with class weights
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

best_val_accuracy = 0.0
step = 0

for epoch in range(num_epochs):
    for model in models:
        model.train()  # set the model to training mode
    classifier.train()
    
    for i, (images, labels) in enumerate(train_dataloader, 0):
        
        if step == 20000:
            break
        
        # Calculate the size of each chunk
        chunk_size = len(images) // N_MODELS

        # Initialize a list to hold the stacked tensors
        stacked_chunks = []

        # Split the images into chunks and stack them
        for i in range(0, len(images), chunk_size):
            # Ensure that we don't go out of bounds on the last chunk
            chunk = images[i:i + chunk_size]
            # Stack the chunk along a new dimension
            stacked_chunk = torch.stack(chunk).to(device)
            stacked_chunks.append(stacked_chunk)

        labels = labels.to(device)
        
        # Linearly increase the learning rate
        lr_scale = warmup_linear(step)
        for g in optimizer.param_groups:
            g['lr'] = lr_scale * initial_lr
            
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        all_outputs = []

        # Forward pass
        for i_model in range(N_MODELS):
            # inputs = processor(images=stacked_chunks[i_model].view(-1, 3, 224, 224), return_tensors="pt", do_rescale=False)
            # for key in inputs.keys():
            #     inputs[key] = inputs[key].to(device)

            # outputs = models[i_model](**inputs).logits.view(2, -1, 768)
            outputs = models[i_model](stacked_chunks[i_model].view(-1, 3, 224, 224)).logits.view(2, -1, 768)
            outputs = outputs.mean(dim=0).view(-1, 768)
            all_outputs.append(outputs)

        all_outputs = torch.cat(all_outputs, dim=1)
        all_outputs = classifier(all_outputs)
        
        loss = criterion(all_outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        scheduler.step()
        step += 1

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))
        
        # Validation accuracy 
        if step % 1000 == 0:
            for model in models:
                model.eval()  # set the model to eval mode
            classifier.eval()

            all_labels = []
            all_preds = []

            with torch.no_grad():
                for _ in range(100):
                    for images, labels in val_dataloader:
                        # Calculate the size of each chunk
                        chunk_size = len(images) // N_MODELS

                        # Initialize a list to hold the stacked tensors
                        stacked_chunks = []

                        # Split the images into chunks and stack them
                        for i in range(0, len(images), chunk_size):
                            # Ensure that we don't go out of bounds on the last chunk
                            chunk = images[i:i + chunk_size]
                            # Stack the chunk along a new dimension
                            stacked_chunk = torch.stack(chunk).to(device)
                            stacked_chunks.append(stacked_chunk)

                        labels = labels.numpy()  # Convert labels to numpy array for later use in accuracy calculation

                        all_outputs = []

                        # Forward pass
                        for i_model in range(N_MODELS):
#                             inputs = processor(images=stacked_chunks[i_model].view(-1, 3, 224, 224), return_tensors="pt", do_rescale=False)
#                             for key in inputs.keys():
#                                 inputs[key] = inputs[key].to(device)

#                             outputs = models[i_model](**inputs).logits
                            outputs = models[i_model](stacked_chunks[i_model].view(-1, 3, 224, 224)).logits
                            outputs = outputs.view(2, -1, 768)
                            outputs = outputs.mean(dim=0).view(-1, 768)
                            all_outputs.append(outputs)

                        all_outputs = torch.cat(all_outputs, dim=1)
                        all_outputs = classifier(all_outputs)
                        probs = all_outputs.softmax(dim=1)

                        # Get predicted labels
                        preds = torch.argmax(probs, dim=1).cpu().numpy()

                        # Store predictions and labels
                        all_preds.extend(preds)
                        all_labels.extend(labels)

                        if len(all_preds) > 100:
                            break
        
            # Calculate accuracy
            accuracy = accuracy_score(all_labels, all_preds)
            logging.info("Validation Accuracy: %s" % accuracy)
            
            for model in models:
                model.train()  # set the model to training mode
            classifier.train()

        if step % 1000 == 0:
            # Assuming 'model' is defined
            for i_model, model in enumerate(models):
                torch.save(model.state_dict(), f'vit-finetune-big-and-small-models-pt-5/model_{i_model}_epoch_{epoch}_step_{step}.pth')
            torch.save(classifier.state_dict(), f'vit-finetune-big-and-small-models-pt-5/classifier_epoch_{epoch}_step_{step}.pth')

    # # Save model after each epoch
    # torch.save(model.state_dict(), f'vit-finetune-big-and-small-models/model_epoch_{epoch+1}.pth')
    # logging.info(f'Model saved after epoch {epoch+1}')

logging.info('Finished Training')
