In [1]:
import evaluate

import torch
import torch.nn as nn
import torch.nn.functional as F

import albumentations as A
import numpy as np


from albumentations.pytorch import ToTensorV2
from argparse import Namespace

from datasets import load_from_disk
from transformers import (
    PreTrainedModel, 
    PretrainedConfig,
    Trainer,
    TrainingArguments,
)
import segmentation_models_pytorch as smp

from transformers import (
    Trainer,
    TrainingArguments,
)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        # preds: [batch, 1, H, W] - sigmoid outputs (0 to 1)
        # targets: [batch, 1, H, W] - binary masks (0 or 1)
        
        # Flatten to 1D for easier computation
        preds_flat = preds.view(preds.size(0), -1)
        targets_flat = targets.view(targets.size(0), -1)
        
        intersection = (preds_flat * targets_flat).sum(dim=1)
        union = preds_flat.sum(dim=1) + targets_flat.sum(dim=1)
        
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()  # Average over batch


class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha  # Weight for Dice (1 - alpha for BCE)
        self.dice = DiceLoss(smooth=smooth)
        self.bce = nn.BCELoss()

    def forward(self, preds, targets):
        dice_loss = self.dice(preds, targets)
        bce_loss = self.bce(preds, targets)
        return self.alpha * dice_loss + (1 - self.alpha) * bce_loss

def seg_data_collator(features):
    pixel_values = torch.stack([f["pixel_values"] for f in features])
    labels = torch.stack([f["label"] for f in features])
    return {"pixel_values": pixel_values, "labels": labels}






# Define the transformations
train_transform = A.Compose([
    # 1. Geometric: Handles both image and mask
    #A.HorizontalFlip(p=0.5),
    #A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
    
    # 1. Photometric: Only affects the image
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.4),
    A.RandomBrightnessContrast(p=0.3),
    
    # 2. Robustness: CoarseDropout to force feature learning
    A.CoarseDropout(
        num_holes_range=(2, 4), 
        hole_height_range=(10, 20), 
        hole_width_range=(10, 20), 
        #num_holes_range=(3, 6),
        #hole_height_range=(10, 20),
        #hole_width_range=(10, 20),
        p=0.3),
    
    # 3. Normalization (using ImageNet stats)
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])


def preprocess_fn(examples):
    images = [np.array(image.convert("RGB")) for image in examples["image"]]
    # Ensure masks are single-channel (L) for segmentation
    masks = [np.array(mask.convert("L")) for mask in examples["label"]]
    
    inputs = {"pixel_values": [], "labels": []}
    
    for img, mask in zip(images, masks):
        # Apply Albumentations
        augmented = train_transform(image=img, mask=mask)
        
        inputs["pixel_values"].append(augmented["image"])
        # Ensure mask is long type and scaled (0 and 1)
        inputs["labels"].append(augmented["label"].long())
        
    return inputs



In [3]:

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    metric = evaluate.load("mean_iou")

    # Upsample logits to match label size
    predictions = np.argmax(logits, axis=1)
    return metric.compute(
        predictions=predictions, 
        references=labels, 
        num_labels=2, 
        ignore_index=255
    )

class LogoSegmentationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")           # shape (batch, h, w)
        outputs = model(**inputs)               # Segformer outputs
        logits = outputs.logits                 # (batch, 2, h, w)

        # Option A: Keep CE (very stable baseline)
        # loss_fct = torch.nn.CrossEntropyLoss(ignore_index=255)
        # loss = loss_fct(logits, labels.long())

        # Option B: Dice + BCE (your preference)
        preds = torch.softmax(logits, dim=1)[:, 1]   # prob of logo class
        preds = preds.unsqueeze(1)                   # (B,1,H,W)

        labels_binary = (labels == 1).float().unsqueeze(1)  # (B,1,H,W)

        # Reuse your earlier losses
        #dice_loss = DiceLoss()(preds, labels_binary)
        #dice_loss = nn.BCELoss()(preds, labels_binary)
        
        #bce_loss  = F.binary_cross_entropy_with_logits(logits[:,1,:,:], labels_binary.squeeze(1))
        # or F.binary_cross_entropy(preds, labels_binary)

        #loss = 0.5 * dice_loss + 0.5 * bce_loss
        loss = CombinedLoss()(preds, labels_binary)

        return (loss, outputs) if return_outputs else loss



In [None]:
"""
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512",
    num_labels=1,
    ignore_mismatched_sizes=True
)
"""
class UNetPlusPlusConfig(PretrainedConfig):
    model_type = "unetplusplus"
    def __init__(
        self, 
        encoder_name='resnet34', 
        encoder_depth=5, 
        encoder_weights='imagenet', 
        decoder_use_norm='batchnorm', 
        decoder_channels=(256, 128, 64, 32, 16), 
        decoder_attention_type=None, 
        decoder_interpolation='nearest', 
        in_channels=3, 
        classes=1, 
        activation=None, 
        aux_params=None
        **kwargs):
        
        super().__init__(**kwargs)
        self.encoder_name = encoder_name
        self.encoder_depth = encoder_depth
        self.encoder_weights = encoder_weights
        self.decoder_use_norm = decoder_use_norm
        self.decoder_channels = decoder_channels
        self.decoder_attention_type = decoder_attention_type
        self.decoder_interpolation = decoder_interpolation
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation
        self.aux_params = aux_params

class UNetPlusPlusHF(PreTrainedModel):
    config_class = UNetPlusPlusConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = smp.UnetPlusPlus(
            encoder_name=config.encoder_name,
            encoder_depth=config., 
            encoder_weights=config.encoder_weights, 
            decoder_use_norm=config., 
            decoder_channels=config., 
            decoder_attention_type=config., 
            decoder_interpolation=config., 
            in_channels=config., 
            classes=config., 
            activation=config., 
            aux_params=config.
        )
    """
    def forward(self, pixel_values, labels=None):
        logits = self.model(pixel_values)
        
        loss = None
        if labels is not None:
            # We reuse the DiceBCELoss logic from before
            loss_fct = CombinedLoss()
            # SMP outputs [Batch, Classes, H, W]
            # We take the 'logo' channel (index 1) for binary comparison
            loss = loss_fct(logits[:, 1, :, :], labels)

        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
    """

model = smp.UnetPlusPlus(
    encoder_name='resnet34', 
    encoder_depth=5, 
    encoder_weights='imagenet', 
    decoder_use_norm='batchnorm', 
    decoder_channels=(256, 128, 64, 32, 16), 
    decoder_attention_type=None, 
    decoder_interpolation='nearest', 
    in_channels=3, 
    classes=1, 
    activation=None, 
    aux_params=None
    )

training_args = TrainingArguments(
    output_dir="./output_image_segmentation/",
    learning_rate=6e-5,
    num_train_epochs=30,
    per_device_train_batch_size=8,    # adjust to your GPU (4â€“16 typical)
    per_device_eval_batch_size=8,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    remove_unused_columns=False,       # important for custom datasets
    push_to_hub=False,                 # set True later if you want
    report_to="none",                  # or "wandb", "tensorboard"
    load_best_model_at_end=True,
    metric_for_best_model="mean_iou",
    greater_is_better=True,
    bf16=True,                         # if GPU supports
    save_total_limit = 3,
    do_train = True,
    do_eval = True,
    do_predict = True,
)



In [None]:
ds_path = f"/home/yassir/projects/image_semantic_segmentation/data/processed/"
ds = load_from_disk(ds_path)

# Apply to your dataset
ds["train"].set_transform(preprocess_fn)
ds["validation"].set_transform(preprocess_fn)
ds["test"].set_transform(preprocess_fn)


model = smp.UnetPlusPlus(
    encoder_name='resnet34', 
    encoder_depth=5, 
    encoder_weights='imagenet', 
    decoder_use_norm='batchnorm', 
    decoder_channels=(256, 128, 64, 32, 16), 
    decoder_attention_type=None, 
    decoder_interpolation='nearest', 
    in_channels=3, 
    classes=1, 
    activation=None, 
    aux_params=None,
    )
model.config = Namespace(**model.config)
model.config.use_cache = True

trainer = LogoSegmentationTrainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    compute_metrics=compute_metrics,
    data_collator=seg_data_collator,
)

trainer.train()

In [None]:
model.config

In [None]:
ds_path = f"/home/yassir/projects/image_semantic_segmentation/data/processed/"
ds = load_from_disk(ds_path)
ds

In [None]:
ds