In [6]:
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, ClassLabel, load_from_disk
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from datetime import datetime
import numpy as np
import wandb
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler

from model import MLPClassifier, MLPConfig

In [7]:
label_column = "Detailed_Cluster_names"
sub_group = "Myeloid cells"
batch_size = 256
epochs = 200

output_dir = "/equilibrium/datasets/TCGA-histological-data/scImmunology/output"

In [8]:
dataset = load_from_disk(
    "/equilibrium/datasets/TCGA-histological-data/scImmunology/dataset_v2"
)

In [9]:
# # Mappa Detailed_Cluster_names → Cluster_names (macro categorie)
# sub_to_macro = {
#     # B cells
#     "Activated": "B cells",
#     "Atypical memory": "B cells",
#     "CD5+ B cells": "B cells",
#     "Naive": "B cells",
#     "Naive-IFN": "B cells",
#     "Non-switched memory": "B cells",
#     "Plasma cells": "B cells",
#     "Switched memory": "B cells",
#     "Transitional": "B cells",

#     # CD4+ T cells
#     "Exhausted-like memory": "CD4+ T cells",
#     "HLA-DR+ memory": "CD4+ T cells",
#     "Naive": "CD4+ T cells",
#     "Naive-IFN": "CD4+ T cells",
#     "Temra": "CD4+ T cells",
#     "Terminal effector": "CD4+ T cells",
#     "Tfh": "CD4+ T cells",
#     "Th1": "CD4+ T cells",
#     "Th1/Th17": "CD4+ T cells",
#     "Th17": "CD4+ T cells",
#     "Th2": "CD4+ T cells",
#     "Th22": "CD4+ T cells",
#     "Treg KLRB1+RORC+": "CD4+ T cells",
#     "Treg cytotoxic": "CD4+ T cells",
#     "Treg memory": "CD4+ T cells",
#     "Treg naive": "CD4+ T cells",

#     # gd T cells
#     "Vd1 GZMB+": "gd T cells",
#     "Vd1 GZMK+": "gd T cells",
#     "Vd2 GZMB+": "gd T cells",
#     "Vd2 GZMK+": "gd T cells",
#     "gd naive": "gd T cells",

#     # Myeloid cells
#     "Classical monocytes": "Myeloid cells",
#     "Non-classical monocytes": "Myeloid cells",
#     "cDCs": "Myeloid cells",
#     "pDCs": "Myeloid cells",

#     # NK cells
#     "CD56bright": "NK cells",
#     "CD56dim CD57+": "NK cells",
#     "CD56dim CD57-": "NK cells",
#     "CD56dim CD57int": "NK cells",
#     "CD56dim CD57low": "NK cells",
#     "Proliferative": "NK cells",   # ← in NK classifier

#     # TRAV1-2- CD8+ T cells
#     "HLA-DR+": "TRAV1-2- CD8+ T cells",
#     "NKT-like": "TRAV1-2- CD8+ T cells",
#     "Naive": "TRAV1-2- CD8+ T cells",
#     "Naive-IFN": "TRAV1-2- CD8+ T cells",
#     "Proliferative": "TRAV1-2- CD8+ T cells",  # ← in CD8 classifier
#     "Tcm CCR4+": "TRAV1-2- CD8+ T cells",
#     "Tcm CCR4-": "TRAV1-2- CD8+ T cells",
#     "Tem GZMB+": "TRAV1-2- CD8+ T cells",
#     "Tem GZMK+": "TRAV1-2- CD8+ T cells",
#     "Temra": "TRAV1-2- CD8+ T cells",
#     "Tmem KLRC2+": "TRAV1-2- CD8+ T cells",
#     "Trm": "TRAV1-2- CD8+ T cells",

#     # Macro categorie dirette (già definite come classi singole)
#     "DN T cells": "DN T cells",
#     "MAIT cells": "MAIT cells",
#     "Progenitor cells": "Progenitor cells"
# }

# dataset = dataset.map(lambda x: {
#     "Cluster_names": sub_to_macro.get(x["Detailed_Cluster_names"], "Unknown")
# })


In [10]:
if label_column != "Cluster_names":
    dataset = dataset.filter(lambda example: example["Cluster_names"] == sub_group)

# Gestione etichette
unique_labels = sorted(set(dataset['validation'][label_column]))
class_label = ClassLabel(names=unique_labels)

if len(unique_labels) == 1:
    raise ValueError(f"Only one class found in the dataset: {unique_labels}. Cannot train a classifier with a single class.")

dataset = dataset.cast_column(label_column, class_label)
label_feature = dataset['train'].features[label_column]
id2label = {i: label_feature.int2str(i) for i in range(label_feature.num_classes)}
label2id = {v: k for k, v in id2label.items()}

Filter:   0%|          | 0/1533093 [00:00<?, ? examples/s]

Filter: 100%|██████████| 1533093/1533093 [34:18<00:00, 744.75 examples/s]
Filter: 100%|██████████| 191637/191637 [04:16<00:00, 746.31 examples/s]
Filter: 100%|██████████| 191637/191637 [04:20<00:00, 736.91 examples/s]
Casting the dataset: 100%|██████████| 269548/269548 [01:06<00:00, 4023.44 examples/s] 
Casting the dataset: 100%|██████████| 33693/33693 [00:08<00:00, 3837.04 examples/s] 
Casting the dataset: 100%|██████████| 33694/33694 [00:14<00:00, 2317.96 examples/s] 


In [None]:
column_remove = [col for col in dataset['validation'].column_names if col != label_column and col != 'input_ids']
dataset = dataset.remove_columns(column_remove)

# Convertilo in torch Dataset
dataset['train'] = dataset['train'].with_format(
    type='torch',
    columns=['input_ids', label_column],
    device='cuda'
)

dataset['validation'] = dataset['validation'].with_format(
    type='torch',
    columns=['input_ids', label_column],
    device='cuda'
)

dataset = dataset.rename_columns({label_column: "labels"})

In [9]:
input_dim = 3072
labels = np.unique(dataset['test']['labels'])
output_dim = len(labels)
hidden_dims = [1536, 768]
hidden_str = "hdim_" + "x".join(map(str, hidden_dims))

config = MLPConfig(
    input_dim=input_dim,
    hidden_dims=hidden_dims,
    output_dim=output_dim,
    dropout_rate=0.2,
    use_residual_in_hidden=True,
    id2label=id2label,
    label2id=label2id
)

model = MLPClassifier(config)

In [10]:
current_time = datetime.now()
run_name = f"{sub_group}_MLPClassifier_{hidden_str}_{current_time.strftime('%Y-%m-%d_%H-%M-%S')}"

# WandB
wandb.init(
    project="scImmunologyClassification",
    group=sub_group,
    name=run_name,
)

[34m[1mwandb[0m: Currently logged in as: [33mvincenzo-civale[0m ([33mvincenzo-civale-universi-degli-studi-di-firenze[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average="weighted")
    return {"accuracy": acc, "f1": f1}

# Training args
training_args = TrainingArguments(
    output_dir=output_dir,
    logging_strategy="steps",
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    weight_decay=0.01,
    logging_dir="./logs",
    load_best_model_at_end=True,
    report_to="wandb",
    remove_unused_columns=False,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    fp16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False
)

early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=10)

# Multiprocessing
torch.multiprocessing.set_start_method('spawn', force=True)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback],
)

trainer.train()



Epoch,Training Loss,Validation Loss


In [None]:
trainer.save_model(output_dir + f"/{run_name}")

In [None]:
def log_confusion_matrix(trainer, dataset, split_name):
    preds_output = trainer.predict(dataset)
    preds = np.argmax(preds_output.predictions, axis=-1)
    labels = preds_output.label_ids

    # Calcola confusion matrix
    cm = confusion_matrix(labels, preds, labels=list(id2label.keys()))

    # Log su W&B (con etichette leggibili)
    wandb.log({
        f"confusion_matrix/{split_name}": wandb.plot.confusion_matrix(
            y_true=labels,
            preds=preds,
            class_names=list(id2label.values())
        )
    })

# Log confusion matrix su validation e test
log_confusion_matrix(trainer, dataset["validation"], "validation")
log_confusion_matrix(trainer, dataset["test"], "test")