In [None]:
import timm 
from datasets import load_from_disk, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import  TrainingArguments, Trainer, EarlyStoppingCallback
import wandb

from src.data.preprocess import DatasetProcessor

In [None]:

# # --- Categoria 1: Veloci e Bilanciati (Ottimi per baseline) ---
# modelli_veloci = [
#     "resnet18",
#     "resnet34",
#     "resnet50",
#     "efficientnet_b0",
#     "efficientnet_b1",
#     "mobilenetv3_large_100",
# ]

# # --- Categoria 2: Alte Prestazioni (Per massima accuratezza) ---
# modelli_performanti = [
#     "efficientnetv2_s",
#     "efficientnetv2_m",
#     "convnext_tiny",
#     "convnext_small",
#     "maxvit_tiny_tf_224",
# ]

# # --- Categoria 3: Vision Transformers (Architetture basate su Attention) ---
# modelli_transformer = [
#     "vit_small_patch16_224",
#     "vit_base_patch16_224",
#     "swin_tiny_patch4_window7_224",
#     "swin_small_patch4_window7_224",
# ]

In [None]:
model_name = "resnet18"
pretrained = True

DATASET_PATH = "/home/vcivale/GenomicVision/data/interim"  

In [None]:
BATCH_SIZE = 2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
OUTPUT_DIR = "/equilibrium/datasets/TCGA-histological-data/genomic_vision/results"

# W&B
WANDB_PROJECT = "genomic-vision"
WANDB_RUN_NAME = f"{model_name}_pretrained_{pretrained}"

In [None]:
model = timm.create_model(
        model_name,
        pretrained=pretrained,
        in_chans=4,
        num_classes=2
    )

In [None]:
dataset_processor = DatasetProcessor(model_name=model_name, in_channels=4)

In [None]:
raw_dataset = load_from_disk(DATASET_PATH)

# Ensure raw_dataset is a DatasetDict
if not isinstance(raw_dataset, DatasetDict):
    raw_dataset = DatasetDict({'train': raw_dataset})

# dataset = dataset_processor.process_dataset(raw_dataset)
dataset = raw_dataset

train_val = dataset['train'].train_test_split(test_size=0.1, seed=42)

dataset = DatasetDict({
    'train': train_val['train'],
    'validation': train_val['test'],
    'test': dataset['test']
})

In [None]:
args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE * 2,
        learning_rate=LEARNING_RATE,
        weight_decay=0.01,
        warmup_steps=500,
        
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_steps=50,
        save_total_limit=2,
        
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        
        dataloader_num_workers=0,
        report_to="wandb",
        run_name=WANDB_RUN_NAME,
        dataloader_drop_last=True,
    )

In [None]:
trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'],
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )

In [None]:
trainer.train()

In [None]:
test_metrics = trainer.evaluate(dataset['test'], metric_key_prefix="test")

In [None]:
trainer.save_model()
wandb.finish()