In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset, ClassLabel, load_from_disk
from transformers import Trainer, TrainingArguments, default_data_collator
from huggingface_hub import login
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.utils.class_weight import compute_class_weight
import torch.nn.functional as F
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from sklearn.metrics import confusion_matrix



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SUB_GROUP = "All"


In [3]:
LABEL_COLUMN = "Cluster_names"


In [4]:
dataset = load_from_disk(f"/equilibrium/datasets/TCGA-histological-data/scImmunology/dataset_v3")

In [5]:
# convert labels to ClassLabel type
class_names = dataset['validation'].unique(LABEL_COLUMN)
class_names.sort()
num_classes = len(class_names)
class_label = ClassLabel(num_classes=num_classes, names=class_names)
dataset = dataset.cast_column(LABEL_COLUMN, class_label)

In [6]:
dataset = dataset.rename_column("Cluster_names", "labels")
dataset = dataset.rename_column("embedding", "input_ids")

In [7]:
import torch
import torch.nn as nn
from transformers import Trainer, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput

class SimpleRNAClassifier(nn.Module):
    def __init__(self, input_dim=3072, num_classes=10):
        super().__init__()
        self.main = nn.Sequential(
            nn.BatchNorm1d(input_dim),

            # Blocco 1
            ResidualBlock(input_dim, 2048),
            nn.Dropout(0.4),

            # Blocco 2
            ResidualBlock(2048, 1024),
            nn.Dropout(0.3),

            # Blocco finale
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.2),

            nn.Linear(512, num_classes)
        )

        # Inizializzazione pesi
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='gelu')
                nn.init.zeros_(module.bias)

    def forward(self, input_ids, labels=None, **kwargs):
        # input_ids: (batch_size, seq_len) -> (batch_size, input_dim)
        if input_ids.dim() > 2:
            input_ids = input_ids.view(input_ids.size(0), -1)

        logits = self.main(input_ids)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss(label_smoothing=0.1)(logits, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits
        )

class ResidualBlock(nn.Module):
    """Blocco residuo semplificato con bottleneck"""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_dim, out_dim // 2),
            nn.BatchNorm1d(out_dim // 2),
            nn.GELU(),
            nn.Linear(out_dim // 2, out_dim),
            nn.BatchNorm1d(out_dim)
        )

        self.shortcut = nn.Linear(in_dim, out_dim) if in_dim != out_dim else nn.Identity()

    def forward(self, x):
        return F.gelu(self.block(x) + self.shortcut(x))

In [8]:
# def compute_class_weights_multiclass(labels: np.ndarray, num_classes: int) -> torch.Tensor:
#     """
#     Calcola i pesi delle classi per classificazione multi-class

#     Args:
#         labels: Array 1D con gli indici delle classi [0, 1, 2, ...]
#         num_classes: Numero totale di classi

#     Returns:
#         Tensor con i pesi per ogni classe [num_classes]
#     """
#     # Conta quanti esempi per ogni classe
#     unique, counts = np.unique(labels, return_counts=True)

#     # Crea tensor con tutti i conteggi (incluse classi con 0 esempi)
#     cls_counts = torch.zeros(num_classes, dtype=torch.float)
#     cls_counts[unique] = torch.tensor(counts, dtype=torch.float)

#     # Evita divisione per zero (classi senza esempi avranno peso minimo)
#     cls_counts = torch.clamp(cls_counts, min=1.0)

#     # Formula standard: n_samples / (n_classes * n_samples_per_class)
#     total_samples = torch.tensor(len(labels), dtype=torch.float)
#     weights = total_samples / (num_classes * cls_counts)

#     # Normalizza (opzionale - rende il peso minimo = 1.0)
#     weights = weights / weights.min()

#     return weights

# # Il tuo codice corretto:
# all_labels = np.array(dataset['train']['labels'])

# # Assicurati che sia 1D (per multi-class)
# if all_labels.ndim == 2 and all_labels.shape[1] == 1:
#     all_labels = all_labels.flatten()

# print(f"Shape labels: {all_labels.shape}")
# print(f"Classi uniche: {np.unique(all_labels)}")
# print(f"Range classi: {all_labels.min()} - {all_labels.max()}")

# # Determina il numero di classi
# num_classes = len(np.unique(all_labels))
# # O se conosci il numero totale: num_classes = dataset['train'].features['labels'].num_classes

# print(f"Numero classi: {num_classes}")

# # Calcola i pesi
# class_weights = compute_class_weights_multiclass(all_labels, num_classes)

# loss_fn = nn.CrossEntropyLoss(weight=class_weights)

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union

class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout_rate=0.2):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        return x + self.block(x)

class ResidualMLPConv1D(nn.Module):
    def __init__(
        self, input_dim, hidden_dims, output_dim,
        dropout_rate=0.2, use_residual=True,
        use_conv=True, conv_channels=64, kernel_size=3
    ):
        super().__init__()

        self.use_conv = use_conv

        # Normalizzazione iniziale
        self.input_bn = nn.BatchNorm1d(input_dim)

        # Blocco convoluzionale opzionale
        if use_conv:
            self.conv_block = nn.Sequential(
                nn.Conv1d(1, conv_channels, kernel_size=kernel_size, padding=kernel_size // 2),
                nn.BatchNorm1d(conv_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Conv1d(conv_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2),
                nn.BatchNorm1d(1),
                nn.ReLU()
            )

        # Primo layer MLP
        self.first_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.BatchNorm1d(hidden_dims[0]),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Blocchi nascosti
        hidden_layers = []
        for i in range(len(hidden_dims) - 1):
            if use_residual and hidden_dims[i] == hidden_dims[i + 1]:
                hidden_layers.append(ResidualBlock(hidden_dims[i], dropout_rate))
            else:
                hidden_layers.extend([
                    nn.Linear(hidden_dims[i], hidden_dims[i + 1]),
                    nn.BatchNorm1d(hidden_dims[i + 1]),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate)
                ])
        self.hidden_layers = nn.Sequential(*hidden_layers)

        # Bottleneck + output
        self.bottleneck = nn.Sequential(
            nn.Linear(hidden_dims[-1], hidden_dims[-1] // 2),
            nn.BatchNorm1d(hidden_dims[-1] // 2),
            nn.ReLU(),
        )
        self.output_layer = nn.Linear(hidden_dims[-1] // 2, output_dim)

        self._initialize_weights()

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        **kwargs
    ) -> dict:
        x = self.input_bn(input_ids)

        if self.use_conv:
            x = x.unsqueeze(1)  # [B, 1, D]
            x = self.conv_block(x)
            x = x.squeeze(1)    # [B, D]

        x = self.first_layer(x)
        x = self.hidden_layers(x)
        x = self.bottleneck(x)
        logits = self.output_layer(x)

        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)

        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List
from transformers.modeling_outputs import SequenceClassifierOutput


class MLPBlock(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
        super().__init__()
        self.use_residual = use_residual and (input_dim == output_dim)

        self.linear = nn.Linear(input_dim, output_dim)
        self.bn = nn.BatchNorm1d(output_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        x = self.linear(x)
        x = self.bn(x)
        x = self.activation(x)
        x = self.dropout(x)
        if self.use_residual:
            x = x + identity
        return x


class AdvancedMLPClassifier(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        dropout_rate: float = 0.2,
        use_residual_in_hidden: bool = True,
        loss_fn: Optional[nn.Module] = None
    ):
        super().__init__()



        self.initial_bn = nn.BatchNorm1d(input_dim)

        all_dims = [input_dim] + hidden_dims
        mlp_layers = []
        for i in range(len(all_dims) - 1):
            mlp_layers.append(
                MLPBlock(
                    input_dim=all_dims[i],
                    output_dim=all_dims[i + 1],
                    dropout_rate=dropout_rate,
                    use_residual=use_residual_in_hidden and (all_dims[i] == all_dims[i + 1])
                )
            )
        self.hidden_network = nn.Sequential(*mlp_layers)
        self.output_projection = nn.Linear(all_dims[-1], output_dim)
        self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss()

        self._initialize_weights()

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = True,
        **kwargs
    ) -> SequenceClassifierOutput:

        if input_ids.ndim > 2:
            input_ids = input_ids.view(input_ids.size(0), -1)  # Flatten if necessary

        x = self.initial_bn(input_ids)
        x = self.hidden_network(x)
        logits = self.output_projection(x)

        loss = self.loss_fn(logits, labels) if labels is not None else None

        if not return_dict:
            return (logits, loss) if loss is not None else (logits,)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


In [11]:

input_dim = len(dataset["train"][0]["input_ids"])
labels = np.unique(dataset['test']['labels'])
output_dim = len(labels)
hidden_dims = [2048, 1024, 512]


hidden_str = "hdim_" + "x".join(map(str, hidden_dims))
model = AdvancedMLPClassifier(input_dim, hidden_dims, output_dim)


In [12]:
current_time = datetime.now()

run_name = f"AdvancedMLPClassifier_{hidden_str}_{current_time.strftime('%Y-%m-%d_%H-%M-%S')}"



In [13]:
from transformers import EarlyStoppingCallback
output_dir=f"/equilibrium/datasets/TCGA-histological-data/scImmunology/checkpoints/{run_name}"


wandb.init(
    project="scFoundationClassification",  # Nome del progetto W&B
    group=SUB_GROUP,                         # Gruppo: utile se hai più classificatori
    name=run_name,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average="weighted")
    return {"accuracy": acc, "f1": f1}

training_args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=1024 ,
    per_device_eval_batch_size=1024 ,
    num_train_epochs=600,
    weight_decay=0.01,
    logging_dir="./logs",
    load_best_model_at_end=True,
    report_to="wandb",
    remove_unused_columns=True,
    optim="adamw_torch",
    lr_scheduler_type="cosine", 
    fp16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False

)

# Early stopping callback per fermare l'addestramento se la valutazione non migliora
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=10)


[34m[1mwandb[0m: Currently logged in as: [33mvincenzo-civale[0m ([33mvincenzo-civale-universi-degli-studi-di-firenze[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback],
)

trainer.train()
# trainer.train(resume_from_checkpoint=True)



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.2876,0.244491,0.906187,0.904369




In [None]:
# id2label = {v: k for k, v in label2id.items()}  # Inverti il dizionario

def log_confusion_matrix(y_true_ids, y_pred_ids, title):
    y_true = [id2label[i] for i in y_true_ids]
    y_pred = [id2label[i] for i in y_pred_ids]

    labels = list(id2label.values())
    cm = confusion_matrix(y_true, y_pred, labels=labels)

    # Calcola le percentuali riga per riga
    cm_percent = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100
    cm_annot = np.empty_like(cm, dtype=object)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            cm_annot[i, j] = f"{cm[i, j]}\n({cm_percent[i, j]:.1f}%)" if cm.sum(axis=1)[i] > 0 else f"{cm[i, j]}\n(0.0%)"

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=cm_annot, fmt='', cmap='Blues',
                xticklabels=labels,
                yticklabels=labels,
                cbar=False)
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("True")

    wandb.log({title: wandb.Image(plt)})
    plt.close()

# === 3. Predizione su validation set ===
val_preds = trainer.predict(dataset["validation"])
val_y_true = val_preds.label_ids
val_y_pred = np.argmax(val_preds.predictions, axis=1)
log_confusion_matrix(val_y_true, val_y_pred, title=f"{SUB_GROUP} - Validation")

# === 4. Predizione su test set ===
test_preds = trainer.predict(dataset["test"])
test_y_true = test_preds.label_ids
test_y_pred = np.argmax(test_preds.predictions, axis=1)
log_confusion_matrix(test_y_true, test_y_pred, title=f"{SUB_GROUP} - Test")

In [None]:
output_dir = f"saved_models/{run_name}"
trainer.save_model(output_dir)

artifact = wandb.Artifact(name=run_name, type="model")
artifact.add_dir(output_dir)
wandb.log_artifact(artifact)