In [5]:
import torch as T
import torchvision as TV
import torchaudio as TA
import cv2
import os
import numpy as np
import random
import tqdm as tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch import optim
from torch.utils.data import DataLoader, Dataset
import segmentation_models_pytorch as smp
from glob import glob
from tqdm import tqdm
import albumentations as A
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, confusion_matrix, average_precision_score

In [6]:
if T.cuda.is_available():
    device=T.device("cuda")
else:
    device=T.device("cpu")

print(device)

cuda


In [7]:
# ---------- Paths ----------
train_images = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\train_new"
train_masks = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\trainlabel_new"
validation_images = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\validation_new"
validation_masks = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\validationlabel_new"
test_images = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\test_new"
test_masks = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\testlabel_new"


In [10]:
# ---------------------- Augmentations -----------------------
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.ElasticTransform(p=0.5),
    A.D4(p=1),
    A.ISONoise(
        color_shift=[0.01, 0.05],
        intensity=[0.1, 0.5],
        p=0.5
    ),
    A.RandomBrightnessContrast(brightness_limit=[-0.2, 0.2], contrast_limit=[-0.2, 0.2], brightness_by_max=True, ensure_safe_range=False, p=0.5),
    A.ElasticTransform(
        alpha=300,
        sigma=10,
        interpolation=cv2.INTER_NEAREST,
        approximate=False,
        same_dxdy=True,
        mask_interpolation=cv2.INTER_NEAREST,
        noise_distribution="gaussian",
        keypoint_remapping_method="mask",
        border_mode=cv2.BORDER_CONSTANT,
        fill=0,
        fill_mask=0
    ),
])

base_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ---------------------- Dataset Class -----------------------
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, train_transform=None, base_transform=None, dataset_type="Unknown"):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.train_transform = train_transform
        self.base_transform = base_transform
        self.dataset_type = dataset_type
        self.image_files = sorted(glob(os.path.join(image_dir, "*.jpg")))
        self.mask_files = sorted(glob(os.path.join(mask_dir, "*.png")))
        self._verify_file_pairs()
        
    def _verify_file_pairs(self):
        if len(self.image_files) != len(self.mask_files):
            raise ValueError(f"Mismatched counts in {self.dataset_type} dataset: {len(self.image_files)} images vs {len(self.mask_files)} masks")
            
        for img_path, mask_path in tqdm(zip(self.image_files, self.mask_files), total=len(self.image_files), desc=f"Verifying {self.dataset_type} File Pairs 🔍"):
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            mask_name = os.path.splitext(os.path.basename(mask_path))[0]
            if img_name != mask_name:
                raise ValueError(f"Filename mismatch in {self.dataset_type} dataset: {img_name} vs {mask_name}")

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_files[idx]), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_files[idx], cv2.IMREAD_GRAYSCALE)
        original_img = self.base_transform(T.from_numpy(img).permute(2, 0, 1).float()).to(device)
        original_mask = T.from_numpy(mask).long().to(device)  # Convert mask to tensor directly
        
        if self.train_transform:
            augmented = self.train_transform(image=img, mask=mask)
            aug_img = augmented['image']
            aug_mask = augmented['mask']
            aug_img = self.base_transform(T.from_numpy(aug_img).permute(2, 0, 1).float()).to(device)
            aug_mask = T.from_numpy(aug_mask).long().to(device)
            
            return {
                'original_img': original_img,
                'original_mask': original_mask,
                'augmented_img': aug_img,
                'augmented_mask': aug_mask
            }
        else:
            return {
                'original_img': original_img,
                'original_mask': original_mask
            }

# ---------------------- Datasets & DataLoaders -----------------------
train_dataset = SegmentationDataset(
    train_images, 
    train_masks, 
    train_transform=train_transform,
    base_transform=base_transform,
    dataset_type="Training"
)

val_dataset = SegmentationDataset(
    validation_images,
    validation_masks,
    train_transform=train_transform,
    base_transform=base_transform,
    dataset_type="Validation"
)

test_dataset = SegmentationDataset(
    test_images,
    test_masks,
    train_transform=train_transform,
    base_transform=base_transform,
    dataset_type="Testing"
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    pin_memory=True,
    num_workers=2,
    persistent_workers=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    persistent_workers=True
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    persistent_workers=True
)

Verifying Training File Pairs 🔍: 100%|███████████████████████████████████████████| 800/800 [00:00<00:00, 99989.37it/s]
Verifying Validation File Pairs 🔍: 100%|█████████████████████████████████████████| 176/176 [00:00<00:00, 58675.58it/s]
Verifying Testing File Pairs 🔍: 100%|████████████████████████████████████████████| 600/600 [00:00<00:00, 60010.07it/s]


In [None]:
# ---------------------- Model Definition -----------------------
model = smp.Unet(
    encoder="efficientnet-b7",
    encoder_weights="imagenet",
    encoder_depth=4,
    decoder_use_batchnorm='inplace',
    decoder_attention_type='scse',
    decoder_channels=[256, 128, 64, 32],
    in_channels=3,
    classes=2,
    activation="softmax",
    center=True,
)

model = model.to(device)

# ---------------------- Loss Functions -----------------------
class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, ep=1e-6):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.ep = ep

    def update_params(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta

    def forward(self, outputs, targets):
        targets_one_hot = T.nn.functional.one_hot(targets, num_classes=2).permute(0, 3, 1, 2).float()
        preds = outputs 
        true = targets_one_hot
        TP = (preds * true).sum(dim=[2, 3])
        FP = (preds * (1 - true)).sum(dim=[2, 3])
        FN = ((1 - preds) * true).sum(dim=[2, 3])
        tversky = (TP + self.ep) / (TP + self.alpha * FP + self.beta * FN + self.ep)
        return 1 - tversky.mean()

class CombinedLoss(nn.Module):
    def __init__(self, focal_weight=0.5, tversky_weight=0.5, tversky_alpha=0.7, tversky_beta=0.3):
        super(CombinedLoss, self).__init__()
        self.focal_loss = smp.losses.FocalLoss(mode='multiclass')
        self.tversky_loss = TverskyLoss(alpha=tversky_alpha, beta=tversky_beta)
        self.focal_weight = focal_weight
        self.tversky_weight = tversky_weight

    def forward(self, outputs, targets, return_components=False):
        focal = self.focal_loss(outputs, targets)
        tversky = self.tversky_loss(outputs, targets)
        total_loss = self.focal_weight * focal + self.tversky_weight * tversky
        if return_components:
            return total_loss, focal, tversky
        return total_loss

# ---------------------- Dynamic Loss Wrapper -----------------------
class DynamicLossWrapper:
    def __init__(self, criterion, recall_threshold=0.8):
        self.criterion = criterion
        self.recall_threshold = recall_threshold
        self.focal_losses = []
        self.tversky_losses = []

    def update_weights(self):
        if len(self.focal_losses) > 0 and len(self.tversky_losses) > 0:
            avg_focal = np.mean(self.focal_losses[-5:])  # Last 5 epochs
            avg_tversky = np.mean(self.tversky_losses[-5:])  # Last 5 epochs
            total = avg_focal + avg_tversky
            if total > 0:
                self.criterion.focal_weight = avg_tversky / total
                self.criterion.tversky_weight = avg_focal / total
            self.focal_losses = []
            self.tversky_losses = []

    def update_tversky_params(self, recall):
        alpha = self.criterion.tversky_loss.alpha
        if recall < self.recall_threshold:
            alpha = min(alpha + 0.05, 0.9)  # Increase alpha for better recall
        else:
            alpha = max(alpha - 0.05, 0.5)  # Decrease alpha for balanced precision
        beta = 1.0 - alpha
        self.criterion.tversky_loss.update_params(alpha, beta)
        return alpha, beta

    def __call__(self, outputs, targets):
        loss, focal, tversky = self.criterion(outputs, targets, return_components=True)
        self.focal_losses.append(focal.item())
        self.tversky_losses.append(tversky.item())
        return loss

# ---------------------- Metrics -----------------------
def compute_accuracy(outputs, targets):
    preds = T.argmax(outputs, dim=1)  # [batch, H, W]
    correct = (preds == targets).float().sum()
    total = targets.numel()
    return (correct / total).item()

def compute_iou(outputs, targets, class_id=1):  # IoU for foreground (class_id=1)
    preds = (T.argmax(outputs, dim=1) == class_id).float()  # [batch, H, W]
    targets = (targets == class_id).float()  # [batch, H, W]
    intersection = (preds * targets).sum((1, 2))
    union = (preds + targets - preds * targets).sum((1, 2))
    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou.mean().item()

def compute_miou(outputs, targets, num_classes=2):
    ious = []
    for class_id in range(num_classes):
        preds = (T.argmax(outputs, dim=1) == class_id).float()
        targets_class = (targets == class_id).float()
        intersection = (preds * targets_class).sum((1, 2))
        union = (preds + targets_class - preds * targets_class).sum((1, 2))
        iou = (intersection + 1e-6) / (union + 1e-6)
        ious.append(iou)
    return T.stack(ious).mean().item()

def compute_map(outputs, targets, iou_thresholds=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]):
    preds = T.argmax(outputs, dim=1)  # [batch, H, W]
    targets = targets
    aps = []
    for threshold in iou_thresholds:
        iou = compute_iou(outputs, targets, class_id=1)
        ap = 1.0 if iou >= threshold else 0.0
        aps.append(ap)
    return np.mean(aps)

def compute_fnr(outputs, targets, class_id=1):
    preds = (T.argmax(outputs, dim=1) == class_id).float()  # [batch, H, W]
    targets = (targets == class_id).float()  # [batch, H, W]
    FN = ((1 - preds) * targets).sum((1, 2))
    TP = (preds * targets).sum((1, 2))
    fnr = (FN + 1e-6) / (TP + FN + 1e-6)
    return fnr.mean().item()

def compute_recall(outputs, targets, class_id=1):
    preds = (T.argmax(outputs, dim=1) == class_id).float()  # [batch, H, W]
    targets = (targets == class_id).float()  # [batch, H, W]
    TP = (preds * targets).sum((1, 2))
    FN = ((1 - preds) * targets).sum((1, 2))
    recall = (TP + 1e-6) / (TP + FN + 1e-6)
    return recall.mean().item()

def compute_precision(outputs, targets, class_id=1):
    preds = (T.argmax(outputs, dim=1) == class_id).float()  # [batch, H, W]
    targets = (targets == class_id).float()  # [batch, H, W]
    TP = (preds * targets).sum((1, 2))
    FP = (preds * (1 - targets)).sum((1, 2))
    precision = (TP + 1e-6) / (TP + FP + 1e-6)
    return precision.mean().item()

def compute_f1_score(outputs, targets, class_id=1):
    precision = compute_precision(outputs, targets, class_id)
    recall = compute_recall(outputs, targets, class_id)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
    return f1

def compute_confusion_matrix(outputs, targets, class_id=1):
    preds = (T.argmax(outputs, dim=1) == class_id).float()  # [batch, H, W]
    targets = (targets == class_id).float()  # [batch, H, W]
    TP = (preds * targets).sum().item()
    FP = (preds * (1 - targets)).sum().item()
    FN = ((1 - preds) * targets).sum().item()
    TN = ((1 - preds) * (1 - targets)).sum().item()
    return {'TP': TP, 'FP': FP, 'FN': FN, 'TN': TN}

# ---------------------- Training and Testing -----------------------
def train_model(model, train_dataloader, val_dataloader, test_dataloader, epochs=50):
    criterion = CombinedLoss(focal_weight=0.5, tversky_weight=0.5, tversky_alpha=0.7, tversky_beta=0.3).to(device)
    dynamic_loss = DynamicLossWrapper(criterion, recall_threshold=0.8)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    
    best_val_loss = float('inf')
    
    for epoch in tqdm(range(epochs), desc="Epochs"):
        model.train()
        train_loss = 0.0
        train_accuracy = 0.0
        train_iou = 0.0
        train_miou = 0.0
        train_map = 0.0
        train_fnr = 0.0
        train_precision = 0.0
        train_f1 = 0.0
        
        for batch in tqdm(train_dataloader, desc="Training", leave=False):
            original_imgs = batch['original_img']
            original_masks = batch['original_mask']
            augmented_imgs = batch['augmented_img']
            augmented_masks = batch['augmented_mask']
            
            all_imgs = T.cat([original_imgs, augmented_imgs], dim=0)
            all_masks = T.cat([original_masks, augmented_masks], dim=0)
            
            optimizer.zero_grad()
            
            outputs = model(all_imgs)
            loss = dynamic_loss(outputs, all_masks)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * all_imgs.size(0)
            train_accuracy += compute_accuracy(outputs, all_masks) * all_imgs.size(0)
            train_iou += compute_iou(outputs, all_masks) * all_imgs.size(0)
            train_miou += compute_miou(outputs, all_masks) * all_imgs.size(0)
            train_map += compute_map(outputs, all_masks) * all_imgs.size(0)
            train_fnr += compute_fnr(outputs, all_masks) * all_imgs.size(0)
            train_precision += compute_precision(outputs, all_masks) * all_imgs.size(0)
            train_f1 += compute_f1_score(outputs, all_masks) * all_imgs.size(0)
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_accuracy = 0.0
        val_iou = 0.0
        val_miou = 0.0
        val_map = 0.0
        val_fnr = 0.0
        val_precision = 0.0
        val_f1 = 0.0
        val_recall = 0.0
        if epoch == epochs - 1:  # Compute confusion matrix only for final epoch
            val_cm = {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 0}
        
        with T.no_grad():
            for batch in tqdm(val_dataloader, desc="Validation", leave=False):
                imgs = batch['original_img']  # Already on GPU
                masks = batch['original_mask']  # Already on GPU
                
                outputs = model(imgs)
                loss = dynamic_loss(outputs, masks)
                
                val_loss += loss.item() * imgs.size(0)
                val_accuracy += compute_accuracy(outputs, masks) * imgs.size(0)
                val_iou += compute_iou(outputs, masks) * imgs.size(0)
                val_miou += compute_miou(outputs, masks) * imgs.size(0)
                val_map += compute_map(outputs, masks) * imgs.size(0)
                val_fnr += compute_fnr(outputs, masks) * imgs.size(0)
                val_precision += compute_precision(outputs, masks) * imgs.size(0)
                val_f1 += compute_f1_score(outputs, masks) * imgs.size(0)
                val_recall += compute_recall(outputs, masks) * imgs.size(0)
                
                if epoch == epochs - 1:  # Accumulate confusion matrix
                    batch_cm = compute_confusion_matrix(outputs, masks)
                    val_cm['TP'] += batch_cm['TP']
                    val_cm['FP'] += batch_cm['FP']
                    val_cm['FN'] += batch_cm['FN']
                    val_cm['TN'] += batch_cm['TN']
        
        # Calculate epoch metrics
        train_loss = train_loss / (2 * len(train_dataloader.dataset))  # 2x for original + augmented
        train_accuracy = train_accuracy / (2 * len(train_dataloader.dataset))
        train_iou = train_iou / (2 * len(train_dataloader.dataset))
        train_miou = train_miou / (2 * len(train_dataloader.dataset))
        train_map = train_map / (2 * len(train_dataloader.dataset))
        train_fnr = train_fnr / (2 * len(train_dataloader.dataset))
        train_precision = train_precision / (2 * len(train_dataloader.dataset))
        train_f1 = train_f1 / (2 * len(train_dataloader.dataset))
        val_loss = val_loss / len(val_dataloader.dataset)
        val_accuracy = val_accuracy / len(val_dataloader.dataset)
        val_iou = val_iou / len(val_dataloader.dataset)
        val_miou = val_miou / len(val_dataloader.dataset)
        val_map = val_map / len(val_dataloader.dataset)
        val_fnr = val_fnr / len(val_dataloader.dataset)
        val_precision = val_precision / len(val_dataloader.dataset)
        val_f1 = val_f1 / len(val_dataloader.dataset)
        val_recall = val_recall / len(val_dataloader.dataset)
        
        # Update dynamic hyperparameters
        dynamic_loss.update_weights()
        tversky_alpha, tversky_beta = dynamic_loss.update_tversky_params(val_recall)
        
        print(f'Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, '
              f'Train IoU: {train_iou:.4f}, Train MIoU: {train_miou:.4f}, Train mAP: {train_map:.4f}, '
              f'Train FNR: {train_fnr:.4f}, Train Precision: {train_precision:.4f}, Train F1: {train_f1:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val IoU: {val_iou:.4f}, '
              f'Val MIoU: {val_miou:.4f}, Val mAP: {val_map:.4f}, Val FNR: {val_fnr:.4f}, '
              f'Val Precision: {val_precision:.4f}, Val F1: {val_f1:.4f}, Val Recall: {val_recall:.4f}, '
              f'Focal Weight: {dynamic_loss.criterion.focal_weight:.3f}, '
              f'Tversky Weight: {dynamic_loss.criterion.tversky_weight:.3f}, '
              f'Tversky Alpha: {tversky_alpha:.3f}, Beta: {tversky_beta:.3f}')
        
        # Print validation confusion matrix for final epoch
        if epoch == epochs - 1:
            total = val_cm['TP'] + val_cm['FP'] + val_cm['FN'] + val_cm['TN']
            if total > 0:
                cm_normalized = {
                    'TP': val_cm['TP'] / total,
                    'FP': val_cm['FP'] / total,
                    'FN': val_cm['FN'] / total,
                    'TN': val_cm['TN'] / total
                }
                print("\nFinal Validation Confusion Matrix (Normalized):")
                print(f"{'':>10} {'Predicted':>20}")
                print(f"{'':>10} {'Positive':>10} {'Negative':>10}")
                print(f"{'Actual':>10}")
                print(f"{'Positive':>10} {cm_normalized['TP']:.4f} {cm_normalized['FN']:.4f}")
                print(f"{'Negative':>10} {cm_normalized['FP']:.4f} {cm_normalized['TN']:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            T.save(model.state_dict(), 'best_model.pth')
        
        scheduler.step(val_loss)
    
    # Testing Phase
    print("\nEvaluating on Test Set...")
    model.load_state_dict(T.load('best_model.pth'))  # Load best model
    model.eval()
    test_loss = 0.0
    test_accuracy = 0.0
    test_iou = 0.0
    test_miou = 0.0
    test_map = 0.0
    test_fnr = 0.0
    test_precision = 0.0
    test_f1 = 0.0
    test_recall = 0.0
    test_cm = {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 0}
    
    with T.no_grad():
        for batch in tqdm(test_dataloader, desc="Testing"):
            imgs = batch['original_img']  # Already on GPU
            masks = batch['original_mask']  # Already on GPU
            
            outputs = model(imgs)
            loss = dynamic_loss(outputs, masks)
            
            test_loss += loss.item() * imgs.size(0)
            test_accuracy += compute_accuracy(outputs, masks) * imgs.size(0)
            test_iou += compute_iou(outputs, masks) * imgs.size(0)
            test_miou += compute_miou(outputs, masks) * imgs.size(0)
            test_map += compute_map(outputs, masks) * imgs.size(0)
            test_fnr += compute_fnr(outputs, masks) * imgs.size(0)
            test_precision += compute_precision(outputs, masks) * imgs.size(0)
            test_f1 += compute_f1_score(outputs, masks) * imgs.size(0)
            test_recall += compute_recall(outputs, masks) * imgs.size(0)
            
            # Accumulate confusion matrix
            batch_cm = compute_confusion_matrix(outputs, masks)
            test_cm['TP'] += batch_cm['TP']
            test_cm['FP'] += batch_cm['FP']
            test_cm['FN'] += batch_cm['FN']
            test_cm['TN'] += batch_cm['TN']
    
    # Calculate test metrics
    test_loss = test_loss / len(test_dataloader.dataset)
    test_accuracy = test_accuracy / len(test_dataloader.dataset)
    test_iou = test_iou / len(test_dataloader.dataset)
    test_miou = test_miou / len(test_dataloader.dataset)
    test_map = test_map / len(test_dataloader.dataset)
    test_fnr = test_fnr / len(test_dataloader.dataset)
    test_precision = test_precision / len(test_dataloader.dataset)
    test_f1 = test_f1 / len(test_dataloader.dataset)
    test_recall = test_recall / len(test_dataloader.dataset)
    
    print(f'\nTest Results - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, '
          f'Test IoU: {test_iou:.4f}, Test MIoU: {test_miou:.4f}, Test mAP: {test_map:.4f}, '
          f'Test FNR: {test_fnr:.4f}, Test Precision: {test_precision:.4f}, Test F1: {test_f1:.4f}, '
          f'Test Recall: {test_recall:.4f}')
    
    # Print test confusion matrix
    total = test_cm['TP'] + test_cm['FP'] + test_cm['FN'] + test_cm['TN']
    if total > 0:
        cm_normalized = {
            'TP': test_cm['TP'] / total,
            'FP': test_cm['FP'] / total,
            'FN': test_cm['FN'] / total,
            'TN': test_cm['TN'] / total
        }
        print("\nTest Confusion Matrix (Normalized):")
        print(f"{'':>10} {'Predicted':>20}")
        print(f"{'':>10} {'Positive':>10} {'Negative':>10}")
        print(f"{'Actual':>10}")
        print(f"{'Positive':>10} {cm_normalized['TP']:.4f} {cm_normalized['FN']:.4f}")
        print(f"{'Negative':>10} {cm_normalized['FP']:.4f} {cm_normalized['TN']:.4f}")

# Start training and testing
train_model(model, train_dataloader, val_dataloader, test_dataloader, epochs=50)

In [None]:
# ---------------------- Model Definition -----------------------
model = smp.Unet(
    encoder="efficientnet-b7",
    encoder_weights="imagenet",
    encoder_depth=4,
    decoder_use_batchnorm='inplace',
    decoder_attention_type='scse',
    decoder_channels=[256, 128, 64, 32],
    in_channels=3,
    classes=2,
    activation="softmax",
    center=True,
)

model = model.to(device)

# ---------------------- Focal-Tversky Loss -----------------------
class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, gamma=0.75, smooth=1e-6):
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, preds, targets):
        targets_one_hot = F.one_hot(targets, num_classes=preds.shape[1]).permute(0, 3, 1, 2).float()
        probs = preds  # Assume softmax

        dims = (0, 2, 3)
        TP = T.sum(probs * targets_one_hot, dims)
        FP = T.sum(probs * (1 - targets_one_hot), dims)
        FN = T.sum((1 - probs) * targets_one_hot, dims)

        Tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return T.mean((1 - Tversky) ** self.gamma)

# ---------------------- Evaluation Metrics -----------------------
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, average_precision_score
import torch.nn.functional as F

def compute_metrics(preds, targets, num_classes=2):
    with T.no_grad():
        preds = T.argmax(preds, dim=1).cpu().numpy().flatten()
        targets = targets.cpu().numpy().flatten()

        acc = accuracy_score(targets, preds)
        f1 = f1_score(targets, preds, average='binary' if num_classes == 2 else 'macro')
        precision = precision_score(targets, preds, average='binary' if num_classes == 2 else 'macro')
        recall = recall_score(targets, preds, average='binary' if num_classes == 2 else 'macro')
        cm = confusion_matrix(targets, preds, labels=list(range(num_classes)))

        if num_classes == 2:
            TN, FP, FN, TP = cm.ravel()
            fnr = FN / (FN + TP + 1e-6)
        else:
            fnr = None

        ious = []
        for cls in range(num_classes):
            intersection = ((preds == cls) & (targets == cls)).sum()
            union = ((preds == cls) | (targets == cls)).sum()
            iou = intersection / (union + 1e-6)
            ious.append(iou)

        mean_iou = sum(ious) / num_classes

        return {
            "Accuracy": acc,
            "F1-Score": f1,
            "Precision": precision,
            "Recall": recall,
            "IoU (Foreground)": ious[1] if num_classes > 1 else ious[0],
            "Mean IoU": mean_iou,
            "FNR": fnr
        }

# ---------------------- Optimizer -----------------------
optimizer = T.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

# ---------------------- LR Scheduler -----------------------
scheduler = T.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',          # use 'max' if monitoring a metric like Mean IoU
    factor=0.5,          # reduce LR by half
    patience=3,          # wait for 3 bad epochs before reducing
    threshold=1e-4,      # minimum improvement to avoid being a bad epoch
    threshold_mode='rel',
    cooldown=0,
    min_lr=1e-6,
    verbose=True         # print when LR is updated
)

# ---------------------- Dynamic Focal-Tversky Logic ----------------------
# Updated logic for strong performance in 50 epochs:
# - Update alpha, beta, gamma every 5 epochs (10 steps total)
# - Alpha linearly decreases from 0.7 to 0.4, beta increases from 0.3 to 0.6
# - Gamma increases from 0.75 to 1.5

def update_ft_params(epoch):
    steps = epoch // 5  # one step every 5 epochs
    alpha = max(0.4, 0.7 - 0.03 * steps)  # from 0.7 to 0.4
    beta = 1 - alpha                      # from 0.3 to 0.6
    gamma = min(1.5, 0.75 + 0.075 * steps)  # from 0.75 to 1.5
    return alpha, beta, gamma

# ---------------------- Training & Validation Loops ----------------------
def TrainUNet(model, dataloader, loss_fn, optimizer):
    model.train()
    running_loss = 0
    all_preds = []
    all_targets = []

    for batch in dataloader:
        inputs = batch['augmented_img']
        targets = batch['augmented_mask']

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        all_preds.append(outputs.detach())
        all_targets.append(targets.detach())

    avg_loss = running_loss / len(dataloader)
    all_preds = T.cat(all_preds, dim=0)
    all_targets = T.cat(all_targets, dim=0)

    metrics = compute_metrics(all_preds, all_targets)
    return avg_loss, metrics


def ValidateUNet(model, dataloader, loss_fn):
    model.eval()
    running_loss = 0
    all_preds = []
    all_targets = []

    with T.no_grad():
        for batch in dataloader:
            inputs = batch['original_img']
            targets = batch['original_mask']

            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            running_loss += loss.item()
            all_preds.append(outputs)
            all_targets.append(targets)

    avg_loss = running_loss / len(dataloader)
    all_preds = T.cat(all_preds, dim=0)
    all_targets = T.cat(all_targets, dim=0)

    metrics = compute_metrics(all_preds, all_targets)
    return avg_loss, metrics
