In [1]:
import torch as T
import torchvision as TV
import torchaudio as TA
import cv2
import os
import numpy as np
import random
import tqdm as tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch import optim
from torch.utils.data import DataLoader, Dataset
import segmentation_models_pytorch as smp
from glob import glob
from tqdm import tqdm
import albumentations as A
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, confusion_matrix, average_precision_score

In [2]:
if T.cuda.is_available():
    device=T.device("cuda")
else:
    device=T.device("cpu")

print(device)

cuda


In [3]:
# ---------- Paths ----------
train_images = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\train_new"
train_masks = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\trainlabel_new"
validation_images = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\validation_new"
validation_masks = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\validationlabel_new"
test_images = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\test_new"
test_masks = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384\testlabel_new"


In [4]:
# ---------------------- Augmentations -----------------------
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.ElasticTransform(p=0.5),
    A.D4(p=1),
    A.ISONoise(
        color_shift=[0.01, 0.05],
        intensity=[0.1, 0.5],
        p=0.5
    ),
    A.RandomBrightnessContrast(brightness_limit=[-0.2, 0.2], contrast_limit=[-0.2, 0.2], brightness_by_max=True, ensure_safe_range=False, p=0.5),
    A.ElasticTransform(
        alpha=300,
        sigma=10,
        interpolation=cv2.INTER_NEAREST,
        approximate=False,
        same_dxdy=True,
        mask_interpolation=cv2.INTER_NEAREST,
        noise_distribution="gaussian",
        keypoint_remapping_method="mask",
        border_mode=cv2.BORDER_CONSTANT,
        fill=0,
        fill_mask=0
    ),
])

base_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ---------------------- Dataset Class -----------------------
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, train_transform=None, base_transform=None, dataset_type="Unknown"):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.train_transform = train_transform
        self.base_transform = base_transform
        self.dataset_type = dataset_type
        self.image_files = sorted(glob(os.path.join(image_dir, "*.jpg")))
        self.mask_files = sorted(glob(os.path.join(mask_dir, "*.png")))
        self._verify_file_pairs()
        
    def _verify_file_pairs(self):
        if len(self.image_files) != len(self.mask_files):
            raise ValueError(f"Mismatched counts in {self.dataset_type} dataset: {len(self.image_files)} images vs {len(self.mask_files)} masks")
            
        for img_path, mask_path in tqdm(zip(self.image_files, self.mask_files), total=len(self.image_files), desc=f"Verifying {self.dataset_type} File Pairs 🔍"):
            img_name = os.path.splitext(os.path.basename(img_path))[0]
            mask_name = os.path.splitext(os.path.basename(mask_path))[0]
            if img_name != mask_name:
                raise ValueError(f"Filename mismatch in {self.dataset_type} dataset: {img_name} vs {mask_name}")

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_files[idx]), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_files[idx], cv2.IMREAD_GRAYSCALE)
        original_img = self.base_transform(T.from_numpy(img).permute(2, 0, 1).float()).to(device)
        original_mask = T.from_numpy(mask).long().to(device)  # Convert mask to tensor directly
        
        if self.train_transform:
            augmented = self.train_transform(image=img, mask=mask)
            aug_img = augmented['image']
            aug_mask = augmented['mask']
            aug_img = self.base_transform(T.from_numpy(aug_img).permute(2, 0, 1).float()).to(device)
            aug_mask = T.from_numpy(aug_mask).long().to(device)
            
            return {
                'original_img': original_img,
                'original_mask': original_mask,
                'augmented_img': aug_img,
                'augmented_mask': aug_mask
            }
        else:
            return {
                'original_img': original_img,
                'original_mask': original_mask
            }

# ---------------------- Datasets & DataLoaders -----------------------
train_dataset = SegmentationDataset(
    train_images, 
    train_masks, 
    train_transform=train_transform,
    base_transform=base_transform,
    dataset_type="Training"
)

val_dataset = SegmentationDataset(
    validation_images,
    validation_masks,
    train_transform=train_transform,
    base_transform=base_transform,
    dataset_type="Validation"
)

test_dataset = SegmentationDataset(
    test_images,
    test_masks,
    train_transform=train_transform,
    base_transform=base_transform,
    dataset_type="Testing"
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    pin_memory=True,
    num_workers=2,
    persistent_workers=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    persistent_workers=True
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    persistent_workers=True
)

  original_init(self, **validated_kwargs)
Verifying Training File Pairs 🔍: 100%|███████████████████████████████████████████| 800/800 [00:00<00:00, 72726.24it/s]
Verifying Validation File Pairs 🔍: 100%|█████████████████████████████████████████| 176/176 [00:00<00:00, 58652.27it/s]
Verifying Testing File Pairs 🔍: 100%|████████████████████████████████████████████| 600/600 [00:00<00:00, 66657.37it/s]


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import segmentation_models_pytorch as smp

# ====================== CONFIGURATION ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CUSTOM_SAVE_ROOT = Path(r"D:\AAU Internship\Code\UNet-Models")
os.makedirs(CUSTOM_SAVE_ROOT, exist_ok=True)

# ====================== MODEL DEFINITION ======================
model = smp.Unet(
    encoder="efficientnet-b7",
    encoder_weights="imagenet",
    encoder_depth=4,
    decoder_use_batchnorm='inplace',
    decoder_attention_type='scse',
    decoder_channels=[256, 128, 64, 32],
    in_channels=3,
    classes=2,
    activation="softmax",
    center=True,
).to(device)

# ====================== LOSS FUNCTION ======================
class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, gamma=0.75, smooth=1e-6):
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth

    def update_hyperparams_by_epoch(self, epoch):
        steps = epoch // 5
        self.alpha = max(0.4, 0.7 - 0.03*steps)
        self.beta = 1 - self.alpha
        self.gamma = min(1.5, 0.5 + 0.1*steps)

    def forward(self, preds, targets):
        targets_one_hot = F.one_hot(targets, num_classes=preds.shape[1]).permute(0, 3, 1, 2).float()
        probs = preds
        dims = (0, 2, 3)
        TP = torch.sum(probs * targets_one_hot, dims)
        FP = torch.sum(probs * (1 - targets_one_hot), dims)
        FN = torch.sum((1 - probs) * targets_one_hot, dims)
        Tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return torch.mean((1 - Tversky) ** self.gamma)

# Instantiate the loss function
loss_fn = FocalTverskyLoss().to(device)

# ====================== EVALUATION METRICS ======================
def compute_metrics(preds, targets):
    with torch.no_grad():
        pred_labels = torch.argmax(preds, dim=1).cpu().numpy().flatten()
        targets = targets.cpu().numpy().flatten()
        
        # Calculate IoUs
        ious = []
        for cls in [0, 1]:
            intersection = ((pred_labels == cls) & (targets == cls)).sum()
            union = ((pred_labels == cls) | (targets == cls)).sum()
            ious.append(intersection / (union + 1e-6))
        
        # Calculate mPA
        class_acc = []
        for cls in [0, 1]:
            mask = (targets == cls)
            if mask.sum() > 0:
                class_acc.append((pred_labels[mask] == cls).mean())
        mPA = np.mean(class_acc) * 100
        
        # Other metrics
        cm = confusion_matrix(targets, pred_labels)
        TN, FP, FN, TP = cm.ravel()
        
        return {
            "Accuracy": 100 * accuracy_score(targets, pred_labels),
            "mPA": mPA,
            "Crop IoU": 100 * ious[1],
            "mIoU": 100 * np.mean(ious),
            "Precision": 100 * precision_score(targets, pred_labels, zero_division=0),
            "Recall": 100 * recall_score(targets, pred_labels, zero_division=0),
            "F1-Score": 100 * f1_score(targets, pred_labels, zero_division=0),
            "FNR": 100 * (FN / (FN + TP + 1e-6))
        }

# ====================== TRAINING SETUP ======================
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Define custom model save paths
MODEL_PATHS = {
    "mPA": CUSTOM_SAVE_ROOT / "best_mPA_model.pth",
    "mIoU": CUSTOM_SAVE_ROOT / "best_mIoU_model.pth",
    "Crop IoU": CUSTOM_SAVE_ROOT / "best_CropIoU_model.pth",
    "Accuracy": CUSTOM_SAVE_ROOT / "best_Accuracy_model.pth",
    "F1-Score": CUSTOM_SAVE_ROOT / "best_F1_model.pth",
    "Precision": CUSTOM_SAVE_ROOT / "best_Precision_model.pth",
    "Recall": CUSTOM_SAVE_ROOT / "best_Recall_model.pth",
    "FNR": CUSTOM_SAVE_ROOT / "best_FNR_model.pth"
}

# Initialize best metrics tracking
best_metrics = {
    "mPA": {"value": -1, "path": MODEL_PATHS["mPA"]},
    "mIoU": {"value": -1, "path": MODEL_PATHS["mIoU"]},
    "Crop IoU": {"value": -1, "path": MODEL_PATHS["Crop IoU"]},
    "Accuracy": {"value": -1, "path": MODEL_PATHS["Accuracy"]},
    "F1-Score": {"value": -1, "path": MODEL_PATHS["F1-Score"]},
    "Precision": {"value": -1, "path": MODEL_PATHS["Precision"]},
    "Recall": {"value": -1, "path": MODEL_PATHS["Recall"]},
    "FNR": {"value": float('inf'), "path": MODEL_PATHS["FNR"]}
}

# ====================== TRAINING LOOPS ======================
def TrainUNet(model, dataloader, loss_fn, optimizer, epoch):
    model.train()
    running_loss = 0
    all_preds = []
    all_targets = []
    loss_fn.update_hyperparams_by_epoch(epoch)

    for batch in dataloader:
        inputs = batch['augmented_img'].to(device)
        targets = batch['augmented_mask'].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        all_preds.append(outputs.detach())
        all_targets.append(targets.detach())

    avg_loss = running_loss / len(dataloader)
    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    metrics = compute_metrics(all_preds, all_targets)
    return avg_loss, metrics

def ValidateUNet(model, dataloader, loss_fn):
    model.eval()
    running_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            inputs = batch['original_img'].to(device)
            targets = batch['original_mask'].to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            running_loss += loss.item()
            all_preds.append(outputs)
            all_targets.append(targets)

    avg_loss = running_loss / len(dataloader)
    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    metrics = compute_metrics(all_preds, all_targets)
    return avg_loss, metrics

# ====================== MAIN TRAINING ======================
num_epochs = 50

# Make sure these dataloaders are defined before this point
# train_dataloader = ...
# val_loader = ...
# test_loader = ...

for epoch in range(num_epochs):
    train_loss, train_metrics = TrainUNet(model, train_dataloader, loss_fn, optimizer, epoch)
    val_loss, val_metrics = ValidateUNet(model, val_loader, loss_fn)
    
    # Update best models
    for metric_name in best_metrics.keys():
        current_value = val_metrics[metric_name]
        is_better = (current_value > best_metrics[metric_name]["value"]) if metric_name != "FNR" else (current_value < best_metrics[metric_name]["value"])
        
        if is_better:
            best_metrics[metric_name]["value"] = current_value
            torch.save(model.state_dict(), str(best_metrics[metric_name]["path"]))
            print(f"New best {metric_name}: {current_value:.2f}% | Saved to: {best_metrics[metric_name]['path']}")

    # Epoch summary
    print(f"\nEpoch {epoch} Summary:")
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    for k, v in val_metrics.items():
        print(f"{k}: {v:.2f}%")

# ====================== FINAL REPORT ======================
print("\n=== Best Models Saved ===")
for metric_name, data in best_metrics.items():
    print(f"{metric_name}: {data['value']:.2f}%")
    print(f"Location: {data['path']}\n")

# ====================== TESTING ======================
print("\n=== Testing Best Models ===")
for metric_name, data in best_metrics.items():
    model.load_state_dict(torch.load(str(data["path"])))
    test_loss, test_metrics = ValidateUNet(model, test_loader, loss_fn)
    print(f"\n{metric_name} Model Test Results:")
    for k, v in test_metrics.items():
        print(f"{k}: {v:.2f}%")