In [1]:
import torch
import torch.nn as nn

config = {
    'image_size': 224,
    'patch_size': 32,
    'num_classes': 5,
    'dim': 768,
    'depth': 12,
    'heads': 12,
    'mlp_dim': 3072
}

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.patch_embed(x)  # [B, C, H, W]
        x = x.flatten(2)  # [B, C, H*W]
        x = x.transpose(1, 2)  # [B, H*W, C]
        return x

class Attention(nn.Module):
    def __init__(self, dim, heads=12):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, p=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4., qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim, heads=heads)
        self.drop_path = nn.Dropout(p)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), out_features=dim)
        self.mlp_drop = nn.Dropout(p)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.mlp_drop(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embed = PatchEmbedding(patch_size=config['patch_size'], emb_size=config['dim'])
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config['dim']))
        self.pos_embed = nn.Parameter(torch.zeros(1, (config['image_size'] // config['patch_size']) ** 2 + 1, config['dim']))
        self.pos_drop = nn.Dropout(p=0.1)

        self.blocks = nn.Sequential(*[Block(dim=config['dim'], heads=config['heads']) for _ in range(config['depth'])])
        self.norm = nn.LayerNorm(config['dim'], eps=1e-6)
        self.head = nn.Linear(config['dim'], config['num_classes'])

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)

        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)

        return x

In [2]:
import pandas as pd
import os

def get_image_path(image_id:int):
    return os.path.join("tiles", str(image_id))

train = pd.read_csv("train-no-tma.csv")
val = pd.read_csv("validation-no-tma.csv")

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
val['tile_path'] = val['image_id'].apply(lambda x: get_image_path(x))

train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,38366,LGSC,31951,21718,False,tiles/38366
1,63298,HGSC,26067,20341,False,tiles/63298
2,54928,CC,36166,31487,False,tiles/54928
3,18813,CC,54671,32443,False,tiles/18813
4,63429,EC,67783,29066,False,tiles/63429


In [3]:
import os
from torchvision import transforms
from torchvision.transforms import autoaugment
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.image_paths = []
        self.labels = []
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            if os.path.isdir(folder_path):  # Check if the folder_path is a valid directory
                for image_name in os.listdir(folder_path):
                    if image_name.lower().endswith('.png'):  # Check if the file is a PNG
                        image_path = os.path.join(folder_path, image_name)
                        self.image_paths.append(image_path)
                        self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label_to_integer[label]


In [4]:
# import torch
# from torch.utils.data import DataLoader
# from torchvision import transforms
# from PIL import Image
# import os

# def calculate_dataset_stats(dataset, num_samples=10000):
#     loader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=True)
    
#     mean = 0.
#     std = 0.
#     nb_samples = 0.
    
#     for i, (data, _) in enumerate(loader):
#         data = data.view(data.size(0), data.size(1), -1)
#         mean += data.mean(2).sum(0)
#         std += data.var(2).sum(0)
#         nb_samples += data.size(0)
        
#         if i >= num_samples:
#             break

#     mean /= nb_samples
#     std /= nb_samples
#     std = torch.sqrt(std)
    
#     return mean, std

# mean, std = calculate_dataset_stats(train_dataset)
# print("Mean:", mean)
# print("Standard Deviation:", std)

In [5]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split


"""# FOR YES TMA
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    autoaugment.RandAugment(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.6152, 0.5353, 0.5934], std=[0.2387, 0.2385, 0.2317]), # FOR YES TMA. calculated above
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.6152, 0.5353, 0.5934], std=[0.2387, 0.2385, 0.2317]), # FOR YES TMA. calculated above
])"""


# FOR NO TMA
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.8265, 0.7217, 0.8247], std=[0.1133, 0.1265, 0.0960]), # FOR NO TMA. calculated above
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.8265, 0.7217, 0.8247], std=[0.1133, 0.1265, 0.0960]), # FOR NO TMA. calculated above
])

train_dataset = ImageDataset(dataframe=train, transform=train_transform)
val_dataset = ImageDataset(dataframe=val, transform=val_transform)

train_dataloader = DataLoader(train_dataset, batch_size=128, num_workers=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, num_workers=8, shuffle=True)

In [6]:
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = VisionTransformer(config)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print('Total number of parameters:', total_params)

# Calculate class weights
class_counts = np.array([3521456, 1876772, 2126428, 589002, 1053114], dtype=np.float32) # These were derived by looking at the number of files in tile_path for each label
# class_counts = np.array([703, 690, 631, 581, 706], dtype=np.float32) # These were derived by looking at the number of files in tile_path for each label
class_weights = 1. / class_counts
class_weights /= class_weights.sum()

# Convert class weights to tensor
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define the loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-3)

Using device: cuda
Total number of parameters: 87459077


In [7]:
# from torchsummary import summary

# summary(model, input_size=(3, 224, 224))

In [7]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler('training_log_vit_non_tma_pt_3.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [9]:
# checkpoint = torch.load('vit-non-tma-models/model_step_16000.pt')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [11]:
import os
# Specify the directory to save the model checkpoints
checkpoint_dir = "vit-non-tma-models-pt-2/"
os.makedirs(checkpoint_dir, exist_ok=True)

# Number of epochs
num_epochs = 100
step = 0

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for images, labels in train_dataloader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        
        running_loss += current_loss
        logging.info(f"epoch: {epoch}, step: {step}, loss: {current_loss}")
        
        if step % 1000 == 0:
            checkpoint_filename = f"{checkpoint_dir}model_step_{step}.pt"
            torch.save({
                'step': step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_filename)
            
            model.eval()
            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in val_dataloader:
                    images, labels = images.to(device), labels.to(device)

                    # Forward pass
                    outputs = model(images)

                    # Get predictions from the maximum value
                    _, predicted = torch.max(outputs.data, 1)

                    # Total number of labels
                    total += labels.size(0)

                    # Total correct predictions
                    correct += (predicted == labels).sum().item()
                    if total >= 10000:
                        break

            # Calculate validation accuracy
            val_accuracy = 100 * correct / total
            logging.info(f"epoch: {epoch}, step: {step}, validation accuracy: {val_accuracy:.4f}%")
            model.train()
        
        step += 1

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)

            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)

            # Total number of labels
            total += labels.size(0)

            # Total correct predictions
            correct += (predicted == labels).sum().item()
            if total >= 10000:
                break

    # Calculate validation accuracy
    val_accuracy = 100 * correct / total
    # logging.info(f"epoch: {epoch}, step: {step}, validation accuracy: {val_accuracy:.4f}%")
    model.train()
    epoch_loss = running_loss / len(train_dataloader)
#     checkpoint_filename = f"{checkpoint_dir}model_epoch{epoch}.pt"
#     torch.save({
#         'epoch': epoch,
#         'model_state_dict': model.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'loss': epoch_loss,
#     }, checkpoint_filename)
    
#     Validation step
    logging.info(f'Epoch {epoch+1}/{num_epochs}, Steps: {step}, Loss: {epoch_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}%')

KeyboardInterrupt: 

In [None]:
checkpoint_filename = f"{checkpoint_dir}final.pt"
torch.save({
    'step': step,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_filename)