In [None]:
import timm 
from datasets import load_dataset, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import  TrainingArguments
import wandb
import torch
from torch import nn

from src.seq2image import GenomicImageGenerator
from src.model import inizialize_model
from src.train import inizialize_training

In [None]:
model_name = "timm/resnet18.a1_in1k"
pretrained = True

input_channel = 1
method = "cgr" 

batch_size = 128
epochs = 50
lr = 0.005

In [None]:
dataset = load_dataset("katarinagresova/Genomic_Benchmarks_human_nontata_promoters")

In [None]:
wandb.init(project="GenomicVision")

In [None]:
model, processor = inizialize_model(model_name=model_name, num_labels=len(set(dataset['test']['label'])), pretrained=pretrained) # type: ignore 

In [None]:
generator = GenomicImageGenerator(image_size=processor.data_config["input_size"][-1], sequence_col="seq", label_col="label")

In [None]:
if input_channel == 1:
    processed_datasets = DatasetDict({
        'train': generator.generate_single_channel_dataset(dataset['train'], method), # type: ignore
        'test': generator.generate_single_channel_dataset(dataset['test'], method) # type: ignore
    })

    run_name = f"{model_name}_{method}_{'pretrained' if pretrained else 'not_pretrained'}"

elif input_channel == 3:
    processed_datasets = DatasetDict({
        'train': generator.generate_3_channel_dataset(dataset['train']), # type: ignore
        'test': generator.generate_3_channel_dataset(dataset['test']) # type: ignore
    })

    run_name = f"{model_name}_3channel_{'pretrained' if pretrained else 'not_pretrained'}"
    
else:
    raise ValueError("Input channel must be 1 or 3")

In [None]:
def transform(examples):
        # `examples["image"]` è una lista di oggetti PIL.Image
        # Questa funzione viene chiamata batch per batch durante l'addestramento.

        # `image.convert("RGB")` assicura che qualsiasi immagine,
        # anche se a 1 canale, venga convertita in un formato a 3 canali.
        images = [image.convert("RGB") for image in examples["image"]]
        
        # Il processore si occupa di resizing, normalizzazione e conversione in tensori PyTorch.
        examples["pixel_values"] = processor(images, return_tensors="pt")["pixel_values"]
        
        return examples

In [None]:
processed_datasets = processed_datasets.map(
    transform,
    batched=True,
    remove_columns=['image'] 
)

Map:  63%|██████▎   | 17000/27097 [07:01<04:14, 39.60 examples/s]

In [None]:
training_args = TrainingArguments(
    output_dir="/equilibrium/datasets/TCGA-histological-data/genomic_vision/results",
    eval_strategy="epoch",
    save_strategy="epoch",

    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_steps=10,
    report_to="wandb",
    save_total_limit=2,
    run_name=run_name
)

In [None]:
trainer = inizialize_training(model, training_args, processed_datasets)

In [None]:
trainer.train()

In [None]:
wandb.finish()