In [1]:

import torch
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset, DataLoader
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from src.data.dataset import PatchFromH5Dataset, stratified_split, plot_class_distributions
from src.rl.train import ViTUCBTrainer
from transformers import (
    TrainingArguments,
    EarlyStoppingCallback,
)
from src.rl.modelling import ViT_UCB_Pruning
import evaluate
import os


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["WANDB_PROJECT"] = "ViT-Pruning-Project"

In [3]:
IMG_SIZE = 224
TRAIN_BATCH_SIZE = 8
NUM_EPOCHS = 30

PRUNING_RATIO = 0.35

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
dataset = PatchFromH5Dataset(
    h5_dir='/equilibrium/datasets/TCGA-histological-data/hest/patches/patches/',
    transform=transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])
)

In [5]:
labels = dataset.labels

In [6]:
# Crea un DataFrame con indici e label
df = pd.DataFrame({
    "index": np.arange(len(labels)),
    "label": labels
})

# Trova il numero di elementi della classe minoritaria
min_count = df["label"].value_counts().min()

# Per ogni classe, seleziona min_count elementi a caso
undersampled_df = (
    df.groupby("label", group_keys=False)
      .apply(lambda x: x.sample(n=min_count, random_state=42)).reset_index(drop=True)
)

# Mischia gli indici
undersampled_indices = undersampled_df["index"].sample(frac=1, random_state=42).tolist()


  .apply(lambda x: x.sample(n=min_count, random_state=42)).reset_index(drop=True)


In [7]:
undersampled_labels = [labels[i] for i in undersampled_indices]

trainval_idx, test_idx = train_test_split(
    undersampled_indices,
    test_size=0.3,
    stratify=undersampled_labels,
    random_state=42
)

# Ottieni i label corrispondenti per il secondo split
trainval_labels = [labels[i] for i in trainval_idx]

# Split: train vs val
train_idx, val_idx = train_test_split(
    trainval_idx,
    test_size=0.3,
    stratify=trainval_labels,
    random_state=42
)

# Crea i subset
train_dataset = Subset(dataset, train_idx)
val_dataset   = Subset(dataset, val_idx)
test_dataset  = Subset(dataset, test_idx)

In [8]:
accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [9]:
run_name = f"vit-ucb-pruning-{PRUNING_RATIO}"  

In [10]:
labels_num = len(np.unique(dataset.labels))

print(f"Number of classes: {labels_num}")
model = ViT_UCB_Pruning(model_name="hf-hub:MahmoodLab/uni", 
    pretrained=True, 
    n_classes=labels_num, 
    keep_ratio=PRUNING_RATIO,        
    exclude_cls=False
)

Number of classes: 27
Loading source model 'hf-hub:MahmoodLab/uni'...


09/03/2025 12:28:23 - INFO - timm.models._builder - Loading pretrained weights from Hugging Face hub (MahmoodLab/uni)


In [None]:
training_args = TrainingArguments(
    # --- Gestione dei salvataggi e output ---
    output_dir=f"/equilibrium/datasets/TCGA-histological-data/ViT-UCB-Pruning-checkpoints/Ratio-{PRUNING_RATIO}",  # Directory temporanea, non userÃ  molto spazio
    report_to="wandb",                   
    push_to_hub=True,                    
    hub_model_id=f"Yuto2007/ViT-UCB-Pruning",
    run_name=run_name, 

    # --- Pianificazione del training ---
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    dataloader_num_workers=16,
    
    # --- Learning Rate Scheduler ---
    learning_rate=1e-1,
    warmup_steps=500,                     
    weight_decay=0.01,

    # --- Valutazione e Best Model ---
    eval_strategy="epoch",          
    save_strategy="epoch",               
    load_best_model_at_end=True,         
    metric_for_best_model="eval_loss",    
    greater_is_better=False,             

    logging_strategy="steps",
    logging_steps=100,
    save_total_limit=1,    

    fp16=False             
)

In [13]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=7,
    early_stopping_threshold=0.0,
)

In [14]:
# def vit_ucb_collate_fn_fast(batch):
#     pixel_values = torch.stack([item['pixel_values'].permute(2,0,1) if item['pixel_values'].ndim==3 else item['pixel_values'].unsqueeze(0) for item in batch])
#     labels = torch.tensor([item['labels'] for item in batch], dtype=torch.long)
#     return {"pixel_values": pixel_values, "labels": labels}


In [15]:
trainer = ViTUCBTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,           
    compute_metrics=compute_metrics,      
    callbacks=[early_stopping_callback],
    # data_collator=vit_ucb_collate_fn  
)


In [None]:
print("ðŸš€ Inizio dell'addestramento...")
trainer.train()

ðŸš€ Inizio dell'addestramento...


[34m[1mwandb[0m: Currently logged in as: [33mvincenzo-civale[0m ([33mvincenzo-civale-universi-degli-studi-di-firenze[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Accuracy
1,4.2294,4.09078,0.037037




In [None]:
print("\nðŸš€ Valutazione finale sul test set con il modello migliore...")
test_results = trainer.predict(test_dataset)

In [None]:
print("Risultati del Test Set:")
print(test_results.metrics)

In [None]:
import wandb
wandb.log({"final_test_metrics": test_results.metrics})

In [None]:
wandb.finish()
print("\nâœ… Addestramento e valutazione completati.")