In [None]:
import sys
import os

current_dir = os.path.abspath(os.getcwd())
project_home_dir = os.path.abspath(os.path.join(current_dir, os.pardir))

sys.path.append(project_home_dir)

In [None]:
import torch
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import wandb

from src.data_preprocessing import prepare_dataset, visualize_sample
from src.model import train_segformer

In [None]:
class_rgb_values = {
    0: [0, 0, 255],      # Class 0 is represented by blue pixels
    1: [0, 255, 0],      # Class 1 is represented by green pixels
    2: [255, 0, 0],      # Class 2 is represented by red pixels
    3: [255, 85, 255],    # Class 3 is represented by pink pixels
    #5: [0, 170, 255    # Class 4 
}

In [None]:
img_dir = "/home/inside-tech/Desktop/image_segmentation/data/raw/_4_classi/images"
mask_dir = "/home/inside-tech/Desktop/image_segmentation/masks"

wandb.init(project="CowSegmentation", name="SegFormer")

# Prepara dataloader
train_loader, val_loader, test_loader, full_dataset = prepare_dataset(
    img_dir=img_dir,
    mask_dir=mask_dir,
    class_rgb_values=class_rgb_values,
    batch_size=8
)

# Calcola i pesi delle classi per il bilanciamento
class_weights = full_dataset.get_class_weight()

# Esegui il training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model, trainer = train_segformer(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_classes=len(class_rgb_values),
    epochs=30,
    learning_rate=1e-4,
    device=device,
    class_weights=class_weights
)

In [None]:
# import os
# import numpy as np
# from PIL import Image
# import cv2
# from tqdm import tqdm

# def extract_binary_masks(input_path, output_path, class_rgb_values):
#     """
#     Estrae maschere binarie per ogni classe RGB presente nelle immagini di segmentazione.

#     Args:
#         input_path (str): Percorso alla directory contenente le immagini di segmentazione.
#         output_path (str): Percorso dove salvare le maschere binarie.
#         class_rgb_values (dict): Dizionario che mappa l'indice di classe ai valori RGB corrispondenti.
#     """
#     # Crea la directory di output se non esiste
#     os.makedirs(output_path, exist_ok=True)

#     # Crea sottodirectory per ogni classe
#     for class_idx in class_rgb_values:
#         class_dir = os.path.join(output_path, f"class_{class_idx}")
#         os.makedirs(class_dir, exist_ok=True)

#     # Lista tutti i file nella directory di input
#     image_files = [f for f in os.listdir(input_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif'))]

#     for image_file in tqdm(image_files, desc="Elaborazione immagini"):
#         # Carica l'immagine
#         img_path = os.path.join(input_path, image_file)
#         img = cv2.imread(img_path)

#         if img is None:
#             print(f"Errore nel caricamento dell'immagine: {img_path}")
#             continue

#         # Converti da BGR a RGB (OpenCV carica come BGR)
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

#         # Estrai il nome del file senza estensione
#         file_name = os.path.splitext(image_file)[0]

#         # Per ogni classe, crea una maschera binaria
#         for class_idx, rgb_value in class_rgb_values.items():
#             # Crea una maschera dove i pixel corrispondono esattamente al valore RGB
#             mask = np.all(img == rgb_value, axis=2).astype(np.uint8) * 255

#             # Verifica se la classe è presente nell'immagine (almeno un pixel)
#             if np.any(mask):
#                 # Salva la maschera binaria
#                 mask_path = os.path.join(output_path, f"class_{class_idx}", f"{file_name}_class{class_idx}.png")
#                 cv2.imwrite(mask_path, mask)


# class_rgb_values = {
#         0: [0, 0, 255],        # Class 0 - blu
#         1: [0, 255, 0],        # Class 1 - verde
#         2: [255, 0, 0],        # Class 2 - rosso
#         3: [255, 85, 255],     # Class 3 - rosa
#         4: [0, 170, 255]       # Class 4 - azzurro
#   }

#     # Percorsi di input e output
# input_path = "/home/inside-tech/Desktop/image_segmentation/data/raw/_4_classi/labels"
# output_path = "/home/inside-tech/Desktop/image_segmentation/masks"

#     # Esegui l'estrazione delle maschere
# extract_binary_masks(input_path, output_path, class_rgb_values)