In [33]:
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, ClassLabel, load_from_disk
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from datetime import datetime
import numpy as np
import wandb
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler


In [34]:
label_column = "Detailed_Cluster_names"
sub_group = "B cells"
batch_size = 1024
epochs = 200

output_dir = "/equilibrium/datasets/TCGA-histological-data/scImmunology/output"

In [35]:
dataset = load_from_disk(
    "/equilibrium/datasets/TCGA-histological-data/scImmunology/dataset_v3"
)

In [36]:
# filter dataset by sub_group
dataset = dataset.filter(lambda example: example['Cluster_names'] == sub_group)

In [37]:
# # Mappa Detailed_Cluster_names → Cluster_names (macro categorie)
# sub_to_macro = {
#     # B cells
#     "Activated": "B cells",
#     "Atypical memory": "B cells",
#     "CD5+ B cells": "B cells",
#     "Naive": "B cells",
#     "Naive-IFN": "B cells",
#     "Non-switched memory": "B cells",
#     "Plasma cells": "B cells",
#     "Switched memory": "B cells",
#     "Transitional": "B cells",

#     # CD4+ T cells
#     "Exhausted-like memory": "CD4+ T cells",
#     "HLA-DR+ memory": "CD4+ T cells",
#     "Naive": "CD4+ T cells",
#     "Naive-IFN": "CD4+ T cells",
#     "Temra": "CD4+ T cells",
#     "Terminal effector": "CD4+ T cells",
#     "Tfh": "CD4+ T cells",
#     "Th1": "CD4+ T cells",
#     "Th1/Th17": "CD4+ T cells",
#     "Th17": "CD4+ T cells",
#     "Th2": "CD4+ T cells",
#     "Th22": "CD4+ T cells",
#     "Treg KLRB1+RORC+": "CD4+ T cells",
#     "Treg cytotoxic": "CD4+ T cells",
#     "Treg memory": "CD4+ T cells",
#     "Treg naive": "CD4+ T cells",

#     # gd T cells
#     "Vd1 GZMB+": "gd T cells",
#     "Vd1 GZMK+": "gd T cells",
#     "Vd2 GZMB+": "gd T cells",
#     "Vd2 GZMK+": "gd T cells",
#     "gd naive": "gd T cells",

#     # Myeloid cells
#     "Classical monocytes": "Myeloid cells",
#     "Non-classical monocytes": "Myeloid cells",
#     "cDCs": "Myeloid cells",
#     "pDCs": "Myeloid cells",

#     # NK cells
#     "CD56bright": "NK cells",
#     "CD56dim CD57+": "NK cells",
#     "CD56dim CD57-": "NK cells",
#     "CD56dim CD57int": "NK cells",
#     "CD56dim CD57low": "NK cells",
#     "Proliferative": "NK cells",   # ← in NK classifier

#     # TRAV1-2- CD8+ T cells
#     "HLA-DR+": "TRAV1-2- CD8+ T cells",
#     "NKT-like": "TRAV1-2- CD8+ T cells",
#     "Naive": "TRAV1-2- CD8+ T cells",
#     "Naive-IFN": "TRAV1-2- CD8+ T cells",
#     "Proliferative": "TRAV1-2- CD8+ T cells",  # ← in CD8 classifier
#     "Tcm CCR4+": "TRAV1-2- CD8+ T cells",
#     "Tcm CCR4-": "TRAV1-2- CD8+ T cells",
#     "Tem GZMB+": "TRAV1-2- CD8+ T cells",
#     "Tem GZMK+": "TRAV1-2- CD8+ T cells",
#     "Temra": "TRAV1-2- CD8+ T cells",
#     "Tmem KLRC2+": "TRAV1-2- CD8+ T cells",
#     "Trm": "TRAV1-2- CD8+ T cells",

#     # Macro categorie dirette (già definite come classi singole)
#     "DN T cells": "DN T cells",
#     "MAIT cells": "MAIT cells",
#     "Progenitor cells": "Progenitor cells"
# }

# dataset = dataset.map(lambda x: {
#     "Cluster_names": sub_to_macro.get(x["Detailed_Cluster_names"], "Unknown")
# })


In [38]:
if label_column != "Cluster_names":
    dataset = dataset.filter(lambda example: example["Cluster_names"] == sub_group)

# Gestione etichette
unique_labels = sorted(set(dataset['validation'][label_column]))
class_label = ClassLabel(names=unique_labels)

if len(unique_labels) == 1:
    raise ValueError(f"Only one class found in the dataset: {unique_labels}. Cannot train a classifier with a single class.")

dataset = dataset.cast_column(label_column, class_label)
label_feature = dataset['train'].features[label_column]
id2label = {i: label_feature.int2str(i) for i in range(label_feature.num_classes)}
label2id = {v: k for k, v in id2label.items()}

In [39]:
dataset = dataset.rename_column("embedding", "input_ids")
column_remove = [col for col in dataset['validation'].column_names if col != label_column and col != 'input_ids']
dataset = dataset.remove_columns(column_remove)

# Convertilo in torch Dataset
dataset['train'] = dataset['train'].with_format(
    type='torch',
    columns=['input_ids', label_column],
    device='cuda'
)

dataset['validation'] = dataset['validation'].with_format(
    type='torch',
    columns=['input_ids', label_column],
    device='cuda'
)

dataset = dataset.rename_columns({label_column: "labels"})

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List
from transformers.modeling_outputs import SequenceClassifierOutput


class MLPBlock(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
        super().__init__()
        self.use_residual = use_residual and (input_dim == output_dim)

        self.linear = nn.Linear(input_dim, output_dim)
        self.bn = nn.BatchNorm1d(output_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        x = self.linear(x)
        x = self.bn(x)
        x = self.activation(x)
        x = self.dropout(x)
        if self.use_residual:
            x = x + identity
        return x


class AdvancedMLPClassifier(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        dropout_rate: float = 0.2,
        use_residual_in_hidden: bool = True,
        loss_fn: Optional[nn.Module] = None
    ):
        super().__init__()



        self.initial_bn = nn.BatchNorm1d(input_dim)

        all_dims = [input_dim] + hidden_dims
        mlp_layers = []
        for i in range(len(all_dims) - 1):
            mlp_layers.append(
                MLPBlock(
                    input_dim=all_dims[i],
                    output_dim=all_dims[i + 1],
                    dropout_rate=dropout_rate,
                    use_residual=use_residual_in_hidden and (all_dims[i] == all_dims[i + 1])
                )
            )
        self.hidden_network = nn.Sequential(*mlp_layers)
        self.output_projection = nn.Linear(all_dims[-1], output_dim)
        self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss()

        self._initialize_weights()

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = True,
        **kwargs
    ) -> SequenceClassifierOutput:

        if input_ids.ndim > 2:
            input_ids = input_ids.view(input_ids.size(0), -1)  # Flatten if necessary

        x = self.initial_bn(input_ids)
        x = self.hidden_network(x)
        logits = self.output_projection(x)

        loss = self.loss_fn(logits, labels) if labels is not None else None

        if not return_dict:
            return (logits, loss) if loss is not None else (logits,)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


In [41]:
input_dim = 3072
labels = np.unique(dataset['test']['labels'])
output_dim = len(labels)
hidden_dims = [3072, 1536, 768]
hidden_str = "hdim_" + "x".join(map(str, hidden_dims))

model = AdvancedMLPClassifier(input_dim, hidden_dims, output_dim)

In [42]:
current_time = datetime.now()
run_name = f"{sub_group}_MLPClassifier_{hidden_str}_{current_time.strftime('%Y-%m-%d_%H-%M-%S')}"

# WandB
wandb.init(
    project="scImmunologyClassification",
    group=sub_group,
    name=run_name,
)

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


In [43]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average="weighted")
    return {"accuracy": acc, "f1": f1}

# Training args
training_args = TrainingArguments(
    output_dir=output_dir,
    logging_strategy="steps",
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    weight_decay=0.01,
    logging_dir="./logs",
    load_best_model_at_end=True,
    report_to="wandb",
    remove_unused_columns=False,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    fp16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False
)

early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=10)

# Multiprocessing
torch.multiprocessing.set_start_method('spawn', force=True)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback],
)

trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,1.010747,0.71275,0.691885
2,1.454600,0.791379,0.755481,0.748615
3,1.454600,0.731254,0.757715,0.75216
4,0.793100,0.707282,0.761207,0.756611
5,0.793100,0.696555,0.767211,0.76254
6,0.710400,0.693142,0.764837,0.759481
7,0.710400,0.690909,0.766792,0.763061
8,0.671100,0.691518,0.765536,0.762998
9,0.648600,0.683509,0.767072,0.761998
10,0.648600,0.686672,0.765117,0.758276




TrainOutput(global_step=588, training_loss=0.7029283452196186, metrics={'train_runtime': 398.1796, 'train_samples_per_second': 28776.461, 'train_steps_per_second': 14.064, 'total_flos': 0.0, 'train_loss': 0.7029283452196186, 'epoch': 21.0})

In [44]:
trainer.save_model(output_dir + f"/{run_name}")

In [45]:
def log_confusion_matrix(trainer, dataset, split_name):
    preds_output = trainer.predict(dataset)
    preds = np.argmax(preds_output.predictions, axis=-1)
    labels = preds_output.label_ids

    # Calcola confusion matrix
    cm = confusion_matrix(labels, preds, labels=list(id2label.keys()))

    # Log su W&B (con etichette leggibili)
    wandb.log({
        f"confusion_matrix/{split_name}": wandb.plot.confusion_matrix(
            y_true=labels,
            preds=preds,
            class_names=list(id2label.values())
        )
    })

# Log confusion matrix su validation e test
log_confusion_matrix(trainer, dataset["validation"], "validation")
log_confusion_matrix(trainer, dataset["test"], "test")



