## The Dataset

In [2]:
import os, sys

# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from torchvision.io import read_image
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
import pandas as pd


class Rxrx1(Dataset):

    def __init__(self, root_dir=None):
        self.le_ = LabelEncoder()

        self.root_dir = os.path.join(root_dir, "rxrx1_v1.0")
        self.imgs_dir = os.path.join(self.root_dir, "images")
        self.metadata = pd.read_csv(os.path.join(self.root_dir, "metadata.csv"))
        self.le_.fit(self.metadata['cell_type'].unique())
        self.metadata['cell_type'] = self.le_.transform(self.metadata['cell_type'])
        self.items = [(os.path.join(self.imgs_dir, item.experiment, "Plate" + str(item.plate), item.well + '_s' +
                       str(item.site) + '.png'), item.cell_type, item.sirna_id) for item in self.metadata.itertuples(index=False)]

    def __getitem__(self, index):
        img_path, cell_type, sirna_id = self.items[index]
        return (read_image(img_path), cell_type, sirna_id)

    def __len__(self):
        return len(self.items)


## Network

In [3]:
import torch
import torch.nn as nn
from torchvision import models


class SimCLR(nn.Module):
    def __init__(self):
        super(SimCLR, self).__init__()
        self.backbone = models.resnet18(weights='DEFAULT')
        self.backbone.fc = nn.Identity()  # fully-connected removed
        self.projection_head = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

    def forward(self, x):
        features = self.backbone(x)
        projections = self.projection_head(features)
        return projections


## Utils

In [4]:
import os, sys

from prettytable import PrettyTable
from wilds import get_dataset
import torch
import torch.nn.functional as F
import os
from source.net import SimCLR
from sklearn.cluster import KMeans
from tqdm import tqdm
import numpy as np
from scipy.stats import mode


def display_configs(configs):
    t = PrettyTable(["Name", "Value"])
    t.align = "r"
    for key, value in configs.items():
        t.add_row([key, value])
    print(t, flush=True)


def load_device(config):
    if config["device"] == "gpu":
        assert torch.cuda.is_available(), "Notebook is not configured properly!"
        device = "cuda:0"
        print(
            "Training network on {}".format(torch.cuda.get_device_name(device=device))
        )
        for i in range(torch.cuda.device_count()):
            print(torch.cuda.get_device_properties(i).name)

    else:
        device = torch.device("cpu")
    return device


def download_dataset():
    dataset = get_dataset(dataset="rxrx1", download=True, root_dir="")


def contrastive_loss(features, device, temperature=0.5):
    features = F.normalize(features, dim=1)
    similarity_matrix = torch.mm(features, features.T) / temperature
    batch_size = features.shape[0]
    labels = torch.arange(batch_size).to(device)
    loss = F.cross_entropy(similarity_matrix, labels)
    return loss


def info_nce_loss(features, device, temperature=0.5):
    """
    Implements Noise Contrastive Estimation loss as explained in the simCLR paper.
    Actual code is taken from here https://github.com/sthalles/SimCLR/blob/master/simclr.py
    Args:
        - features: torch tensor of shape (2*N, D) where N is the batch size.
            The first N samples are the original views, while the last
            N are the modified views.
        - device: torch device
        - temperature: float
    """
    n_views = 2
    assert features.shape[0] % n_views == 0  # make sure shapes are correct
    batch_size = features.shape[0] // n_views

    labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(device)

    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)
    assert similarity_matrix.shape == (
        n_views * batch_size, n_views * batch_size)
    assert similarity_matrix.shape == labels.shape

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    # assert similarity_matrix.shape == labels.shape

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

    logits = logits / temperature
    return F.cross_entropy(logits, labels)


def config_loader(config):
    net, loss, opt = ..., ..., ...
    if str(config["net"]).__contains__("simclr"):
        net = SimCLR()

    if str(config["loss"]).__contains__("contrastive"):
        loss = contrastive_loss
    if str(config["loss"]).__contains__("NCE"):
        loss = info_nce_loss

    if str(config["opt"]).__contains__("adam"):
        opt = torch.optim.Adam(net.parameters(), lr=0.005)

    return (net, loss, opt)


def test_kmean_accuracy(net, test_loader, device):
    net.eval()
    test_features = []
    test_labels = []
    with torch.no_grad():
        for x_batch, y_batch, _ in tqdm(
            test_loader
        ):  # Suppongo tu abbia le etichette nel test set
            features = net(
                x_batch.to(torch.float).to(device)
            )  # Estrazione delle feature
            test_features.append(features)
            test_labels.append(y_batch.to(device))

    test_features = torch.cat(test_features).cpu()
    test_labels = torch.cat(test_labels).cpu()
    kmeans = KMeans(n_clusters=4, random_state=42)
    predicted_clusters = kmeans.fit_predict(test_features)
    cluster_to_class = {}

    for cluster_id in range(4):
        indices = np.where(predicted_clusters == cluster_id)[0]
        true_labels = test_labels[indices]
        most_common_class = mode(true_labels).mode
        cluster_to_class[cluster_id] = most_common_class

    mapped_predictions = np.array([cluster_to_class[c] for c in predicted_clusters])
    accuracy = np.mean(mapped_predictions == test_labels.numpy())
    print(f"Test Accuracy: {accuracy * 100:.2f}%")


def validation_loss(net, val_loader, device, transform, std_transform, loss_func):
    validation_loss_values = []
    pbar = tqdm(total=len(val_loader), desc=f"validation")
    with net.eval() and torch.no_grad():
        for x_batch, _, _ in val_loader:
            standard_views = torch.cat(
                [std_transform(img.unsqueeze(0)) for img in x_batch], dim=0).to(device)
            augmented_views = torch.cat(
                [transform(img.unsqueeze(0)) for img in x_batch], dim=0).to(device)
            block = torch.cat([standard_views, augmented_views], dim=0)
            out_feat = net.forward(block.to(torch.float))
            loss = loss_func(out_feat, device)

            validation_loss_values.append(loss.item())
            pbar.update(1)
            pbar.set_postfix({"Validation Loss": loss.item()})

    return validation_loss_values


def save_model(epoch, net, opt, train_loss, val_loss, batch_size, checkpoint_dir, optimizer):
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": net.state_dict(),
            "optimizer_state_dict": opt.state_dict(),
            "training_loss_values": train_loss,
            "validation_loss_values": val_loss,
            "batch_size": batch_size,
            "optimizer": optimizer,
        },
        os.path.join(checkpoint_dir, "checkpoint{}".format(epoch + 1)),
    )

def load_yaml(inFile):
    with open(inFile, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    display_configs(config)

    assert Path(config['checkpoint_dir']).is_dir(), "Please provide a valid directory to save checkpoints in."
    assert Path(config['dataset_dir']).is_dir(), "Please provide a valid directory to load dataset."
    if 'load_checkpoint' in config.keys():
        assert Path(config['load_checkpoint']).is_dir(), "Please provide a valid directory to load dataset."

    return config

## Trainer

In [5]:
import os, sys

import torch, torchvision, yaml
import torchvision.transforms.v2 as transforms
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
from pathlib import Path

torch.manual_seed(42)


class Trainer():
    def __init__(self, net, device, config, opt, loss_func, scheduler=None):
        self.net = net
        self.device = device
        self.config = config
        self.opt = opt
        self.loss_func = loss_func
        self.scheduler = scheduler

    def load_checkpoint(self):
        checkpoint = torch.load(self.config['load_checkpoint'])
        if self.config['multiple_gpus']:
            model_dict = {key.replace(
                "module.", ""): value for key, value in checkpoint['model_state_dict'].items()}
            self.net.load_state_dict(model_dict)
        else:
            self.net.load_state_dict(checkpoint['model_state_dict'])

        self.opt.load_state_dict(checkpoint['optimizer_state_dict'])
        try:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        except:
            print("Scheduler not found in the checkpoint.")

        last_epoch = checkpoint['epoch']
        training_loss_values = checkpoint['training_loss_values']
        validation_loss_values = checkpoint['validation_loss_values']
        config['batch_size'] = checkpoint['batch_size']
        return (last_epoch, training_loss_values, validation_loss_values)

    def train(self, split_sizes, dataset, transform, std_transform):

        train_size = int(split_sizes[0] * len(dataset))
        val_size = int(split_sizes[1] * len(dataset))
        test_size = len(dataset) - train_size - val_size
        train_dataset, val_dataset, test_dataset = random_split(
            dataset, [train_size, val_size, test_size])

        train_dataloader = DataLoader(
            train_dataset, batch_size=config['batch_size'])

        if 'load_checkpoint' in self.config.keys():
            print('Loading latest checkpoint... ')
            last_epoch, training_loss_values, validation_loss_values = self.load_checkpoint()
            print(f"Checkpoint {config['load_checkpoint']} Loaded")
        else:
            last_epoch = 0
            training_loss_values = []  # store every training loss value
            validation_loss_values = []  # store every validation loss value

        self.net = self.net.to(device)
        self.net.train()
        if self.config['multiple_gpus']:
            self.net = nn.DataParallel(self.net)

        for epoch in range(last_epoch, int(self.config['epochs'])):
            pbar = tqdm(total=len(train_dataloader), desc=f"Epoch-{epoch}")
            for x_batch, _, _ in train_dataloader:
                standard_views = torch.cat(
                    [std_transform(img.unsqueeze(0)) for img in x_batch], dim=0).to(device)
                augmented_views = torch.cat(
                    [transform(img.unsqueeze(0)) for img in x_batch], dim=0).to(device)
                block = torch.cat([standard_views, augmented_views], dim=0)
                out_feat = self.net.forward(block.to(torch.float))
                loss = self.loss_func(out_feat, device)

                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                training_loss_values.append(loss.item())

                pbar.update(1)
                pbar.set_postfix({'Loss': loss.item()})

            if (epoch + 1) % int(self.config['model_save_freq']) == 0:
                save_model(epoch, self.net, self.opt, training_loss_values, validation_loss_values,
                           self.config['batch_size'], self.config['checkpoint_dir'], self.config['opt'])
                test_kmean_accuracy(self.net.backbone, DataLoader(
                    test_dataset, batch_size=self.config['batch_size']), self.device)

            if (epoch + 1) % int(config['evaluation_freq']) == 0:
                print(f"Running Validation...")
                validation_loss_values += validation_loss(self.net, DataLoader(val_dataset, batch_size=self.config['batch_size'], pin_memory_device=self.device, pin_memory=True,
                                                                               shuffle=True, num_workers=4, drop_last=True, prefetch_factor=1), self.device, transform, std_transform, self.loss_func)

        return training_loss_values, validation_loss_values

## Main

In [6]:
if __name__ == "__main__":

    config = load_yaml("/homes/nmorelli/AIBIO_proj/config/train/server_conf.yaml")
    device = load_device(config)
    dataset = Rxrx1(config['dataset_dir'])
    net, loss_func, opt = config_loader(config)

    # no transformations
    std_transform = transforms.Compose(
        [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)])
    # view for self supervised learning
    transform = transforms.Compose([
        transforms.RandomResizedCrop(256),
        transforms.ColorJitter(brightness=0.5, contrast=0.5),
        transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
        transforms.Compose([transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)])
    ])

    tr_ = Trainer(net, device, config, opt, loss_func)

    training_loss_values, validation_loss_values = tr_.train([0.7, 0.05, 0.15], dataset, transform, std_transform)

+-----------------+-------------------------------------------+
|            Name |                                     Value |
+-----------------+-------------------------------------------+
|      batch_size |                                       128 |
|     dataset_dir |                    /work/ai4bio2024/rxrx1 |
|  checkpoint_dir | /work/ai4bio2024/rxrx1/checkpoints/simclr |
|          device |                                       gpu |
|          epochs |                                         4 |
|             net |                                    simclr |
|            loss |                                       NCE |
|             opt |                                      adam |
| evaluation_freq |                                         4 |
| model_save_freq |                                         4 |
|   multiple_gpus |                                      True |
+-----------------+-------------------------------------------+
Training network on Tesla P100-PCIE-16GB

validation: 100%|██████████| 49/49 [00:43<00:00,  1.13it/s, Validation Loss=5.43]
Epoch-0:   0%|          | 0/687 [00:00<?, ?it/s]

: 