# NOTEBOOK PER L'ADDESTRAMENTO DEL MODELLO DI SEGMENTAZIONE

In [None]:
from pathlib import Path
from torch.utils.data import DataLoader , ConcatDataset , random_split
import os
import json
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
import torchvision.transforms.v2 as transforms
import torch.nn as nn
import cv2
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torchvision import tv_tensors
import timm
from CableDetection import CableTrainDataset

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

## IMPOSTAZIONE DEI PERCORSI DEL PROGETTO

In questa sezione vengono definiti i percorsi principali del progetto, i file di dataset e i parametri utilizzati per l’addestramento e il fine-tuning del modello.

In [None]:
current_dir = Path.cwd()
image_train_path = current_dir/"dataset/train"
image_train_json_path = current_dir/'dataset/train/train.json'
TARGET_SIZE = (704,704)

dir_model_path = current_dir/"models/"
model_path= dir_model_path/"UNET++.pth"

fine_tuned_model_path= dir_model_path/"UNET++_ft.pth"

## DEFINIZIONE DEL DEVICE 

In [None]:
device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

## DEFINIZIONE DATASET DI ADDESTRAMENTO

## Data Augmentation per il training del modello

In [None]:
transforms_cavi = transforms.Compose([
    transforms.RandomChoice([
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.RandomVerticalFlip(p=1.0),
    ])
])

In [None]:
transforms_cavi= transforms.Compose([transforms.ElasticTransform(alpha=50.0, sigma=5.0)])

In [None]:
transforms_cavi= transforms.Compose([
    transforms.RandomChoice([
        transforms.RandomRotation(degrees=(90, 90)),   # Ruota esattamente di 90°
        transforms.RandomRotation(degrees=(180, 180)), # Ruota esattamente di 180°
        transforms.RandomRotation(degrees=(270, 270)), # Ruota esattamente di 270°
    ])
])

## Suddivisione del Dataset

- Imposta la dimensione del batch a **8**
- Calcola automaticamente le dimensioni di training e validation
- Suddivide il dataset in:
  - **80% Training**
  - **20% Validation**

La suddivisione è effettuata in modo casuale tramite `random_split`.


In [None]:
BATCH_SIZE = 8
dataset = CableTrainDataset(image_train_path, image_train_json_path,size=TARGET_SIZE, transform=transforms_cavi)

dataset_size = len(dataset)
val_size = int(0.2 * dataset_size) 
train_size = dataset_size - val_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

## CREAZIONE DEL DATALOADER

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

## DEFINIZIONE MODELLO

In [None]:
NUM_CLASSES = 1

model = smp.UnetPlusPlus(
    encoder_name="timm-resnest50d",
    encoder_weights="imagenet",
    in_channels=3,
    classes=NUM_CLASSES,
    decoder_attention_type="scse",
)

## CARICAMENTO DEI PESI DEL MODELLO PRE-ADDESTRATO

In [None]:
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

## DEFINIZIONE OTTIMIZZATORE E FUNZIONE DI LOSS

### Ottimizzazione e Funzione di Loss

- **Ottimizzatore**: `AdamW`  
  - Learning rate: `1e-4`  
  - Weight decay: `1e-4`

- **Scheduler**: `ReduceLROnPlateau`  
  - Riduce il learning rate quando la loss di validazione smette di migliorare  
  - Fattore di riduzione: `0.5`  
  - Pazienza: `3` epoche

- **Loss combinata**:
  - **Dice Loss** (40%) → migliora la segmentazione delle aree
  - **Focal Loss** (60%) → gestisce lo sbilanciamento tra classi

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=3
)

def combined_loss(preds, targets):
  dice_loss = smp.losses.DiceLoss(mode="binary")
  focal_loss = smp.losses.FocalLoss(mode="binary")
  return 0.4*dice_loss(preds, targets) + 0.6*focal_loss(preds, targets)

## ADDESTRAMENTO DEL MODELLO

In [None]:
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0

    for images, masks, _ in tqdm(loader):
        images = images.to(device).float()
        masks = masks.to(device).long()


        optimizer.zero_grad()
        outputs = model(images)  # (B,2,H,W)
        loss = combined_loss(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
@torch.no_grad()
def validate(model, loader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for images, masks, _ in tqdm(loader):
            images = images.to(device).float()
            masks = masks.to(device).long()

            outputs = model(images)
            loss = combined_loss(outputs, masks)
            total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
num_epochs = 10
best_val = float("inf")

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = validate(model, val_loader)

    scheduler.step(val_loss)

    print(f"Epoch {epoch:02d} | train {train_loss:.4f} | val {val_loss:.4f}")

    #if val_loss < best_val:
    #   best_val = val_loss
    #   torch.save(model.state_dict(), f'''best_model_{epoch}.pth''')

## SALVATAGGIO DEL MODELLO

In [None]:
torch.save(model.state_dict(), fine_tuned_model_path)
print(f'''Modello salvato in {fine_tuned_model_path}''')