# Uncertainty


Loading of Models, functions, dependencies and dataset

In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from torchvision.transforms.functional import to_tensor
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm 
from scipy.stats import ttest_rel

SAVE_DIR = Path("segmentation/validation_results")
SAVE_DIR.mkdir(parents=True, exist_ok=True)
LR = 0.0023337072137759097
BATCH_SIZE=  6
# DROPOUT_P = 0.43498360821913545 
# DICE_WEIGHT = 0.35695347056147286
PRED_THRESH = 0.42610762019278425
OPTIM= "Adam"
EPOCHS = 10
NUM_SAMPLES = 30

# Paths
data_dir = Path(r"/Users/petrahlavinova/School/DP/d/data")
mask_dir = Path(r"/Users/petrahlavinova/School/DP/d/generated_masks5")


# Dataset Definition
class SegmentationDataset(Dataset):
    def __init__(self, data_dir, mask_dir, transform=None):
        self.data_dir = data_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(data_dir))
        self.masks = sorted(os.listdir(mask_dir))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.data_dir / self.images[idx]
        mask_path = self.mask_dir / self.masks[idx]
        
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # Convert to grayscale
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask

import random

class AugmentedSegmentationDataset(Dataset):
    def __init__(self, data_dir, mask_dir, transform=None, augment=True):
        """
        Dataset class for segmentation with optional augmentation.
        Args:
            data_dir: Path to the images.
            mask_dir: Path to the masks.
            transform: Transformations to apply to images and masks.
            augment: Whether to apply augmentations (default True).
        """
        self.data_dir = data_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.augment = augment
        self.images = sorted(os.listdir(data_dir))
        self.masks = sorted(os.listdir(mask_dir))
    
    def __len__(self):
        return len(self.images) * (4 if self.augment else 1)
    
    def __getitem__(self, idx):
        # Determine the original image and rotation to apply
        original_idx = idx // 4
        rotation = idx % 4 * 90  # Rotate by 0, 90, 180, or 270 degrees
        
        img_path = self.data_dir / self.images[original_idx]
        mask_path = self.mask_dir / self.masks[original_idx]
        
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # Convert to grayscale
        
        # Apply rotation
        image = image.rotate(rotation)
        mask = mask.rotate(rotation)
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask
    
# Transforms
transform = T.Compose([
    T.Resize((256, 256)),  # Resize images and masks
    T.ToTensor(),           # Convert to PyTorch tensors
    # T.RandomRotation(30),  # Náhodné otočenie
    # T.RandomHorizontalFlip(),  # Horizontálne zrkadlenie
])

# Load Dataset
dataset = SegmentationDataset(data_dir, mask_dir, transform=transform)

# Split Dataset
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-7):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)  # Sigmoid pre škálovanie do rozsahu [0, 1]
        target = target.float()

        intersection = (pred * target).sum(dim=(1, 2, 3))
        union = pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)

        return 1 - dice.mean()


class CombinedLoss(nn.Module):
    def __init__(self, weight_dice=0.5, weight_bce=0.5):
        super(CombinedLoss, self).__init__()
        self.weight_dice = weight_dice
        self.weight_bce = weight_bce
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        bce = self.bce_loss(pred, target)
        return self.weight_dice * dice + self.weight_bce * bce

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2 with dropout"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.2)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNetWithDropout(nn.Module):
    """Modified U-Net with dropout layers"""
    def __init__(self, in_channels, out_channels):
        super(UNetWithDropout, self).__init__()
        
        # Encoder
        self.encoder1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(128, 64)
        
        # Final output
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))
        
        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = self.decoder4(torch.cat((dec4, enc4), dim=1))
        dec3 = self.upconv3(dec4)
        dec3 = self.decoder3(torch.cat((dec3, enc3), dim=1))
        dec2 = self.upconv2(dec3)
        dec2 = self.decoder2(torch.cat((dec2, enc2), dim=1))
        dec1 = self.upconv1(dec2)
        dec1 = self.decoder1(torch.cat((dec1, enc1), dim=1))
        
        return self.final_conv(dec1)
    
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        """
        Args:
            patience (int): Počet epoch bez zlepšenia, po ktorých sa tréning zastaví.
            min_delta (float): Minimálne zlepšenie metriky, aby sa počítalo ako zlepšenie.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def check_stop(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1

        return self.counter >= self.patience


def dice_score(pred, target, threshold=PRED_THRESH):

    pred = (torch.sigmoid(pred) > threshold).float()  # Binarize predictions
    target = target.float()
    
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
    dice = (2 * intersection) / (union + 1e-7)  # Avoid division by zero
    
    return dice.mean().item()




batch_size = BATCH_SIZE
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Model Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetWithDropout(in_channels=3, out_channels=1)  # RGB images, binary masks
model.to(device)  


# Loss and Optimizer
# criterion = torch.nn.BCEWithLogitsLoss()  # Binary segmentation task
criterion = CombinedLoss(weight_dice=0.7, weight_bce=0.3)  
optimizer = torch.optim.Adam(model.parameters(), lr=LR)



Uncertainty

In [None]:
def plot_segmentation(image, mask, prediction, epoch, idx):

    image = image.permute(1, 2, 0).cpu().numpy()  # Convert image to HWC format
    mask = mask.cpu().numpy()
    prediction = (torch.sigmoid(prediction) > 0.5).float().cpu().numpy()
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(image, cmap='gray')
    plt.title("Original Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='gray')
    plt.title("Ground Truth Mask")
    plt.axis("off")
    
    plt.subplot(1, 3, 3)
    plt.imshow(prediction, cmap='gray')
    plt.title("Predicted Mask")
    plt.axis("off")
    
    plt.tight_layout()

    plt.show()

In [None]:
early_stopping = EarlyStopping(patience=5, min_delta=0.001)
model = UNetWithDropout(in_channels=3, out_channels=1).to(device)
model.load_state_dict(torch.load("/segmentation/unet_segmentationSE.pth"))


# Training and Validation Loop
epochs = EPOCHS
for epoch in range(epochs):
    # Training
    model.train()
    # train_loss = 0
    # for images, masks in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
    #     images, masks = images.to(device), masks.to(device)
        
    #     # Forward pass
    #     outputs = model(images)
    #     loss = criterion(outputs, masks)
        
    #     # Backward pass
    #     optimizer.zero_grad()
    #     loss.backward()
    #     optimizer.step()
        
    #     train_loss += loss.item()
    
    # avg_train_loss = train_loss / len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():  # Disable gradient computation
        for idx, (images, masks) in enumerate(tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{epochs}")):
            images, masks = images.to(device), masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            val_loss += loss.item()
            val_dice += dice_score(outputs, masks)

            # Visualization for the first batch of each epoch
            if idx == 0:
                for i in range(min(3, len(images))):  # Visualize up to 3 images per epoch
                    plot_segmentation(images[i], masks[i][0], outputs[i][0], epoch, idx)

    
    avg_val_loss = val_loss / len(val_loader)
    avg_val_dice = val_dice / len(val_loader)
    
    # Print Epoch Summary
    print(f"Epoch [{epoch+1}/{epochs}] - "
        #   f"Train Loss: {avg_train_loss:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, "
          f"Val Dice Score: {avg_val_dice:.4f}")
    
    if early_stopping.check_stop(avg_val_loss):
        print("Early stopping triggered.")
        # torch.save(model.state_dict(), "unet_segmentationDropout.pth")
        break

# Save the Model
# torch.save(model.state_dict(), "unet_segmentationDropout.pth")

# Uncertainty Estimation starts here:
model.eval()



In [None]:
model = UNetWithDropout(in_channels=3, out_channels=1).to(device)
model.load_state_dict(torch.load("/segmentation/unet_segmentationSE.pth"))

# Uncertainty Estimation starts here:
model.eval()


In [None]:
def mc_dropout_predict(
    model, 
    input_image, 
    num_samples=20, 
    return_entropy=False, 
    return_std=False,
    return_iqr=False,
    return_range95=False,
    return_samples=False
):
    """
    Monte Carlo Dropout inference for segmentation uncertainty.

    Args:
        model: PyTorch model with dropout layers.
        input_image: Input tensor, shape [B, C, H, W] or [C, H, W].
        num_samples: Number of MC samples.
        return_entropy: If True, returns pixel-wise entropy map.
        return_std: If True, returns pixel-wise standard deviation map.
        return_iqr: If True, returns pixel-wise IQR map.
        return_range95: If True, returns pixel-wise 95% range map.
        return_samples: If True, returns raw probability samples.

    Returns:
        A tuple containing:
            mean_pred (np.ndarray): Mean probability map, shape [H,W].
            var_map (np.ndarray): Variance map, shape [H,W].
            std_map (np.ndarray, optional): Std dev map, shape [H,W].
            entropy_map (np.ndarray, optional): Entropy map, shape [H,W].
            iqr_map (np.ndarray, optional): IQR map, shape [H,W].
            range95_map (np.ndarray, optional): 95% range map, shape [H,W].
            samples (np.ndarray, optional): Raw probability samples, shape [N,H,W].
    """
    # Ensure dropout is active
    model.train()
    
    # Collect Monte Carlo samples
    probs_list = []
    with torch.no_grad():
        for _ in range(num_samples):
            logits = model(input_image)                 # [B, C, H, W] or [C, H, W]
            prob = torch.sigmoid(logits)                # convert logits to probability
            probs_list.append(prob.cpu().numpy())
    
    # Stack to shape [N, B, C, H, W] or [N, C, H, W]
    probs = np.stack(probs_list, axis=0)
    # Squeeze channel dim if C=1, and batch dim if B=1
    if probs.ndim == 5 and probs.shape[2] == 1:
        probs = probs.squeeze(2)
    if probs.ndim == 4 and probs.shape[1] == 1:
        probs = probs[:, 0, ...]

    # Basic 
    mean_pred = np.mean(probs, axis=0)
    var_map   = np.var(probs, axis=0) if num_samples > 1 else np.zeros_like(mean_pred)

    outputs = [mean_pred, var_map]

    # Standard deviation
    if return_std:
        std_map = np.sqrt(var_map)
        outputs.append(std_map)

    if return_entropy:
        eps = 1e-12
        p   = np.clip(mean_pred, eps, 1 - eps)   # vyhnúť sa log(0)
        entropy_map = -(p * np.log(p) + (1 - p) * np.log(1 - p))
        outputs.append(entropy_map)

    # Inter-quartile range (IQR)
    if return_iqr:
        q75 = np.percentile(probs, 75, axis=0)
        q25 = np.percentile(probs, 25, axis=0)
        iqr_map = q75 - q25
        outputs.append(iqr_map)

    # 95% range (between 2.5th and 97.5th percentiles)
    if return_range95:
        upper = np.percentile(probs, 97.5, axis=0)
        lower = np.percentile(probs, 2.5, axis=0)
        range95_map = upper - lower
        outputs.append(range95_map)

    # Raw samples
    if return_samples:
        outputs.append(probs)

    return tuple(outputs)


## Task 1 – Demonstration on one example from the test loader

In [None]:
def visualize_multiple_uncertainty_maps(original,ground_truth, segmentation, uncertainties_dict):
    num_maps = len(uncertainties_dict)
    cols = max(2, num_maps)
    fig, axes = plt.subplots(2, cols, figsize=(4 * cols, 8))

    # First Row : RTG
    ax = axes[0, 0]
    img = np.transpose(original, (1, 2, 0)) if original.ndim == 3 else original
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title('RTG')
    ax.axis('off')

    # First row: Ground Truth 
    ax = axes[0, 2]
    ax.imshow(ground_truth.squeeze(), cmap='gray')
    ax.set_title('Ground Truth')
    ax.axis('off')

    # first row: Segmentation
    ax = axes[0, 1]
    ax.imshow(segmentation.squeeze(), cmap='gray')
    ax.set_title('Segmentácia')
    ax.axis('off')

    # Empty fields in the first row, if num_maps < cols
    for j in range(2, cols):
        axes[0, j].axis('off')

    # second row: Uncertainty maps
    for j, (key, unc_map) in enumerate(uncertainties_dict.items()):
        ax = axes[1, j]
        data = unc_map.squeeze()

        if key.lower() == 'entropia':
            # entropy has range [0, log(2)]
            vmin, vmax = 0.0, np.log(2.0)
            im = ax.imshow(data, cmap='jet', vmin=vmin, vmax=vmax, alpha=0.8)
        else:
            # cut outliers for other maps
            vmin = 0.0
            vmax = np.percentile(data, 99)
            im = ax.imshow(data, cmap='jet', vmin=vmin, vmax=vmax, alpha=0.8)

        ax.set_title(key)
        ax.axis('off')
        fig.colorbar(im, ax=ax, shrink=0.6)

    for j in range(num_maps, cols):
        axes[1, j].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
# 1) one sample from the test set
sample_batch = next(iter(test_loader))
inputs, labels = sample_batch  
sample = inputs.to(device)[0].unsqueeze(0)
label = labels[0]  

# 2) MC-dropout (N=50 samples) 
mean_pred_np, var_map, std_map, entropy_map, iqr_map, range95_map = mc_dropout_predict(
    model,
    sample,               #  [B,C,H,W]
    num_samples     = 50,
    return_std      = True,
    return_entropy  = True,
    return_iqr      = True,
    return_range95  = True,
    return_samples  = False
)

# 3) preparation for visualization
orig_np   = sample[0].cpu().numpy()            # (3,H,W)
gt_np     = label[0].squeeze().cpu().numpy()   # (H,W)
segmentation_np  = (mean_pred_np.squeeze() > PRED_THRESH).astype(float)     # (H,W)

# 4) tupple of uncertainty maps
uncertainties = {
    'Variance'            : var_map,
    'Std dev'             : std_map,
    'Entropy'             : entropy_map,
    'IQR'                 : iqr_map,
    'Range 95%'           : range95_map
}

# 5) Visualization – first row: orig | GT | pred ; second: uncertainty maps
visualize_multiple_uncertainty_maps(orig_np,gt_np, segmentation_np, uncertainties)


## Task 1: visualization of 5 random samples from the test set

In [None]:
#  1) reproduceable set of data 
random.seed(2025)
rand_idx = random.sample(range(len(test_dataset)), 5)   

# 2) for through the random samples
for k, idx in enumerate(rand_idx, 1):
    img, gt = test_dataset[idx]          # tensors (C,H,W)  &  (1,H,W)

    img_batch = img.unsqueeze(0).to(device)  # (1,C,H,W) 

    # MC-dropout (N = 50 samples)
    mean_np, var_map, std_map, ent_map, iqr_map, r95_map = mc_dropout_predict(
        model,
        img_batch,
        num_samples     = 50,
        return_std      = True,
        return_entropy  = True,
        return_iqr      = True,
        return_range95  = True,
        return_samples  = False
    )

    # visualization preparation
    orig_np  = img.cpu().numpy()                     # (3,H,W)
    gt_np    = gt.squeeze().cpu().numpy()            # (H,W)
    pred_bin = (mean_np.squeeze() > PRED_THRESH).astype(float)

    uncert = {
        'Variance'  : var_map,
        'Std dev'   : std_map,
        'Entropy'   : ent_map,
        'IQR'       : iqr_map,
        'Range 95%' : r95_map
    }

    print(f"\n=== Príklad {k}  (dataset index {idx}) ===")
    visualize_multiple_uncertainty_maps(orig_np, gt_np, pred_bin, uncert)


In [None]:
def visualize_multiple_uncertainty_maps(original, ground_truth, segmentation, uncertainties_dict):
    num_maps = len(uncertainties_dict)
    cols = max(2, num_maps)
    fig, axes = plt.subplots(2, cols, figsize=(4 * cols, 8))

    ax = axes[0, 0]
    img = np.transpose(original, (1, 2, 0)) if original.ndim == 3 else original
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title('RTG')
    ax.axis('off')

    ax = axes[0, 2]
    ax.imshow(ground_truth.squeeze(), cmap='gray')
    ax.set_title('Ground Truth')
    ax.axis('off')

    ax = axes[0, 1]
    ax.imshow(segmentation.squeeze(), cmap='gray')
    ax.set_title('Segmentácia')
    ax.axis('off')

    for j in range(2, cols):
        axes[0, j].axis('off')

    normalized_uncertainties = {}
    for key, unc_map in uncertainties_dict.items():
        data = unc_map.squeeze()
        
        if key.lower() == 'entropy':
            normalized = data / np.log(2.0)
        else:
            max_val = np.percentile(data, 99)
            normalized = data / max_val
            normalized = np.clip(normalized, 0, 1)
        
        normalized_uncertainties[key] = normalized

    for j, (key, unc_map) in enumerate(normalized_uncertainties.items()):
        ax = axes[1, j]
        im = ax.imshow(unc_map.squeeze(), cmap='jet', vmin=0, vmax=1, alpha=0.8)
        ax.set_title(f"{key} (norm.)")
        ax.axis('off')
        fig.colorbar(im, ax=ax, shrink=0.6)

    for j in range(num_maps, cols):
        axes[1, j].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:

random.seed(2025)
rand_idx = random.sample(range(len(test_dataset)), 5)   

for k, idx in enumerate(rand_idx, 1):
    img, gt = test_dataset[idx]       

    img_batch = img.unsqueeze(0).to(device)  

    # MC-dropout (N = 50 samples)
    mean_np, var_map, std_map, ent_map, iqr_map, r95_map = mc_dropout_predict(
        model,
        img_batch,
        num_samples     = 50,
        return_std      = True,
        return_entropy  = True,
        return_iqr      = True,
        return_range95  = True,
        return_samples  = False
    )

    orig_np  = img.cpu().numpy()                     # (3,H,W)
    gt_np    = gt.squeeze().cpu().numpy()            # (H,W)
    pred_bin = (mean_np.squeeze() > PRED_THRESH).astype(float)

    uncert = {
        'Variance'  : var_map,
        'Std dev'   : std_map,
        'Entropy'   : ent_map,
        'IQR'       : iqr_map,
        'Range 95%' : r95_map
    }

    print(f"\n=== Príklad {k}  (dataset index {idx}) ===")
    visualize_multiple_uncertainty_maps(orig_np, gt_np, pred_bin, uncert)


## Sample_counts 

In [None]:
import  time

def visualize_entropy_vs_N(original, ground_truth, segmentation, entropies_dict):
    """
    1. riadok: orig | GT | pred (N=50)
    2. riadok: entropia pre N=10,20,30,40,50
    """
    cols = max(3, len(entropies_dict))
    fig, axes = plt.subplots(2, cols, figsize=(4*cols, 8))

    # 1) orig | GT | pred
    axes[0,0].imshow(np.transpose(original,(1,2,0)), cmap='gray')
    axes[0,0].set_title("RTG");         axes[0,0].axis('off')

    axes[0,1].imshow(ground_truth, cmap='gray')
    axes[0,1].set_title("GT maska");    axes[0,1].axis('off')

    axes[0,2].imshow(segmentation, cmap='gray')
    axes[0,2].set_title("Predikcia (N=50)"); axes[0,2].axis('off')

    for j in range(3, cols):
        axes[0,j].axis('off')

    # 2) entropy for different N
    for j,(name,emap) in enumerate(entropies_dict.items()):
        ax = axes[1,j]
        im = ax.imshow(emap.squeeze(), cmap='jet',
                       vmin=0.0, vmax=np.log(2.0))      # fix scale 0-ln2
        ax.set_title(name)
        ax.axis('off')
        fig.colorbar(im, ax=ax, fraction=0.045, pad=0.04)

    for j in range(len(entropies_dict), cols):
        axes[1,j].axis('off')

    plt.tight_layout()
    plt.show()
    

N_LIST = [0,5,10,15,20,25,30,35,40,45,50]         

test_loader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_means, dice_stds = [], []

print("⇢ Výpočet Dice pre rôzne počty vzoriek N …")
for N in N_LIST:
    dice_vals = []
    t0 = time.time()
    for img, gt in test_loader_1:
        img = img.to(device)

        # --- baseline N=0 (without dropout) 
        if N == 0:
            model.eval()                                     
            with torch.no_grad():
                logits = model(img)
                mean_np = torch.sigmoid(logits)[0,0].cpu().numpy()
        # --- MC-dropout N>0
        else:
            mean_np, *_ = mc_dropout_predict(
                model, img,
                num_samples     = N,
                return_std      = False,
                return_entropy  = False,
                return_iqr      = False,
                return_range95  = False,
                return_samples  = False
            )

        pred_bin = torch.from_numpy(mean_np)[None, None] > PRED_THRESH   # (1,1,H,W)
        d = dice_score(pred_bin.float(), gt.float())
        dice_vals.append(d)

    dice_means.append(np.mean(dice_vals))
    dice_stds .append(np.std (dice_vals))
    print(f"  N={N:2d} | Dice: {dice_means[-1]:.4f} ± {dice_stds[-1]:.4f} "
          f"| čas: {time.time()-t0:.1f}s")

y_min = min(dice_means) - 0.02
y_max = max(dice_means) + 0.02

plt.figure(figsize=(6,4))
plt.errorbar(N_LIST, dice_means, yerr=dice_stds,
             marker='o', capsize=4, label="MC-dropout")
# baseline (N=0) horizontal line
plt.axhline(dice_means[0], linestyle='--', color='green',
            label=f"Baseline N=0 (Dice {dice_means[0]:.3f})")

plt.xlabel("MC Samples count N")
plt.ylabel("Average Dice on test set")
plt.title("Stabilization Dice with increasing N")
plt.ylim(y_min, y_max)
plt.xticks(N_LIST)
plt.grid(True, axis='y', alpha=0.4)
plt.legend()
plt.tight_layout()
plt.show()


random.seed(42)
idx = random.randrange(len(test_dataset))
img, gt = test_dataset[idx]
img_batch = img.unsqueeze(0).to(device)

mean_50, *_ = mc_dropout_predict(
    model, img_batch, num_samples=50,
    return_std=False, return_entropy=False,
    return_iqr=False, return_range95=False
)

orig_np  = img.cpu().numpy()
gt_np    = gt.squeeze().cpu().numpy()
pred_np  = (mean_50.squeeze() > PRED_THRESH).astype(float)

entropy_maps = {}
for N in [10, 20, 30, 40, 50]:
    _, _, ent_map = mc_dropout_predict(
        model, img_batch, num_samples=N,
        return_std=False, return_entropy=True,
        return_iqr=False, return_range95=False
    )
    entropy_maps[f"Entropy\nN={N}"] = ent_map

def visualize_entropy_vs_N(original, ground_truth, segmentation, entropies_dict):
    cols = max(3, len(entropies_dict))
    fig, axes = plt.subplots(2, cols, figsize=(4*cols, 8))

    axes[0,0].imshow(np.transpose(original,(1,2,0)), cmap='gray')
    axes[0,0].set_title("RTG");         axes[0,0].axis('off')

    axes[0,1].imshow(ground_truth, cmap='gray')
    axes[0,1].set_title("GT maska");    axes[0,1].axis('off')

    axes[0,2].imshow(segmentation, cmap='gray')
    axes[0,2].set_title("Predikcia (N=50)"); axes[0,2].axis('off')

    for j in range(3, cols):
        axes[0,j].axis('off')

    for j,(name,emap) in enumerate(entropies_dict.items()):
        ax = axes[1,j]
        im = ax.imshow(emap.squeeze(), cmap='jet',
                       vmin=0.0, vmax=np.log(2.0))
        ax.set_title(name)
        ax.axis('off')
        fig.colorbar(im, ax=ax, fraction=0.045, pad=0.04)

    for j in range(len(entropies_dict), cols):
        axes[1,j].axis('off')

    plt.tight_layout()
    plt.show()

visualize_entropy_vs_N(orig_np, gt_np, pred_np, entropy_maps)


In [None]:
def test_dropout_variability(model, input_image):
    model.train()
    with torch.no_grad():
        out1 = model(input_image)
        out2 = model(input_image)
    diff = torch.abs(out1 - out2).mean().item()
    print(f"Average difference between two outputs: {diff:.6f}")


In [None]:
test_dropout_variability(model, img_batch)

In [None]:

from sklearn.metrics import f1_score

# ----------------------------- Dice metric -----------------------------
def dice_score_without_threshold(pred_mask, true_mask):
    """
    pred_mask, true_mask: torch.Tensor shape (1, 1, H, W), hodnoty {0,1} alebo float
    """
    smooth = 1e-6
    pred_flat = pred_mask.view(-1)
    true_flat = true_mask.view(-1)
    intersection = (pred_flat * true_flat).sum()
    return (2. * intersection + smooth) / (pred_flat.sum() + true_flat.sum() + smooth)

# --------------------------- DICE vs N GRAPF -----------------------------
N_LIST = [0,5,10,15,20,25,30,35,40,45,50]         
test_loader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_means, dice_stds = [], []
print("⇢ Výpočet Dice pre rôzne počty vzoriek N …")
for N in N_LIST:
    dice_vals = []
    t0 = time.time()
    for img, gt in test_loader_1:
        img = img.to(device)

        if N == 0:
            model.eval()
            with torch.no_grad():
                logits = model(img)
                mean_np = torch.sigmoid(logits)[0,0].cpu().numpy()
        else:
            mean_np, *_ = mc_dropout_predict(
                model, img,
                num_samples=N,
                return_entropy=False
            )

        pred_bin = torch.from_numpy(mean_np)[None, None] > PRED_THRESH
        d = dice_score_without_threshold(pred_bin.float(), gt.float())
        dice_vals.append(d)

    dice_means.append(np.mean(dice_vals))
    dice_stds .append(np.std (dice_vals))
    print(f"  N={N:2d} | Dice: {dice_means[-1]:.4f} ± {dice_stds[-1]:.4f} "
          f"| čas: {time.time()-t0:.1f}s")

plt.figure(figsize=(6,4))
plt.errorbar(N_LIST, dice_means, yerr=dice_stds,
             marker='o', capsize=4, label="MC-dropout")
plt.axhline(dice_means[0], linestyle='--', color='green',
            label=f"Baseline N=0 (Dice {dice_means[0]:.3f})")
plt.xlabel("MC vzorky N")
plt.ylabel("Priemerný Dice (test sada)")
plt.title("Dice stabilita s rastúcim N")
plt.ylim(min(dice_means)-0.02, max(dice_means)+0.02)
plt.xticks(N_LIST)
plt.grid(True, axis='y', alpha=0.4)
plt.legend()
plt.tight_layout()
plt.show()

# ------------------------- VIZUALIZATION of entropy ------------------------
random.seed(42)
idx = random.randrange(len(test_dataset))
img, gt = test_dataset[idx]
img_batch = img.unsqueeze(0).to(device)

mean_50, *_ = mc_dropout_predict(
    model, img_batch, num_samples=50, return_entropy=False
)
orig_np  = img.cpu().numpy()
gt_np    = gt.squeeze().cpu().numpy()
pred_np  = (mean_50.squeeze() > PRED_THRESH).astype(float)

entropy_maps = {}
for N in [10, 20, 30, 40, 50]:
    _, _, ent_map = mc_dropout_predict(
        model, img_batch, num_samples=N,
        return_entropy=True
    )
    entropy_maps[f"Entropy\nN={N}"] = ent_map

def visualize_entropy_vs_N(original, ground_truth, segmentation, entropies_dict):
    cols = max(3, len(entropies_dict))
    fig, axes = plt.subplots(2, cols, figsize=(4*cols, 8))
    axes[0,0].imshow(np.transpose(original,(1,2,0)), cmap='gray')
    axes[0,0].set_title("RTG");         axes[0,0].axis('off')
    axes[0,1].imshow(ground_truth, cmap='gray')
    axes[0,1].set_title("GT maska");    axes[0,1].axis('off')
    axes[0,2].imshow(segmentation, cmap='gray')
    axes[0,2].set_title("Predikcia (N=50)"); axes[0,2].axis('off')
    for j in range(3, cols): axes[0,j].axis('off')
    for j,(name,emap) in enumerate(entropies_dict.items()):
        ax = axes[1,j]
        im = ax.imshow(emap.squeeze(), cmap='jet', vmin=0.0, vmax=np.log(2.0))
        ax.set_title(name)
        ax.axis('off')
        fig.colorbar(im, ax=ax, fraction=0.045, pad=0.04)
    for j in range(len(entropies_dict), cols):
        axes[1,j].axis('off')
    plt.tight_layout()
    plt.show()

visualize_entropy_vs_N(orig_np, gt_np, pred_np, entropy_maps)


In [None]:
def compute_dice(pred_mask, true_mask):
    pred_flat = pred_mask.flatten()
    true_flat = true_mask.flatten()
    return f1_score(true_flat, pred_flat)

threshold = PRED_THRESH if 'PRED_THRESH' in globals() else 0.5
sample_counts = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

for N in sample_counts:
    dice_scores = []
    start = time.time()
    for image, mask in val_loader:
        image = image.to(device)
        mask = mask.squeeze().cpu().numpy().astype(bool)

        if N == 0:
            model.eval()
            with torch.no_grad():
                output = model(image)
                prob = torch.sigmoid(output).squeeze().cpu().numpy()
        else:
            mean_pred, *_ = mc_dropout_predict(model, image, num_samples=N)
            prob = mean_pred

        pred_bin = prob > threshold
        dice = compute_dice(pred_bin, mask)
        dice_scores.append(dice)

    mean_dice = np.mean(dice_scores)
    std_dice = np.std(dice_scores)
    elapsed = time.time() - start
    print(f"  N={N:2d} | Dice: {mean_dice:.4f} ± {std_dice:.4f} | čas: {elapsed:.1f}s")


## Task 3: Overlap histogram of entropy

In [None]:
import numpy as np, torch, matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from scipy.stats import gaussian_kde   

N_OPT = 30                       
BINS  = 60                        

# 1) loader s batch_size=1
test_loader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)

correct_vals  = []  
incorrect_vals = []   

print(f"⇢ Zbieram std pre všetky pixely test-setu (N={N_OPT}) …")
for img, gt in test_loader_1:
    img = img.to(device)                           # (1,3,H,W)

    # MC-dropout → entropy
    _, _,std_map= mc_dropout_predict(
        model, img,
        num_samples     = N_OPT,
        return_std      = True,
        return_entropy  = False,
        return_iqr      = False,
        return_range95  = False,
        return_samples  = False
    )
    std_map = std_map.squeeze()                    # (H,W)

    # bináry pred
    mean_np, *_ = mc_dropout_predict(
        model, img,
        num_samples     = N_OPT,
        return_std      = False,
        return_entropy  = False,
        return_iqr      = False,
        return_range95  = False,
        return_samples  = False
    )
    pred_bin = (mean_np.squeeze() > PRED_THRESH).astype(np.uint8)   # (H,W)
    gt_bin   = gt.squeeze().numpy().astype(np.uint8)

    # masks
    correct_mask   = (pred_bin == gt_bin)
    incorrect_mask = ~correct_mask

    correct_vals  .append(std_map[correct_mask])
    incorrect_vals.append(std_map[incorrect_mask])

# to 1D
correct_vals   = np.concatenate(correct_vals)
incorrect_vals = np.concatenate(incorrect_vals)

print(f"  • number of correct pixels   : {correct_vals.size:,}")
print(f"  • number of incorrect pixels : {incorrect_vals.size:,}")

# 2) Histogram
plt.figure(figsize=(7,4))

plt.hist(correct_vals,   bins=BINS, density=True, alpha=0.5,
         label='Correct',   color='tab:blue')
plt.hist(incorrect_vals, bins=BINS, density=True, alpha=0.5,
         label='Incorrect', color='tab:orange')

for vals, col in [(correct_vals,'tab:blue'), (incorrect_vals,'tab:orange')]:
    kde = gaussian_kde(vals)
    xs  = np.linspace(0, np.log(2), 300)
    plt.plot(xs, kde(xs), color=col)

plt.xlabel("Pixel-wise std  $H$  (bits)")
plt.title(f"Overlap std – Correct vs Incorrect  (N={N_OPT})")
plt.legend()
plt.xlim(0, np.log(2))
plt.tight_layout()
plt.show()

# 3) Kvantifikované prekrývanie (area of overlap)
hist_c , edges = np.histogram(correct_vals  , bins=BINS, density=True)
hist_ic, _     = np.histogram(incorrect_vals, bins=edges, density=True)
overlap = np.sum(np.minimum(hist_c, hist_ic) * np.diff(edges))
print(f"Area of overlap = {overlap:.3f} ")


In [None]:
N_OPT = 30
BINS  = 60
metrics = ["Variance", "Std", "Entropy", "IQR", "Range95"]

test_loader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)
corr_vals, inc_vals = {m: [] for m in metrics}, {m: [] for m in metrics}

print(f"⇢ Zbieram metriky (entropia z mean-p, N={N_OPT}) …")
for img, gt in test_loader_1:
    img = img.to(device)

    mean_np, var_map, std_map,  iqr_map, r95_map = mc_dropout_predict(
        model, img,
        num_samples     = N_OPT,
        return_std      = True,
        return_entropy  = False,  
        return_iqr      = True,
        return_range95  = True,
        return_samples  = False
    )
    # ── entropy ────────────────────────────
    p_mean = np.clip(mean_np.squeeze(), 1e-6, 1-1e-6)      # (H,W)
    ent_map = -(p_mean*np.log(p_mean) + (1-p_mean)*np.log(1-p_mean))

    pred_bin = (p_mean > PRED_THRESH).astype(np.uint8)
    gt_bin   = gt.squeeze().numpy().astype(np.uint8)

    corr_mask = (pred_bin == gt_bin)
    inc_mask  = ~corr_mask

    maps = {
        "Variance": var_map,
        "Std":      std_map,
        "Entropy":  ent_map,
        "IQR":      iqr_map,
        "Range95":  r95_map
    }
    for m in metrics:
        corr_vals[m].append(maps[m][corr_mask])
        inc_vals [m].append(maps[m][inc_mask])

for m in metrics:
    corr_vals[m] = np.concatenate(corr_vals[m])
    inc_vals [m] = np.concatenate(inc_vals[m])

# ── visualization & overlap ───────────────────────────────────────────────
fig, axes = plt.subplots(2, 3, figsize=(15, 8)); axes = axes.flatten()
overlap_num = {}

for i, m in enumerate(metrics):
    ax = axes[i]
    corr, inc = corr_vals[m], inc_vals[m]

    vmax = np.percentile(np.concatenate([corr, inc]), 99)
    bins = np.linspace(0, vmax, BINS+1)

    h_c, _ = np.histogram(corr, bins=bins, density=True)
    h_i, _ = np.histogram(inc , bins=bins, density=True)
    ax.hist(corr, bins=bins, alpha=0.5, density=True, label='Correct')
    ax.hist(inc , bins=bins, alpha=0.5, density=True, label='Incorrect')

    overlap = np.sum(np.minimum(h_c, h_i) * np.diff(bins))
    overlap_num[m] = overlap

    ax.set_title(f"{m}  (overlap={overlap:.2f})")
    ax.set_xlabel(m); ax.set_ylabel("Density"); ax.grid(alpha=0.2)


axes[-1].axis('off')
fig.suptitle(f"Pixel-wise overlap – MC-dropout N={N_OPT}", fontsize=16)
fig.tight_layout(rect=[0,0,1,0.95])
plt.legend(loc='upper right')
plt.show()

print("\nArea of overlap :")
for m in metrics:
    print(f"  {m:8s}: {overlap_num[m]:.3f}")


## Task 4 : Reliability heat-map
###     error-rate vs. epistemic uncertainty

In [None]:

from scipy.stats import spearmanr


N_OPT = NUM_SAMPLES          
K_BINS = 20        
test_loader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)

u_vals, err_vals = [], []

print(f"Going through test-set, N={N_OPT}, K={K_BINS} …")
for img, gt in test_loader_1:
    img = img.to(device)                     # (1,3,H,W)

    # MC-dropout → mean + entropy
    mean_np, _, ent_map= mc_dropout_predict(
        model, img,
        num_samples     = N_OPT,
        return_entropy  = True
    )
    ent_map = ent_map.squeeze()              # (H,W)

    # binary pred a error mASK
    pred_bin = (mean_np.squeeze() > PRED_THRESH).astype(np.uint8)
    gt_bin   = gt.squeeze().numpy().astype(np.uint8)
    err_mask = (pred_bin != gt_bin).astype(np.uint8)

    bins = np.linspace(0.0, np.log(2.0), K_BINS + 1)        # 0 … ln2
    bin_idx = np.digitize(ent_map, bins) - 1                # 0 … K-1

    for k in range(K_BINS):
        in_bin = (bin_idx == k)
        if in_bin.sum() == 0:
            continue
        err_rate = err_mask[in_bin].mean()                 
        u_center = 0.5 * (bins[k] + bins[k+1])             
        u_vals.append(u_center)
        err_vals.append(err_rate)


H, xedges, yedges = np.histogram2d(
    u_vals, err_vals,
    bins=[K_BINS, K_BINS],        
    range=[[0, np.log(2.0)], [0, 1]]
)

mean_err_per_bin = []
for i in range(K_BINS):
    mask = (np.array(u_vals) >= xedges[i]) & (np.array(u_vals) < xedges[i+1])
    if mask.sum() > 0:
        mean_err_per_bin.append(err_vals*np.array(mask))
    else:
        mean_err_per_bin.append([])

mean_err = [
    np.mean(np.array(err_vals)[(np.array(u_vals) >= xedges[i]) &
                               (np.array(u_vals) <  xedges[i+1])])
    if ((np.array(u_vals) >= xedges[i]) & (np.array(u_vals) < xedges[i+1])).any()
    else np.nan
    for i in range(K_BINS)
]

# Spearman correlation
rho, p_val = spearmanr(u_vals, err_vals)


plt.figure(figsize=(7,5))
# heat-map
plt.imshow(
    H.T, origin='lower',
    extent=[0, np.log(2.0), 0, 1], aspect='auto',
    cmap='viridis'
)
plt.colorbar(label = 'Number of (slice, bin) points')
# red average line
centers = 0.5 * (xedges[:-1] + xedges[1:])
plt.plot(centers, mean_err, color='red', linewidth=2, label='Average error-rate')

plt.xlabel('Epistemic uncertainty  $H(\\hat p)$')
plt.ylabel('Error-rate')
plt.title(f'Reliability heat-map  –  N={N_OPT}  |  Spearman ρ={rho:.2f} (p={p_val:.2e})')
plt.ylim(0, 1)
plt.legend()
plt.tight_layout()
plt.show()


## Task 5: Scatter VVC vs (1 – Dice)

In [None]:

from scipy.stats import linregress, spearmanr

N_OPT = NUM_SAMPLES          # sweet-spot 
test_loader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_list, vvc_list = [], []

print("⇢ Počítam VVC a 1-Dice pre každú snímku test-setu …")
for idx, (img, gt) in enumerate(test_loader_1):
    img = img.to(device)

    mean_np, _,  samples = mc_dropout_predict(
        model, img,
        num_samples    = N_OPT,
        return_std     = False,
        return_entropy = False,
        return_iqr     = False,
        return_range95 = False,
        return_samples = True         
    )

    # --- Dice for average mask -----------------------------------------
    pred_bin = (mean_np > PRED_THRESH).astype(np.uint8)
    d = dice_score(torch.from_numpy(pred_bin)[None,None].float(),
                   gt.float(), threshold=0.5)          # tu 0.5 
    dice_list.append(1 - d)                           

    # --- VVC -------------------------------------------------------------
    vols = samples.reshape(N_OPT, -1).sum(axis=1)    
    muV  = vols.mean()
    sigmaV = vols.std()
    vvc_list.append(sigmaV / (muV + 1e-8))           

# Scatter + linear regression 
x = np.array(dice_list)
y = np.array(vvc_list)

slope, intercept, r_val, p_val, _ = linregress(x, y)
rho, p_spear = spearmanr(x, y)

plt.figure(figsize=(4,4))
plt.scatter(x, y, color="#c1440e", s=30)
plt.plot(x, slope*x + intercept, color="#c1440e")
plt.xlabel("1 – Dice")
plt.ylabel("VVC")
plt.title(f"TTR (MC-dropout)   R²={r_val**2:.3f}   ρ={rho:.3f}")
plt.xlim(0, x.max()*1.05)
plt.ylim(0, y.max()*1.1)
plt.tight_layout()
plt.show()

print(f"\nLinear regression:  y = {slope:.3f}·x + {intercept:.4f}   "
      f"(R² = {r_val**2:.3f},  p = {p_val:.2e})")
print(f"Spearman ρ = {rho:.3f}  (p = {p_spear:.2e})")


In [None]:

N_OPT = 80          # sweet-spot 
test_loader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_list, vvc_list = [], []

print("⇢ Počítam VVC a 1-Dice pre každú snímku test-setu …")
for idx, (img, gt) in enumerate(test_loader_1):
    img = img.to(device)

    mean_np, _,  samples = mc_dropout_predict(
        model, img,
        num_samples    = N_OPT,
        return_std     = False,
        return_entropy = False,
        return_iqr     = False,
        return_range95 = False,
        return_samples = True          
    )

    pred_bin = (mean_np > PRED_THRESH).astype(np.uint8)
    d = dice_score(torch.from_numpy(pred_bin)[None,None].float(),
                   gt.float(), threshold=0.5)          
    dice_list.append(1 - d)                          

    # --- VVC -------------------------------------------------------------
    vols = samples.reshape(N_OPT, -1).sum(axis=1)     
    muV  = vols.mean()
    sigmaV = vols.std()
    vvc_list.append(sigmaV / (muV + 1e-8))            

x = np.array(dice_list)
y = np.array(vvc_list)

slope, intercept, r_val, p_val, _ = linregress(x, y)
rho, p_spear = spearmanr(x, y)

plt.figure(figsize=(4,4))
plt.scatter(x, y, color="#c1440e", s=30)
plt.plot(x, slope*x + intercept, color="#c1440e")
plt.xlabel("1 – Dice")
plt.ylabel("VVC")
plt.title(f"TTR (MC-dropout)   R²={r_val**2:.3f}   ρ={rho:.3f}")
plt.xlim(0, x.max()*1.05)
plt.ylim(0, y.max()*1.1)
plt.tight_layout()
plt.show()

print(f"\nLineárna regresia:  y = {slope:.3f}·x + {intercept:.4f}   "
      f"(R² = {r_val**2:.3f},  p = {p_val:.2e})")
print(f"Spearman ρ = {rho:.3f}  (p = {p_spear:.2e})")


## Task 6: Mean vs Median aggregation (MC-dropout) 

In [None]:
N_SAMPLES   = NUM_SAMPLES          # number of MC-dropout samples
loader      = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_mean, dice_median = [], []

print(f"→ Evaluating Dice for MEAN vs MEDIAN (N={N_SAMPLES})")
for img, gt in loader:
    img = img.to(device)

    # run MC-dropout and keep the raw samples
    mean_pred, _,  samples = mc_dropout_predict(
        model, img,
        num_samples    = N_SAMPLES,
        return_entropy = False,
        return_samples = True
    )
    samples = samples                                  # [N,H,W]
    median_pred = np.median(samples, axis=0)

    # binarise
    mean_bin   = (mean_pred   > PRED_THRESH).astype(np.uint8)
    median_bin = (median_pred > PRED_THRESH).astype(np.uint8)

    # Dice
    dice_mean.append(
        dice_score(torch.from_numpy(mean_bin)[None,None].float(),
                   gt.float(), threshold=0.5) )
    dice_median.append(
        dice_score(torch.from_numpy(median_bin)[None,None].float(),
                   gt.float(), threshold=0.5) )

dice_mean  = np.array(dice_mean)
dice_median= np.array(dice_median)

print(f"Mean aggregation  :  Dice = {dice_mean.mean():.4f} ± {dice_mean.std():.4f}")
print(f"Median aggregation:  Dice = {dice_median.mean():.4f} ± {dice_median.std():.4f}")

# --- bar-plot ------------------------------------------------------------
labels = ["Mean", "Median"]
vals   = [dice_mean.mean(), dice_median.mean()]
errs   = [dice_mean.std(),  dice_median.std()]

plt.figure(figsize=(4,4))
plt.bar(labels, vals, yerr=errs, capsize=6, color=["steelblue","orange"])
plt.ylabel("Average Dice (test set)")
plt.title(f"Mean vs Median aggregation  |  N={N_SAMPLES}")
plt.ylim(0, 1)
plt.grid(axis="y", alpha=.3)
plt.tight_layout()
plt.show()


In [None]:
N_SAMPLES   = NUM_SAMPLES          # number of MC-dropout samples
loader      = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_mean, dice_median = [], []

print(f"→ Evaluating Dice for MEAN vs MEDIAN (N={N_SAMPLES})")
for img, gt in loader:
    img = img.to(device)

    # run MC-dropout and keep the raw samples
    mean_pred, _,  samples = mc_dropout_predict(
        model, img,
        num_samples    = N_SAMPLES,
        return_entropy = False,
        return_samples = True
    )
    samples = samples                                  # [N,H,W]
    median_pred = np.median(samples, axis=0)

    # binarise
    mean_bin   = (mean_pred   > PRED_THRESH).astype(np.uint8)
    median_bin = (median_pred > PRED_THRESH).astype(np.uint8)

    # Dice
    dice_mean.append(
        dice_score(torch.from_numpy(mean_bin)[None,None].float(),
                   gt.float(), threshold=0.5) )
    dice_median.append(
        dice_score(torch.from_numpy(median_bin)[None,None].float(),
                   gt.float(), threshold=0.5) )

dice_mean  = np.array(dice_mean)
dice_median= np.array(dice_median)

print(f"Mean aggregation  :  Dice = {dice_mean.mean():.4f} ± {dice_mean.std():.4f}")
print(f"Median aggregation:  Dice = {dice_median.mean():.4f} ± {dice_median.std():.4f}")

# --- bar-plot ------------------------------------------------------------
labels = ["Mean", "Median"]
vals   = [dice_mean.mean(), dice_median.mean()]
errs   = [dice_mean.std(),  dice_median.std()]

plt.figure(figsize=(4,4))
plt.bar(labels, vals, yerr=errs, capsize=6, color=["steelblue","orange"])
plt.ylabel("Average Dice (test set)")
plt.title(f"Mean vs Median aggregation  |  N={N_SAMPLES}")
plt.ylim(0, 1)
plt.grid(axis="y", alpha=.3)
plt.tight_layout()
plt.show()


In [None]:
# any test slice
random.seed(123)
idx = random.randrange(len(test_dataset))
img, gt = test_dataset[idx]
img_b   = img.unsqueeze(0).to(device)          # (1,3,H,W)

# run MC-dropout and keep raw samples
mean_pred, _,  samples = mc_dropout_predict(
    model, img_b,
    num_samples    = N_SAMPLES,    # keep same N as above (40)
    return_entropy = False,
    return_samples = True
)
median_pred = np.median(samples, axis=0)       # (H,W)

# entropy from mean
eps = 1e-12
p_mean = np.clip(mean_pred,   eps, 1-eps)
p_med  = np.clip(median_pred, eps, 1-eps)
entropy_mean = -(p_mean*np.log(p_mean) + (1-p_mean)*np.log(1-p_mean))
entropy_med  = -(p_med *np.log(p_med ) + (1-p_med )*np.log(1-p_med ))

fig, axes = plt.subplots(2, 3, figsize=(10, 7))

axes[0,0].imshow(np.transpose(img.cpu().numpy(), (1,2,0)), cmap='gray')
axes[0,0].set_title("Original");           axes[0,0].axis('off')

axes[0,1].imshow((mean_pred.squeeze() > PRED_THRESH), cmap='gray')
axes[0,1].set_title(f"Prediction (mean, N={N_SAMPLES})"); axes[0,2].axis('off')

im1 = axes[1,0].imshow(entropy_mean, cmap='jet',
                       vmin=0.0, vmax=np.log(2.0))
axes[1,0].set_title("Entropy (mean)");  axes[1,0].axis('off')
plt.colorbar(im1, ax=axes[1,0], fraction=0.05, pad=0.04)

im2 = axes[1,1].imshow(entropy_med, cmap='jet',
                       vmin=0.0, vmax=np.log(2.0))
axes[1,1].set_title("Entropy (median)"); axes[1,1].axis('off')
plt.colorbar(im2, ax=axes[1,1], fraction=0.05, pad=0.04)

diff = entropy_med - entropy_mean
im3 = axes[1,2].imshow(diff, cmap='bwr',
                       vmin=-0.2, vmax=0.2)   # symmetric range
axes[1,2].set_title("Difference  (median – mean)")
axes[1,2].axis('off')
plt.colorbar(im3, ax=axes[1,2], fraction=0.05, pad=0.04)

plt.tight_layout()
plt.show()


In [None]:

N_SAMPLES = NUM_SAMPLES      # MC-dropout samples per run
K_RUNS    = 10      # how many independent runs
THR       = PRED_THRESH

loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_mean_runs   = []   # shape [K_RUNS]
dice_median_runs = []

for run in range(K_RUNS):
    torch.manual_seed(run)
    np.random.seed(run)
    random.seed(run)

    dice_mean_all, dice_median_all = [], []

    for img, gt in loader:
        img = img.to(device)

        mean_pred, _, samples = mc_dropout_predict(
            model, img,
            num_samples    = N_SAMPLES,
            return_samples = True
        )
        median_pred = np.median(samples, axis=0)

        mean_bin   = (mean_pred   > THR).astype(np.uint8)
        median_bin = (median_pred > THR).astype(np.uint8)

        dice_mean_all.append(
            dice_score(torch.from_numpy(mean_bin)[None,None].float(),
                       gt.float(), threshold=0.5) )
        dice_median_all.append(
            dice_score(torch.from_numpy(median_bin)[None,None].float(),
                       gt.float(), threshold=0.5) )

    dice_mean_runs .append(np.mean(dice_mean_all))
    dice_median_runs.append(np.mean(dice_median_all))

dice_mean_runs   = np.array(dice_mean_runs)
dice_median_runs = np.array(dice_median_runs)

# --- paired t-test -------------------------------------------------------
t_stat, p_val = ttest_rel(dice_mean_runs, dice_median_runs)

print(f"Mean   Dice: {dice_mean_runs.mean():.4f} ± {dice_mean_runs.std():.4f}")
print(f"Median Dice: {dice_median_runs.mean():.4f} ± {dice_median_runs.std():.4f}")
print(f"Paired t-test:  t = {t_stat:.3f}   p = {p_val:.4f}")

# --- box-plot ------------------------------------------------------------
plt.figure(figsize=(4,4))
plt.boxplot([dice_mean_runs, dice_median_runs],
            labels=["Mean", "Median"], showmeans=True)
plt.ylabel("Average Dice per run")
plt.title(f"Mean vs Median  |  N={N_SAMPLES},  K={K_RUNS} runs")
plt.grid(axis="y", alpha=.3)
plt.tight_layout()
plt.show()


In [None]:

N_SAMPLES = NUM_SAMPLES      # MC-dropout samples per run
K_RUNS    = 60      # how many independent runs
THR       = PRED_THRESH

loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_mean_runs   = []   # shape [K_RUNS]
dice_median_runs = []

for run in range(K_RUNS):
    torch.manual_seed(run)
    np.random.seed(run)
    random.seed(run)

    dice_mean_all, dice_median_all = [], []

    for img, gt in loader:
        img = img.to(device)

        mean_pred, _, samples = mc_dropout_predict(
            model, img,
            num_samples    = N_SAMPLES,
            return_samples = True
        )
        median_pred = np.median(samples, axis=0)

        mean_bin   = (mean_pred   > THR).astype(np.uint8)
        median_bin = (median_pred > THR).astype(np.uint8)

        dice_mean_all.append(
            dice_score(torch.from_numpy(mean_bin)[None,None].float(),
                       gt.float(), threshold=0.5) )
        dice_median_all.append(
            dice_score(torch.from_numpy(median_bin)[None,None].float(),
                       gt.float(), threshold=0.5) )

    dice_mean_runs .append(np.mean(dice_mean_all))
    dice_median_runs.append(np.mean(dice_median_all))

dice_mean_runs   = np.array(dice_mean_runs)
dice_median_runs = np.array(dice_median_runs)

# --- paired t-test -------------------------------------------------------
t_stat, p_val = ttest_rel(dice_mean_runs, dice_median_runs)

print(f"Mean   Dice: {dice_mean_runs.mean():.4f} ± {dice_mean_runs.std():.4f}")
print(f"Median Dice: {dice_median_runs.mean():.4f} ± {dice_median_runs.std():.4f}")
print(f"Paired t-test:  t = {t_stat:.3f}   p = {p_val:.4f}")

# --- box-plot ------------------------------------------------------------
plt.figure(figsize=(4,4))
plt.boxplot([dice_mean_runs, dice_median_runs],
            labels=["Mean", "Median"], showmeans=True)
plt.ylabel("Average Dice per run")
plt.title(f"Mean vs Median  |  N={N_SAMPLES},  K={K_RUNS} runs")
plt.grid(axis="y", alpha=.3)
plt.tight_layout()
plt.show()


## Task 7 : Remove top-10 % most uncertain pixels and recompute Dice


In [None]:
N_SAMPLES   = NUM_SAMPLES      # MC-dropout samples for entropy
TOP_FRAC    = 0.10    # fraction of pixels to remove
THR         = PRED_THRESH
loader      = DataLoader(test_dataset, batch_size=1, shuffle=False)

dice_full, dice_masked = [], []

def masked_dice(pred_bin, target_bin, valid_mask):
    """Dice only on pixels where valid_mask == 1."""
    pred   = pred_bin.float()   * valid_mask
    target = target_bin.float() * valid_mask
    inter  = (pred * target).sum()
    union  = pred.sum() + target.sum()
    return (2 * inter) / (union + 1e-7)

print(f"→ Evaluating Dice with top {int(TOP_FRAC*100)} % high-entropy pixels removed")
for img, gt in loader:
    img = img.to(device)

    # MC-dropout → entropy map
    mean_pred, _,  ent_map, *_ = mc_dropout_predict(
        model, img,
        num_samples    = N_SAMPLES,
        return_entropy = True
    )
    pred_bin = (mean_pred > THR).astype(np.uint8)
    H, W     = ent_map.shape

    # baseline Dice (all pixels)
    dice_full.append(
        dice_score(torch.from_numpy(pred_bin)[None,None].float(),
                   gt.float(), threshold=0.5)
    )

    # keep only 90 % lowest-entropy pixels
    k        = int((1 - TOP_FRAC) * H * W)
    thresh   = np.partition(ent_map.flatten(), k)[k]        # kth smallest
    valid    = (ent_map <= thresh).astype(np.uint8)

    dice_masked.append(
        masked_dice(torch.from_numpy(pred_bin)[None,None],
                    gt, torch.from_numpy(valid)[None,None])
        .item()
    )

dice_full   = np.array(dice_full)
dice_masked = np.array(dice_masked)

# paired t-test
t_stat, p_val = ttest_rel(dice_full, dice_masked)

print(f"\nFull-image Dice  :  {dice_full.mean():.4f} ± {dice_full.std():.4f}")
print(f"Masked Dice      :  {dice_masked.mean():.4f} ± {dice_masked.std():.4f}")
print(f"Paired t-test    :  t = {t_stat:.3f}   p = {p_val:.4f}")

# box-plot
plt.figure(figsize=(4,4))
plt.boxplot([dice_full, dice_masked],
            labels=["Full", "Masked"], showmeans=True)
plt.ylabel("Dice score")
plt.title(f"Effect of removing top 10 % uncertain pixels  |  N={N_SAMPLES}")
plt.grid(axis="y", alpha=.3)
plt.tight_layout()
plt.show()


In [None]:
#  Visual demo – 10 % highest-entropy pixels removed
random.seed(8)
idx = random.randrange(len(test_dataset))
img, gt = test_dataset[idx]
img_b   = img.unsqueeze(0).to(device)

# MC-dropout entropy
mean_pred, _,  ent_map, *_ = mc_dropout_predict(
    model, img_b,
    num_samples    = N_SAMPLES,      # 30
    return_entropy = True
)

pred_bin = (mean_pred > THR).astype(np.uint8)

# build valid-mask (keep 90 % lowest entropy)
H, W  = ent_map.shape
k     = int((1 - TOP_FRAC) * H * W)
thr_e = np.partition(ent_map.flatten(), k)[k]
valid = (ent_map <= thr_e).astype(np.uint8)

# figure ---------------------------------------------------------------
fig, axes = plt.subplots(2, 2, figsize=(8, 8))

axes[0,0].imshow(np.transpose(img.cpu().numpy(), (1,2,0)), cmap='gray')
axes[0,0].imshow(pred_bin, alpha=0.3, cmap='Reds')
axes[0,0].set_title("Baseline prediction"); axes[0,0].axis('off')

im = axes[0,1].imshow(ent_map, cmap='jet', vmin=0, vmax=np.log(2.0))
axes[0,1].set_title("Predictive entropy");  axes[0,1].axis('off')
plt.colorbar(im, ax=axes[0,1], fraction=0.046, pad=0.04)

axes[1,0].imshow(valid, cmap='Greens')
axes[1,0].set_title("Valid mask (90 % kept)"); axes[1,0].axis('off')

removed = np.where(valid==0, 1, np.nan)      # red overlay on removed
axes[1,1].imshow(pred_bin, cmap='gray')
axes[1,1].imshow(removed, cmap='autumn', alpha=0.7)
axes[1,1].set_title("Removed pixels (red)"); axes[1,1].axis('off')

plt.tight_layout()
plt.show()


In [None]:
random.seed(2)
idx = random.randrange(len(test_dataset))
img, gt = test_dataset[idx]
img_b   = img.unsqueeze(0).to(device)

# MC-dropout entropy
mean_pred, _,  ent_map, *_ = mc_dropout_predict(
    model, img_b,
    num_samples    = N_SAMPLES,      # 30
    return_entropy = True
)

pred_bin = (mean_pred > THR).astype(np.uint8)

# build valid-mask (keep 90 % lowest entropy)
H, W  = ent_map.shape
k     = int((1 - TOP_FRAC) * H * W)
thr_e = np.partition(ent_map.flatten(), k)[k]
valid = (ent_map <= thr_e).astype(np.uint8)

# figure ---------------------------------------------------------------
fig, axes = plt.subplots(2, 2, figsize=(8, 8))

axes[0,0].imshow(np.transpose(img.cpu().numpy(), (1,2,0)), cmap='gray')
axes[0,0].imshow(pred_bin, alpha=0.3, cmap='Reds')
axes[0,0].set_title("Baseline prediction"); axes[0,0].axis('off')

im = axes[0,1].imshow(ent_map, cmap='jet', vmin=0, vmax=np.log(2.0))
axes[0,1].set_title("Predictive entropy");  axes[0,1].axis('off')
plt.colorbar(im, ax=axes[0,1], fraction=0.046, pad=0.04)

axes[1,0].imshow(valid, cmap='Greens')
axes[1,0].set_title("Valid mask (90 % kept)"); axes[1,0].axis('off')

removed = np.where(valid==0, 1, np.nan)      # red overlay on removed
axes[1,1].imshow(pred_bin, cmap='gray')
axes[1,1].imshow(removed, cmap='autumn', alpha=0.7)
axes[1,1].set_title("Removed pixels (red)"); axes[1,1].axis('off')

plt.tight_layout()
plt.show()


In [None]:
random.seed(2)
idx = random.randrange(len(test_dataset))
img, gt = test_dataset[idx]
img_b   = img.unsqueeze(0).to(device)

# MC-dropout entropy
mean_pred, _, ent_map = mc_dropout_predict(
    model, img_b,
    num_samples    = N_SAMPLES,      # 30
    return_entropy = True            # returns (mean, var, entropy, …)
)

pred_bin = (mean_pred > THR).astype(np.uint8)

# Dice on full image
dice_full = dice_score(
    torch.from_numpy(pred_bin)[None, None].float(),   # (1,1,H,W)
    gt.unsqueeze(0).float(),                          # (1,1,H,W)
    threshold=0.5
)

# Build valid-mask – keep 90 % lowest-entropy pixels
H, W  = ent_map.shape
k     = int((1 - TOP_FRAC) * H * W)
thr_e = np.partition(ent_map.flatten(), k)[k]
valid = (ent_map <= thr_e).astype(np.uint8)

# Dice on valid pixels only
dice_masked = masked_dice(
    torch.from_numpy(pred_bin)[None, None],
    gt.unsqueeze(0),
    torch.from_numpy(valid)[None, None]
)

# --------------------------- figure ------------------------------------
fig, axes = plt.subplots(2, 2, figsize=(8, 8))

axes[0,0].imshow(np.transpose(img.cpu().numpy(), (1,2,0)), cmap='gray')
axes[0,0].imshow(pred_bin, alpha=0.3, cmap='Reds')
axes[0,0].set_title("Baseline prediction"); axes[0,0].axis('off')

im = axes[0,1].imshow(ent_map, cmap='jet', vmin=0, vmax=np.log(2.0))
axes[0,1].set_title("Predictive entropy");  axes[0,1].axis('off')
plt.colorbar(im, ax=axes[0,1], fraction=0.046, pad=0.04)

axes[1,0].imshow(valid, cmap='Greens')
axes[1,0].set_title("Valid mask (90 % kept)"); axes[1,0].axis('off')

removed = np.where(valid == 0, 1, np.nan)          # red overlay
axes[1,1].imshow(pred_bin, cmap='gray')
axes[1,1].imshow(removed, cmap='autumn', alpha=0.7)
axes[1,1].set_title("Removed pixels (red)"); axes[1,1].axis('off')

plt.figtext(0.5, 0,
       f"Dice score: Dice full: {dice_full:.3f}   "
    f"Dice masked: {dice_masked:.3f}",
 fontsize=12, ha='center', va='top'
)
plt.tight_layout()
plt.show()




## Task 8 : Compare SE (no aug) vs. SENOAUG (flip + rot) on epistemic uncertainty

In [None]:
PATH_SE       = "/segmentation/unet_segmentationSE.pth"
PATH_SENOAUG  = "/segmentation/unet_segmentationSENOAUG.pth"

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_unet(path):
    net = UNetWithDropout(in_channels=3, out_channels=1).to(device)
    net.load_state_dict(torch.load(path, map_location=device))
    net.eval()
    return net

model_no_aug      = load_unet(PATH_SENOAUG)
model_aug     = load_unet(PATH_SE)

# ------------------------------------------------------------------ eval helper
def eval_dice_entropy(model, n_mc=30, thr=0.4):
    dice_list, ent_list = [], []
    for img, gt in test_loader:
        img = img.to(device)
        mean_pred, _, ent_map = mc_dropout_predict(
            model, img,
            num_samples    = n_mc,
            return_entropy = True
        )
        pred_bin = (mean_pred > thr).astype(np.uint8)
        dice_val = dice_score(
            torch.from_numpy(pred_bin)[None,None].float(),
            gt.float(), threshold=0.5
        )
        dice_list.append(dice_val)
        ent_list .append(ent_map.mean())
    return np.array(dice_list), np.array(ent_list)

dice_no_aug , ent_no_aug  = eval_dice_entropy(model_no_aug)
dice_aug, ent_aug = eval_dice_entropy(model_aug)

# --------------------------- statistics --------------------------
t_dice, p_dice = ttest_rel(dice_no_aug,  dice_aug)
t_ent , p_ent  = ttest_rel(ent_no_aug,   ent_aug)

print("\n===  RESULTS  ===")
print(f"Dice  No AUG       : {dice_no_aug.mean():.4f} ± {dice_no_aug.std():.4f}")
print(f"Dice  With AUG  : {dice_aug.mean():.4f} ± {dice_aug.std():.4f}")
print(f"Paired t-test Dice   p = {p_dice:.4e}")

print(f"\nEntropy  No AUG    : {ent_no_aug.mean():.4f} ± {ent_no_aug.std():.4f}")
print(f"Entropy  With AUG   : {ent_aug.mean():.4f} ± {ent_aug.std():.4f}")
print(f"Paired t-test entropy p = {p_ent:.4e}")

# --------------------------- bar plots ---------------------------
fig, axes = plt.subplots(1, 2, figsize=(8,4))
axes[0].bar(["No AUG","With AUG"], [dice_no_aug.mean(), dice_aug.mean()],
            yerr=[dice_no_aug.std(), dice_aug.std()], capsize=6)
axes[0].set_title("Average Dice"); axes[0].set_ylabel("Dice score")

axes[1].bar(["No AUG","With AUG"], [ent_no_aug.mean(), ent_aug.mean()],
            yerr=[ent_no_aug.std(), ent_aug.std()], capsize=6, color="orange")
axes[1].set_title("Mean predictive entropy"); axes[1].set_ylabel("bits")

plt.suptitle("Effect of TRAINING augmentation (flip + rotations)")
plt.tight_layout()
plt.show()


In [None]:
random.seed(2)
idx = random.randrange(len(test_dataset))
img, _ = test_dataset[idx]
img_b  = img.unsqueeze(0).to(device)

# predictions + entropy
prob_no, _, ent_no = mc_dropout_predict(
    model_no_aug, img_b, num_samples=30, return_entropy=True)
prob_aug, _, ent_aug = mc_dropout_predict(
    model_aug, img_b, num_samples=30, return_entropy=True)

pred_no  = (prob_no  > 0.4).astype(np.uint8)
pred_aug = (prob_aug > 0.4).astype(np.uint8)
diff_ent = ent_aug - ent_no

fig, ax = plt.subplots(2, 3, figsize=(11, 7))

# row 1 – predictions
ax[0,0].imshow(np.transpose(img.cpu().numpy(),(1,2,0)), cmap="gray")
ax[0,0].set_title("Original"); ax[0,0].axis('off')

ax[0,1].imshow(pred_no, cmap="gray"); ax[0,1].set_title("Prediction  • No AUG"); ax[0,1].axis('off')
ax[0,2].imshow(pred_aug, cmap="gray"); ax[0,2].set_title("Prediction  • With AUG"); ax[0,2].axis('off')

# row 2 – entropies
for j,(emap,ttl) in enumerate([(ent_no,"Entropy  • No AUG"),
                               (ent_aug,"Entropy  • With AUG"),
                               (diff_ent,"Δ Entropy  (With UG–No AUG)")]):
    cmap = 'jet' if j<2 else 'bwr'
    vmin,vmax = (0, np.log(2.0)) if j<2 else (-0.3,0.3)
    im = ax[1,j].imshow(emap, cmap=cmap, vmin=vmin, vmax=vmax)
    ax[1,j].set_title(ttl); ax[1,j].axis('off'); plt.colorbar(im, ax=ax[1,j], fraction=0.046)

plt.suptitle(f"Effect of training augmentation on epistemic uncertainty", y=1)
plt.tight_layout(); plt.show()


In [None]:
# Multi-example figure
import matplotlib.patches as mpatches

# parameters 
N_EX     = 30
N_MC     = NUM_SAMPLES     
THR_BIN  = PRED_THRESH      
random.seed(36)

def overlay(img_rgb, mask, color):
    out   = img_rgb.copy()
    edges = cv2.Canny((mask*255).astype(np.uint8), 100, 200)
    out[edges>0] = color
    return out

idx_pool = list(range(len(test_dataset))); random.shuffle(idx_pool)
sel_idx  = idx_pool[:N_EX]

for ex, idx in enumerate(sel_idx, 1):
    img, gt = test_dataset[idx]
    img_b   = img.unsqueeze(0).to(device)
    rtg     = np.transpose(img.cpu().numpy(), (1,2,0))
    mGT     = gt.squeeze().cpu().numpy().astype(np.uint8)

    # baseline
    model.eval()
    with torch.no_grad():
        log_det = model(img_b)
    prob_det = torch.sigmoid(log_det)[0,0].cpu().numpy()
    mDET     = (prob_det > THR_BIN).astype(np.uint8)

    # MC-dropout
    pm, _, ent = mc_dropout_predict(
        model, img_b, num_samples=N_MC, return_entropy=True)
    mTTD = (pm > THR_BIN).astype(np.uint8)

    # overlays
    rgb_det = overlay(overlay(rtg, mDET, (0,1,0)), mGT, (1,1,0))
    rgb_ttd = overlay(overlay(rtg, mTTD, (0,1,1)), mGT, (1,1,0))

    # layout  (extra room on the right for colour-bar + legend)
    gs = dict(width_ratios=[1,1,0.07], height_ratios=[1,0.15,1],
              wspace=0.02, hspace=0.04)
    fig, ax = plt.subplots(3, 3, figsize=(9.2, 8), gridspec_kw=gs)

    # row-1 ----------------------------------------------------------------
    ax[0,0].imshow(rtg, cmap='gray');   ax[0,0].set_title("Original X-ray")
    ax[0,1].imshow(rgb_det);            ax[0,1].set_title("Baseline + GT")
    for a in ax[0,:2]: a.axis('off')   
    ax[0,2].axis('off')

    # ------------------ gap between lines ---------------------
    for a in ax[1,:]:
        a.axis('off')  
    # row-2 ----------------------------------------------------------------
    im = ax[2,0].imshow(ent, cmap='jet', vmin=0, vmax=np.log(2.0))
    ax[2,0].set_title("Epistemic unertainty");           ax[2,0].axis('off')
    ax[2,1].imshow(rgb_ttd);                    ax[2,1].set_title("TTD + GT")
    ax[2,1].axis('off');                       ax[2,2].axis('off')

    # color-bar
    cbar = fig.colorbar(im, ax=ax.ravel().tolist(), shrink=0.75)
    cbar.set_label("Entropy  H(p̂)  [bits]")

    # legend – placed **above** colour-bar (bbox y > 1)
    handles = [mpatches.Patch(color='yellow', label='Ground truth'),
               mpatches.Patch(color='lime',   label='Baseline prediction'),
               mpatches.Patch(color='cyan',   label='TTD prediction')]
    fig.legend(handles=handles, loc='upper right',
               bbox_to_anchor=(0.97, 0.90), frameon=True)

    plt.tight_layout(rect=[0,0,0.94,0.97])
    plt.show()


## Task 9 : Entropy–penalised loss (epistemic consistency)

In [None]:
#  Training with entropy–penalised loss (epistemic consistency)


# ---------------- hyper–parameters --------------------------------------
K_PASS       = 8    # dropout forward passes per mini-batch
LAMBDA_ENT   = 0.2     # weight of entropy penalty
N_EPOCHS     = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------- model, optimiser, base-loss ------------------------------
model_unc = UNetWithDropout(in_channels=3, out_channels=1).to(device)
base_crit  = CombinedLoss(weight_dice=0.7, weight_bce=0.3)
optim      = torch.optim.Adam(model_unc.parameters(), lr=1e-4)

# ------------- helper : entropy of mean-probabilities --------------------
def entropy_from_mean(p_mean, eps=1e-12):
    p_mean = p_mean.clamp(eps, 1. - eps)
    return -(p_mean * p_mean.log() + (1 - p_mean) * (1 - p_mean).log())

# ------------- training loop --------------------------------------------
for epoch in range(1, N_EPOCHS + 1):
    model_unc.train()
    running_loss = 0.0
    for imgs, gts in tqdm(train_loader, desc=f"Epoch {epoch}/{N_EPOCHS}"):
        imgs, gts = imgs.to(device), gts.to(device)

        # K stochastic forward passes ------------------------------------
        probs = []
        for _ in range(K_PASS):
            logits = model_unc(imgs)            # dropout active
            probs.append(torch.sigmoid(logits))
        probs = torch.stack(probs, dim=0)        # [K,B,1,H,W]

        # mean prediction & base loss ------------------------------------
        p_mean = probs.mean(dim=0)               # [B,1,H,W]
        base_loss = base_crit(p_mean, gts)

        # entropy penalty -------------------------------------------------
        ent_map   = entropy_from_mean(p_mean)
        ent_loss  = ent_map.mean()

        loss = base_loss + LAMBDA_ENT * ent_loss

        # backward --------------------------------------------------------
        optim.zero_grad()
        loss.backward()
        optim.step()

        running_loss += loss.item()
    dice_val = torch.tensor(dice_vals).mean().item()
    print(f"Epoch {epoch:02d} | loss = {running_loss/len(train_loader):.4f} | val-Dice={dice_val:.4f}")

print("Finished uncertainty-aware training.  You can now evaluate the new model.")
torch.save(model.state_dict(), "segmentation/unet_segmentationUncertainty.pth")


In [None]:
early_stopping = EarlyStopping(patience=5, min_delta=0.001)
# ------------ hyper-params & loaders ------------------------------------
K_PASS       = 8
LAMBDA_ENT   = 0.2
N_EPOCHS     = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_unc = UNetWithDropout(in_channels=3, out_channels=1).to(device)
base_crit  = CombinedLoss(weight_dice=0.7, weight_bce=0.3)
optim      = torch.optim.Adam(model_unc.parameters(), lr=1e-4)

def entropy_from_mean(p_mean, eps=1e-12):
    p = p_mean.clamp(eps, 1. - eps)
    return -(p * p.log() + (1 - p) * (1 - p).log())

# ---------------- training loop -----------------------------------------
for epoch in range(1, N_EPOCHS + 1):
    model_unc.train()
    loss_accum = 0.0

    for imgs, gts in tqdm(train_loader, desc=f"Epoch {epoch}/{N_EPOCHS}"):
        imgs, gts = imgs.to(device), gts.to(device)

        # K stochastic passes
        probs = torch.stack(
            [torch.sigmoid(model_unc(imgs)) for _ in range(K_PASS)], dim=0
        )                                # [K,B,1,H,W]

        p_mean   = probs.mean(dim=0)
        base_loss = base_crit(p_mean, gts)
        ent_loss  = entropy_from_mean(p_mean).mean()
        loss      = base_loss + LAMBDA_ENT * ent_loss

        optim.zero_grad(); loss.backward(); optim.step()
        loss_accum += loss.item()

    # ---------- validation Dice -----------------------------------------
    model_unc.eval()
    dice_vals = []
    with torch.no_grad():
        for imgs, gts in val_loader:
            imgs, gts = imgs.to(device), gts.to(device)
            p_mean = torch.sigmoid(model_unc(imgs))
            dice_vals.append(
                dice_score(p_mean, gts, threshold=0.5)
            )
    dice_val = torch.tensor(dice_vals).mean().item()

    print(f"Epoch {epoch:02d} | loss={loss_accum/len(train_loader):.4f}"
          f" | val-Dice={dice_val:.4f}")
    if early_stopping.check_stop(avg_val_loss):
        print("Early stopping triggered.")
        torch.save(model.state_dict(), "segmentation/unet_segmentationSEUncertainty.pth")
        break

print("Training finished; model `model_unc` now contains epistemic penalty.")
torch.save(model.state_dict(), "segmentation/unet_segmentationUncertainty.pth")


In [None]:
# Entropy-penalised training + correct val-Dice + early stopping
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

# hyper-params -------------------------------------------------------------
K_PASS     = 8
LAMBDA_ENT = 0.2
N_EPOCHS   = 30
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_unc  = UNetWithDropout(in_channels=3, out_channels=1).to(device)
base_crit  = CombinedLoss(weight_dice=0.7, weight_bce=0.3)
optim      = torch.optim.Adam(model_unc.parameters(), lr=1e-4)

def entropy_from_mean(p_mean, eps=1e-12):
    p = p_mean.clamp(eps, 1. - eps)
    return -(p * p.log() + (1 - p) * (1 - p).log())

# training loop ------------------------------------------------------------
for epoch in range(1, N_EPOCHS + 1):
    model_unc.train()
    train_loss_accum = 0.0

    for imgs, gts in tqdm(train_loader, desc=f"Epoch {epoch}/{N_EPOCHS}"):
        imgs, gts = imgs.to(device), gts.to(device)

        # K stochastic forward passes (dropout ON)
        probs = torch.stack(
            [torch.sigmoid(model_unc(imgs)) for _ in range(K_PASS)], dim=0
        )                                   # [K,B,1,H,W]

        p_mean   = probs.mean(dim=0)
        base_loss = base_crit(p_mean, gts)
        ent_loss  = entropy_from_mean(p_mean).mean()
        loss      = base_loss + LAMBDA_ENT * ent_loss

        optim.zero_grad(); loss.backward(); optim.step()
        train_loss_accum += loss.item()

    # ---------- validation ------------------------------------------------
    model_unc.eval()
    val_dice_vals, val_loss_vals = [], []

    with torch.no_grad():
        for imgs, gts in val_loader:
            imgs, gts = imgs.to(device), gts.to(device)

            logits = model_unc(imgs)                           # dropout OFF
            probs  = torch.sigmoid(logits)

            val_loss_vals.append(base_crit(probs, gts).item())


            val_dice_vals.append(dice_score(logits, gts, threshold=0.5))

    avg_val_loss = np.mean(val_loss_vals)                     
    avg_val_dice = np.mean(val_dice_vals)

    print(f"Epoch {epoch:02d} | train-loss={train_loss_accum/len(train_loader):.4f}"
          f" | val-loss={avg_val_loss:.4f} | val-Dice={avg_val_dice:.4f}")

    # early stopping -------------------------------------------------------
    if early_stopping.check_stop(avg_val_loss):
        print("Early stopping triggered.")
        torch.save(model_unc.state_dict(),
                   "segmentation/unet_segmentation_uncertaintySE.pth")    
        break

# save final state (if ES nevyplo skôr) ------------------------------------
torch.save(model_unc.state_dict(),
           "segmentation/unet_segmentation_uncertainty.pth")
print("Training finished; model saved.")
