In [1]:
import pandas as pd
train = pd.read_csv("train-no-tma.csv")
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma
0,38366,LGSC,31951,21718,False
1,63298,HGSC,26067,20341,False
2,54928,CC,36166,31487,False
3,18813,CC,54671,32443,False
4,63429,EC,67783,29066,False


In [2]:
validation = pd.read_csv("validation-no-tma.csv")
validation.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma
0,9658,CC,52900,45380,False
1,12522,EC,46605,45511,False
2,34845,HGSC,42908,25840,False
3,38585,LGSC,64822,30320,False
4,23523,MC,74723,45387,False


In [3]:
import os

def get_image_path(image_id:int):
    return os.path.join('tiles', str(image_id))

train['tile_path'] = train['image_id'].apply(lambda x: get_image_path(x))
train.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,38366,LGSC,31951,21718,False,tiles/38366
1,63298,HGSC,26067,20341,False,tiles/63298
2,54928,CC,36166,31487,False,tiles/54928
3,18813,CC,54671,32443,False,tiles/18813
4,63429,EC,67783,29066,False,tiles/63429


In [4]:
validation['tile_path'] = validation['image_id'].apply(lambda x: get_image_path(x))
validation.head()

Unnamed: 0,image_id,label,image_width,image_height,is_tma,tile_path
0,9658,CC,52900,45380,False,tiles/9658
1,12522,EC,46605,45511,False,tiles/12522
2,34845,HGSC,42908,25840,False,tiles/34845
3,38585,LGSC,64822,30320,False,tiles/38585
4,23523,MC,74723,45387,False,tiles/23523


In [5]:
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel

device = "cuda" if torch.cuda.is_available() else "cpu"
# model_name = "openai/clip-vit-large-patch14"
model_name = "openai/clip-vit-base-patch32"
print(f"Using device {device} and model {model_name}")

model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

Using device cuda and model openai/clip-vit-base-patch32


In [6]:
import os
from PIL import Image, ImageOps
from torch.utils.data import Dataset
import torchvision.transforms as transforms

integer_to_label = {
    0: 'HGSC',
    1: 'CC',
    2: 'EC',
    3: 'LGSC',
    4: 'MC',
}

label_to_integer = {
    'HGSC': 0,
    'CC': 1,
    'EC': 2,
    'LGSC': 3,
    'MC': 4,
}

classes = ["high-grade serous carcinoma", "clear-cell ovarian carcinoma", "endometrioid carcinoma", "low-grade serous carcinoma", "mucinous carcinoma"]

class ImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.image_paths = []
        self.labels = []
        for index, row in dataframe.iterrows():
            folder_path = row['tile_path']
            label = row['label']
            if os.path.isdir(folder_path):  # Check if the folder_path is a valid directory
                for image_name in os.listdir(folder_path):
                    if image_name.lower().endswith('.png'):  # Check if the file is a PNG
                        image_path = os.path.join(folder_path, image_name)
                        self.image_paths.append(image_path)
                        self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label_to_integer[label]


In [12]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.transforms import autoaugment

"""# FOR YES TMA
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    autoaugment.RandAugment(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.6152, 0.5353, 0.5934], std=[0.2387, 0.2385, 0.2317]), # FOR YES TMA. calculated above
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.6152, 0.5353, 0.5934], std=[0.2387, 0.2385, 0.2317]), # FOR YES TMA. calculated above
])"""

# FOR NO TMA
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = ImageDataset(dataframe=train, transform=train_transform)
val_dataset = ImageDataset(dataframe=validation, transform=val_transform)

train_dataloader = DataLoader(train_dataset, batch_size=64, num_workers=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, num_workers=8, shuffle=True)

In [13]:
import logging
import sys

# Get the root logger
logger = logging.getLogger()

# Optional: Remove all existing handlers from the logger
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a FileHandler and add it to the logger
file_handler = logging.FileHandler('training_log_finetune_non_tma.txt')
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

# Create a StreamHandler for stderr and add it to the logger
stream_handler = logging.StreamHandler(sys.stderr)
stream_handler.setLevel(logging.ERROR)  # Only log ERROR and CRITICAL messages to stderr
logger.addHandler(stream_handler)

In [14]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [15]:
import torch
import torch.optim as optim
import logging
import numpy as np
from sklearn.metrics import accuracy_score

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-3)

# Calculate class weights
class_counts = np.array([3521456, 1876772, 2126428, 589002, 1053114], dtype=np.float32) # These were derived by looking at the number of files in tile_path for each label
# class_counts = np.array([703, 690, 631, 581, 706], dtype=np.float32) # These were derived by looking at the number of files in tile_path for each label
class_weights = 1. / class_counts
class_weights /= class_weights.sum()

# Convert class weights to tensor
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define the loss function with class weights
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

num_epochs = 100
best_val_accuracy = 0.0
step = 0

for epoch in range(num_epochs):
    model.train()  # set the model to training mode
    
    for i, (images, labels) in enumerate(train_dataloader, 0):
        # Convert images to PIL format
        images = images.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        # Assuming 'processor' and 'classes' are defined
        inputs = processor(text=classes, images=images, return_tensors="pt", padding=True)
        for key in inputs.keys():
            inputs[key] = inputs[key].to(device)

        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        loss = criterion(logits_per_image, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        step += 1

        logging.info('[%d, %5d] loss: %.3f' % (epoch + 1, step, loss.item()))

        if i % 100 == 0:
            model.eval()

            all_preds = []
            all_labels = []

            with torch.no_grad():
                for images, labels in val_dataloader:
                    images = images.to(device)
                    labels = labels.numpy()  # Convert labels to numpy array for later use in accuracy calculation

                    # Preprocess and forward pass
                    inputs = processor(text=classes, images=images, return_tensors="pt", padding=True)
                    for key in inputs.keys():
                        inputs[key] = inputs[key].to(device)

                    outputs = model(**inputs)
                    logits_per_image = outputs.logits_per_image
                    probs = logits_per_image.softmax(dim=1)
                    print(probs)

                    # Get predicted labels
                    preds = torch.argmax(probs, dim=1).cpu().numpy()

                    # Store predictions and labels
                    all_preds.extend(preds)
                    all_labels.extend(labels)
                    
                    if len(all_preds) > 10000:
                        break
        
            # Calculate accuracy
            accuracy = accuracy_score(all_labels, all_preds)
            logging.info("Validation Accuracy: %s" % accuracy)
            model.train()

        if i % 1000 == 0:
            # Assuming 'model' is defined
            torch.save(model.state_dict(), f'clip-finetune-non-tma-models/epoch_{epoch}_batch_{i}.pth')

    # Save model after each epoch
    torch.save(model.state_dict(), f'clip-finetune-non-tma-models/model_epoch_{epoch+1}.pth')
    logging.info(f'Model saved after epoch {epoch+1}')

logging.info('Finished Training')


tensor([[0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],
        [0.2100, 0.2214, 0.1905, 0.2021, 0.1760],


KeyboardInterrupt: 

In [None]:

#             # Validation function
#             model.eval()

#             all_preds = []
#             all_labels = []

#             with torch.no_grad():
#                 for images, labels in val_dataloader:
#                     images = images.to(device)
#                     labels = labels.numpy()  # Convert labels to numpy array for later use in accuracy calculation

#                     # Preprocess and forward pass
#                     inputs = processor(text=classes, images=images, return_tensors="pt", padding=True)
#                     for key in inputs.keys():
#                         inputs[key] = inputs[key].to(device)

#                     outputs = model(**inputs)
#                     logits_per_image = outputs.logits_per_image
#                     probs = logits_per_image.softmax(dim=1)

#                     # Get predicted labels
#                     preds = torch.argmax(probs, dim=1).cpu().numpy()

#                     # Store predictions and labels
#                     all_preds.extend(preds)
#                     all_labels.extend(labels)

#             # Calculate accuracy
#             accuracy = accuracy_score(all_labels, all_preds)
#             logging.info("Validation Accuracy: %s" % accuracy)