In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ============================
# Imports
# ============================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms as T
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import random
from PIL import Image, ImageOps


# ============================
# Dataset Classes
# ============================
class XBDMulticlassDataset(Dataset):
    def __init__(self, root_dir, transform_pre=None, transform_post=None,
                 image_size=(1024, 1024), tile_size=(256, 256), max_images=None):
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        self.files = sorted([f for f in os.listdir(self.image_dir) if '_pre_disaster.png' in f])
        if max_images:
            self.files = self.files[:max_images]
        self.transform_pre = transform_pre
        self.transform_post = transform_post
        self.image_size = image_size
        self.tile_size = tile_size
        self.tiles_per_image = (image_size[0] // tile_size[0]) * (image_size[1] // tile_size[1])

    def __getitem__(self, idx, return_raw=False):
        pre_file = self.files[idx]
        post_file = pre_file.replace('pre', 'post')
        mask_file = post_file.replace('.png', '_rgb.png')
        pre_path = os.path.join(self.image_dir, pre_file)
        post_path = os.path.join(self.image_dir, post_file)
        mask_path = os.path.join(self.mask_dir, mask_file)

        pre_img_raw = Image.open(pre_path).convert("RGB").resize(self.image_size)
        post_img_raw = Image.open(post_path).convert("RGB").resize(self.image_size)
        mask_img_raw = Image.open(mask_path).convert("RGB").resize(self.image_size, Image.NEAREST)

        def optical_to_sar_like(img):
            img = img.convert('L')
            img = ImageOps.autocontrast(img, cutoff=2)
            return img

        def update_mask_multiclass(mask_rgb_img):
            mask_np = np.array(mask_rgb_img)
            label_mask = np.zeros(mask_np.shape[:2], dtype=np.uint8)
            color_to_label = {
                (0, 0, 0): 0, (0, 255, 255): 0, (0, 0, 255): 1,
                (255, 255, 0): 2, (255, 0, 0): 3, (211, 211, 211): 0
            }
            for rgb, label in color_to_label.items():
                mask = np.all(mask_np == rgb, axis=-1)
                label_mask[mask] = label
            return label_mask

        post_img_sar_raw = optical_to_sar_like(post_img_raw)
        if return_raw:
            mask_np = update_mask_multiclass(mask_img_raw)
            return pre_img_raw, post_img_sar_raw, mask_np

        pre_img = self.transform_pre(pre_img_raw) if self.transform_pre else pre_img_raw
        post_img_sar = self.transform_post(post_img_sar_raw) if self.transform_post else post_img_sar_raw

        input_tensor = torch.cat([pre_img, post_img_sar], dim=0)
        mask_tensor = torch.tensor(update_mask_multiclass(mask_img_raw), dtype=torch.long)

        return tile_tensor_and_mask(input_tensor, mask_tensor, self.tile_size)

    def __len__(self):
        return len(self.files)

    def get_tile_dataset(self):
        return TiledXBDDataset(self)

class TiledXBDDataset(Dataset):
    def __init__(self, parent_dataset):
        self.parent_dataset = parent_dataset
        self.tiles_per_image = parent_dataset.tiles_per_image

    def __getitem__(self, idx):
        image_idx = idx // self.tiles_per_image
        tile_idx = idx % self.tiles_per_image
        tiles_input, tiles_mask = self.parent_dataset[image_idx]
        return tiles_input[tile_idx], tiles_mask[tile_idx]

    def __len__(self):
        return len(self.parent_dataset) * self.tiles_per_image

class FilteredTileDataset(Dataset):
    def __init__(self, tiled_dataset, keep_zero_damage_prob=0.1):
        self.tiled_dataset = tiled_dataset
        self.filtered_indices = []
        for idx in range(len(tiled_dataset)):
            _, mask = tiled_dataset[idx]
            if mask.max() == 0:
                if random.random() < keep_zero_damage_prob:
                    self.filtered_indices.append(idx)
            else:
                self.filtered_indices.append(idx)

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        return self.tiled_dataset[self.filtered_indices[idx]]

from torch.utils.data import Subset

def get_majority_class_per_tile(tile_mask):
    # Ignore padding or empty masks
    return torch.bincount(tile_mask.flatten()).argmax().item()

def stratified_tile_sample(dataset, total_samples=500):
    class_indices = {0: [], 1: [], 2: [], 3: []}

    print("Scanning tile classes for stratified sampling...")
    for idx in tqdm(range(len(dataset))):
        _, mask = dataset[idx]
        majority_class = get_majority_class_per_tile(mask)
        if majority_class in class_indices:
            class_indices[majority_class].append(idx)

    # Count all
    total_tiles = sum(len(v) for v in class_indices.values())
    proportions = {cls: len(v)/total_tiles for cls, v in class_indices.items()}
    print("Class proportions:", proportions)

    # Sample proportional counts
    sampled_indices = []
    for cls in range(4):
        n_samples = int(proportions[cls] * total_samples)
        n_available = len(class_indices[cls])
        chosen = random.sample(class_indices[cls], min(n_samples, n_available))
        sampled_indices.extend(chosen)

    return Subset(dataset, sampled_indices)

# ============================
# Helper Functions
# ============================
def tile_tensor_and_mask(input_tensor, mask_tensor, tile_size=(256, 256)):
    C, H, W = input_tensor.shape
    th, tw = tile_size
    tiles_input, tiles_mask = [], []
    for i in range(0, H, th):
        for j in range(0, W, tw):
            tiles_input.append(input_tensor[:, i:i+th, j:j+tw])
            tiles_mask.append(mask_tensor[i:i+th, j:j+tw])
    return tiles_input, tiles_mask

# ============================
# Visualization
# ============================
def unnormalize(img_tensor, mean, std):
    for t, m, s in zip(img_tensor, mean, std):
        t.mul_(s).add_(m)
    return img_tensor

def visualize_predictions(inputs, masks, outputs, n=4):
    preds = torch.argmax(outputs, dim=1).cpu().numpy()
    masks = masks.cpu().numpy()
    cmap = mcolors.ListedColormap(['black', 'blue', 'yellow', 'red'])
    bounds = [0, 1, 2, 3, 4]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    for i in range(min(n, inputs.shape[0])):
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        pre_rgb = unnormalize(inputs[i][:3].clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        axs[0].imshow(np.transpose(torch.clamp(pre_rgb, 0, 1).cpu().numpy(), (1, 2, 0)))
        axs[0].set_title("Pre-disaster RGB")
        axs[0].axis('off')
        axs[1].imshow(inputs[i][3].cpu().numpy(), cmap='gray')
        axs[1].set_title("Post-disaster SAR")
        axs[1].axis('off')
        axs[2].imshow(preds[i], cmap=cmap, norm=norm)
        axs[2].set_title("Predicted")
        axs[2].axis('off')
        plt.show()
        plt.figure()
        plt.imshow(masks[i], cmap=cmap, norm=norm)
        plt.title("Ground Truth")
        plt.axis('off')
        plt.show()

print("Saved all")

class UNetOriginal(nn.Module):
    def __init__(self, in_channels=4, out_classes=4):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = conv_block(512, 1024)

        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        # Output
        self.final = nn.Conv2d(64, out_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.final(d1)
model_path = 'C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/model_epoch7.pt'
model = UNetOriginal(in_channels=4, out_classes=4)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
print("model loaded successfully")


from torch.utils.data import DataLoader

# Transforms (same as training)
transform_pre = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_post = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

# Load full xBD dataset
xbd_root = r"C:\Users\sweta\.cache\kagglehub\datasets\qianlanzz\xbd-dataset\versions\1\xbd\tier1"
xbd_dataset = XBDMulticlassDataset(
    xbd_root, 
    transform_pre=transform_pre, 
    transform_post=transform_post,
    image_size=(1024, 1024),
    tile_size=(256, 256)
)

# Convert to per-tile dataset for evaluation
xbd_tile_dataset = xbd_dataset.get_tile_dataset()
xbd_loader = DataLoader(xbd_tile_dataset, batch_size=8, shuffle=False)

def evaluate_model_on_xbd(model, dataloader, device, label, save_root):
    model.eval()
    model.to(device)
    
    results = []
    save_dir = os.path.join(save_root, f"xbd_{label}")
    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(dataloader, desc=f"Evaluating on xBD: {label}")):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for j in range(x.size(0)):
                acc = compute_accuracy(y_pred[j], y[j])
                iou = compute_iou(y_pred[j], y[j])
                results.append({
                    "tile_index": i * x.size(0) + j,
                    "accuracy": acc,
                    "iou": iou
                })

    # Save CSV
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(save_dir, f"xbd_{label}_metrics.csv"), index=False)

    # Summary
    print(f"\n[SUMMARY: xBD {label}]")
    print(df.describe())

from torch.utils.data import Subset
from tqdm import tqdm
xbd_tile_dataset = xbd_dataset.get_tile_dataset()
stratified_subset = stratified_tile_sample(xbd_tile_dataset, total_samples=500)
xbd_loader = DataLoader(stratified_subset, batch_size=8, shuffle=False)


#goal: Test how the model behaves differently once it's fine tuned (are there any patterns?) also test how well the different models perform on LIC_test. LIC_test does have a damage mask as somewhat ground truth so use that.
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

# ============ Dataset =============
class LICTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        x = torch.cat([pre_img, post_img], dim=0)
        return x, mask, fname

# ============ Metrics ============
def compute_iou(pred, target, num_classes=4):
    pred = pred.flatten()
    target = target.flatten()
    iou_per_class = []
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            iou_per_class.append(np.nan)
        else:
            iou_per_class.append(intersection / union)
    return np.nanmean(iou_per_class)

def compute_dice(pred, target, class_id):
    """
    Compute Dice coefficient for a single class.
    pred, target: 1D or 2D torch tensors with class labels.
    class_id: int, the class to compute Dice for.
    """
    pred_inds = (pred == class_id)
    target_inds = (target == class_id)

    intersection = (pred_inds & target_inds).sum().item()
    pred_sum = pred_inds.sum().item()
    target_sum = target_inds.sum().item()

    if pred_sum + target_sum == 0:
        return float('nan')  # no samples in pred or target for this class

    dice = (2 * intersection) / (pred_sum + target_sum)
    return dice

def compute_precision(pred, target, class_id):
    """
    Compute Precision for a single class.
    """
    pred_inds = (pred == class_id)
    target_inds = (target == class_id)

    tp = (pred_inds & target_inds).sum().item()
    fp = (pred_inds & (~target_inds)).sum().item()

    if tp + fp == 0:
        return float('nan')

    precision = tp / (tp + fp)
    return precision

def compute_recall(pred, target, class_id):
    """
    Compute Recall for a single class.
    """
    pred_inds = (pred == class_id)
    target_inds = (target == class_id)

    tp = (pred_inds & target_inds).sum().item()
    fn = ((~pred_inds) & target_inds).sum().item()

    if tp + fn == 0:
        return float('nan')

    recall = tp / (tp + fn)
    return recall
    
def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

def compute_all_metrics(pred, target, num_classes=4):
    pred = pred.flatten().cpu().numpy()
    target = target.flatten().cpu().numpy()

    metrics = {'iou': [], 'precision': [], 'recall': [], 'dice': []}

    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls

        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        tp = intersection
        fp = np.logical_and(pred_inds, ~target_inds).sum()
        fn = np.logical_and(~pred_inds, target_inds).sum()

        iou = intersection / union if union > 0 else np.nan
        precision = tp / (tp + fp) if (tp + fp) > 0 else np.nan
        recall = tp / (tp + fn) if (tp + fn) > 0 else np.nan
        dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else np.nan

        metrics['iou'].append(iou)
        metrics['precision'].append(precision)
        metrics['recall'].append(recall)
        metrics['dice'].append(dice)

    return metrics

def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

# ============ Visualization ============
def save_visualization(pre_img, post_img, pred_mask, true_mask, fname, save_dir):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))

    axs[0].imshow(pre_img.permute(1, 2, 0).cpu())
    axs[0].set_title("Pre-disaster")
    axs[0].axis("off")

    axs[1].imshow(post_img.squeeze(0).cpu(), cmap='gray')
    axs[1].set_title("Post SAR")
    axs[1].axis("off")

    axs[2].imshow(true_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
    axs[2].set_title("Ground Truth")
    axs[2].axis("off")

    axs[3].imshow(pred_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
    axs[3].set_title("Prediction")
    axs[3].axis("off")

    os.makedirs(save_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, fname.replace('.png', '_viz.png')))
    plt.close()

# ============ Evaluation ============
def evaluate_and_visualize(model, dataloader, device, label, save_root, num_classes=4):
    model.eval()
    model.to(device)

    results = []
    per_class_metrics = {cls: {'iou': [], 'precision': [], 'recall': [], 'dice': []} for cls in range(num_classes)}
    vis_dir = os.path.join(save_root, label)
    os.makedirs(vis_dir, exist_ok=True)

    with torch.no_grad():
        for x, y, fname in tqdm(dataloader, desc=f"Evaluating {label}"):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for i in range(len(y)):
                fname_i = fname[i]
                acc = compute_accuracy(y_pred[i], y[i])
                metrics = compute_all_metrics(y_pred[i], y[i], num_classes=num_classes)

                for cls in range(num_classes):
                    for metric in ['iou', 'precision', 'recall', 'dice']:
                        per_class_metrics[cls][metric].append(metrics[metric][cls])

                results.append({
                    "filename": fname_i,
                    "accuracy": acc,
                    "mean_iou": np.nanmean(metrics['iou']),
                    "mean_precision": np.nanmean(metrics['precision']),
                    "mean_recall": np.nanmean(metrics['recall']),
                    "mean_dice": np.nanmean(metrics['dice']),
                })

                save_visualization(x[i, :3], x[i, 3:], y_pred[i], y[i], fname_i, vis_dir)

    # Save metrics
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(vis_dir, f"{label}_metrics.csv"), index=False)

    # Print overall summary
    print(f"\n[SUMMARY: {label}]")
    print(df.describe())

    # Per-class summary
    print("\n[Per-Class Metrics]")
    for cls in range(num_classes):
        print(f"Class {cls}:")
        for metric in ['iou', 'precision', 'recall', 'dice']:
            values = np.array(per_class_metrics[cls][metric])
            print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
    
# ============ Load and Run ============

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(path):
    model = UNetOriginal(in_channels=4, out_classes=4)
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
    return model

# Paths
base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
save_root = os.path.join(base_dir, "LIC_pseudo", "results")

baseline_path = os.path.join(base_dir, "iteration_models", "model_epoch8.pt")
half_path = os.path.join(base_dir, "iteration_models", "lic_half_finetunedmodel.pt")
full_path = os.path.join(base_dir, "iteration_models", "lic_full_finetunedmodel.pt")

# Load models
baseline = load_model(baseline_path)
half = load_model(half_path)
full = load_model(full_path)

# Dataset
test_dir = os.path.join(base_dir, "LIC_pseudo", "test")
test_dataset = LICTestDataset(os.path.join(test_dir, "pre"),
                              os.path.join(test_dir, "post"),
                              os.path.join(test_dir, "mask"))
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Evaluate all 3
bright_save_root = os.path.join(base_dir, "LIC_pseudo", "results", "bright")
evaluate_and_visualize(baseline, test_loader, device, label="baseline_updated", save_root=bright_save_root)
evaluate_and_visualize(half, test_loader, device, label="lic_half_updated", save_root=bright_save_root)
evaluate_and_visualize(full, test_loader, device, label="lic_full_updated", save_root=bright_save_root)
xbd_save_root = os.path.join(base_dir, "LIC_pseudo", "results", "xbd")
evaluate_model_on_xbd(baseline, xbd_loader, device, label="baseline_updated", save_root=xbd_save_root)
evaluate_model_on_xbd(half, xbd_loader, device, label="lic_half_updated", save_root=xbd_save_root)
evaluate_model_on_xbd(full, xbd_loader, device, label="lic_full_updated", save_root=xbd_save_root)

print("LIC COMPLETE") #we move on to mic
# ============ Dataset =============
class MICTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        x = torch.cat([pre_img, post_img], dim=0)
        return x, mask, fname


base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
mic_base = os.path.join(base_dir, "MIC_pseudo")
save_root = os.path.join(mic_base, "results")

baseline_path = os.path.join(base_dir, "iteration_models", "model_epoch8.pt")
half_path = os.path.join(base_dir, "iteration_models", "mic_half_finetunedmodel.pt")
full_path = os.path.join(base_dir, "iteration_models", "mic_full_finetunedmodel.pt")

# Load models
baseline = load_model(baseline_path)
half = load_model(half_path)
full = load_model(full_path)

# MIC Dataset
test_dir = os.path.join(mic_base, "test")
test_dataset = MICTestDataset(os.path.join(test_dir, "pre"),
                              os.path.join(test_dir, "post"),
                              os.path.join(test_dir, "mask"))
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Evaluate all 3 models on MIC test set
bright_save_root = os.path.join(mic_base, "results", "bright")
evaluate_and_visualize(baseline, test_loader, device, label="baseline_updated", save_root=bright_save_root)
evaluate_and_visualize(half, test_loader, device, label="mic_half_updated", save_root=bright_save_root)
evaluate_and_visualize(full, test_loader, device, label="mic_full_updated", save_root=bright_save_root)

# Optional: Evaluate on MIC-style xBD data (if xbd_loader is defined)
mic_xbd_save_root = os.path.join(mic_base, "results", "xbd")
evaluate_model_on_xbd(baseline, xbd_loader, device, label="baseline_updated", save_root=mic_xbd_save_root)
evaluate_model_on_xbd(half, xbd_loader, device, label="mic_half_updated", save_root=mic_xbd_save_root)
evaluate_model_on_xbd(full, xbd_loader, device, label="mic_full_updated", save_root=mic_xbd_save_root)

print("DONE WITH MIC") #move on to hic
class HICTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        x = torch.cat([pre_img, post_img], dim=0)
        return x, mask, fname


# Paths
base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
hic_base = os.path.join(base_dir, "HIC_pseudo")
save_root = os.path.join(hic_base, "results")

baseline_path = os.path.join(base_dir, "iteration_models", "model_epoch8.pt")
half_path = os.path.join(base_dir, "iteration_models", "hic_half_finetunedmodel.pt")
full_path = os.path.join(base_dir, "iteration_models", "hic_full_finetunedmodel.pt")

# Load models
baseline = load_model(baseline_path)
half = load_model(half_path)
full = load_model(full_path)

# HIC Dataset
test_dir = os.path.join(hic_base, "test")
test_dataset = HICTestDataset(os.path.join(test_dir, "pre"),
                              os.path.join(test_dir, "post"),
                              os.path.join(test_dir, "mask"))
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Evaluate all 3 models on HIC test set
bright_save_root = os.path.join(hic_base, "results", "bright")
evaluate_and_visualize(baseline, test_loader, device, label="baseline_updated", save_root=bright_save_root)
evaluate_and_visualize(half, test_loader, device, label="hic_half_updated", save_root=bright_save_root)
evaluate_and_visualize(full, test_loader, device, label="hic_full_updated", save_root=bright_save_root)

# Optional: Evaluate on HIC-style xBD data (if xbd_loader is defined)
hic_xbd_save_root = os.path.join(hic_base, "results", "xbd")
evaluate_model_on_xbd(baseline, xbd_loader, device, label="baseline_updated", save_root=hic_xbd_save_root)
evaluate_model_on_xbd(half, xbd_loader, device, label="hic_half_updated", save_root=hic_xbd_save_root)
evaluate_model_on_xbd(full, xbd_loader, device, label="hic_full_updated", save_root=hic_xbd_save_root)

In [3]:
#gem ver
# ============================
# Imports
# ============================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
from torchvision import transforms as T
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import random
from PIL import Image, ImageOps
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

print("All libraries imported successfully.")

# ============================
# Model Definition (U-Net)
# ============================
class UNetOriginal(nn.Module):
    def __init__(self, in_channels=4, out_classes=4):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = conv_block(512, 1024)

        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        # Output
        self.final = nn.Conv2d(64, out_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.final(d1)

# ============================
# Helper & Metric Functions
# ============================
def tile_tensor_and_mask(input_tensor, mask_tensor, tile_size=(256, 256)):
    C, H, W = input_tensor.shape
    th, tw = tile_size
    tiles_input, tiles_mask = [], []
    for i in range(0, H, th):
        for j in range(0, W, tw):
            tiles_input.append(input_tensor[:, i:i+th, j:j+tw])
            tiles_mask.append(mask_tensor[i:i+th, j:j+tw])
    return tiles_input, tiles_mask

def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

def compute_all_metrics(pred, target, num_classes=4):
    pred_flat = pred.flatten().cpu().numpy()
    target_flat = target.flatten().cpu().numpy()
    metrics = {'iou': [], 'precision': [], 'recall': [], 'dice': []}

    for cls in range(num_classes):
        pred_inds = pred_flat == cls
        target_inds = target_flat == cls
        
        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        tp = intersection
        fp = np.logical_and(pred_inds, ~target_inds).sum()
        fn = np.logical_and(~pred_inds, target_inds).sum()

        iou = tp / union if union > 0 else np.nan
        precision = tp / (tp + fp) if (tp + fp) > 0 else np.nan
        recall = tp / (tp + fn) if (tp + fn) > 0 else np.nan
        dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else np.nan

        metrics['iou'].append(iou)
        metrics['precision'].append(precision)
        metrics['recall'].append(recall)
        metrics['dice'].append(dice)
        
    return metrics

# ============================
# Dataset Classes
# ============================
class XBDMulticlassDataset(Dataset):
    def __init__(self, root_dir, transform_pre=None, transform_post=None,
                 image_size=(1024, 1024), tile_size=(256, 256), max_images=None):
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        self.files = sorted([f for f in os.listdir(self.image_dir) if '_pre_disaster.png' in f])
        if max_images:
            self.files = self.files[:max_images]
        # DEBUG: Confirm number of images found
        print(f"[XBDDataset] Found {len(self.files)} pre-disaster images in {self.image_dir}")
        self.transform_pre = transform_pre
        self.transform_post = transform_post
        self.image_size = image_size
        self.tile_size = tile_size
        self.tiles_per_image = (image_size[0] // tile_size[0]) * (image_size[1] // tile_size[1])

    def __getitem__(self, idx):
        pre_file = self.files[idx]
        post_file = pre_file.replace('pre', 'post')
        mask_file = post_file.replace('.png', '_rgb.png')
        pre_path = os.path.join(self.image_dir, pre_file)
        post_path = os.path.join(self.image_dir, post_file)
        mask_path = os.path.join(self.mask_dir, mask_file)

        pre_img_raw = Image.open(pre_path).convert("RGB").resize(self.image_size)
        post_img_raw = Image.open(post_path).convert("RGB").resize(self.image_size)
        mask_img_raw = Image.open(mask_path).convert("RGB").resize(self.image_size, Image.NEAREST)

        def optical_to_sar_like(img):
            img = img.convert('L')
            img = ImageOps.autocontrast(img, cutoff=2)
            return img

        def update_mask_multiclass(mask_rgb_img):
            mask_np = np.array(mask_rgb_img)
            label_mask = np.zeros(mask_np.shape[:2], dtype=np.uint8)
            color_to_label = {
                (0, 0, 0): 0, (0, 255, 255): 0, (0, 0, 255): 1,
                (255, 255, 0): 2, (255, 0, 0): 3, (211, 211, 211): 0
            }
            for rgb, label in color_to_label.items():
                mask = np.all(mask_np == rgb, axis=-1)
                label_mask[mask] = label
            return label_mask

        post_img_sar_raw = optical_to_sar_like(post_img_raw)
        pre_img = self.transform_pre(pre_img_raw) if self.transform_pre else pre_img_raw
        post_img_sar = self.transform_post(post_img_sar_raw) if self.transform_post else post_img_sar_raw

        input_tensor = torch.cat([pre_img, post_img_sar], dim=0)
        mask_tensor = torch.tensor(update_mask_multiclass(mask_img_raw), dtype=torch.long)

        return tile_tensor_and_mask(input_tensor, mask_tensor, self.tile_size)

    def __len__(self):
        return len(self.files)

    def get_tile_dataset(self):
        return TiledXBDDataset(self)

class TiledXBDDataset(Dataset):
    def __init__(self, parent_dataset):
        self.parent_dataset = parent_dataset
        self.tiles_per_image = parent_dataset.tiles_per_image

    def __getitem__(self, idx):
        image_idx = idx // self.tiles_per_image
        tile_idx = idx % self.tiles_per_image
        tiles_input, tiles_mask = self.parent_dataset[image_idx]
        return tiles_input[tile_idx], tiles_mask[tile_idx]

    def __len__(self):
        return len(self.parent_dataset) * self.tiles_per_image

# Note: The three test datasets are identical. They could be refactored into one class.
class BaseTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir, dataset_name="Test"):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        # DEBUG: Confirm number of images found
        print(f"[{dataset_name} Dataset] Found {len(self.files)} images in {pre_dir}")
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        
        x = torch.cat([pre_img, post_img], dim=0)
        
        # DEBUG: Check tensor shapes for the first item
        if idx == 0:
            print(f"[{self.__class__.__name__}] Sample 0 shapes - Input: {x.shape}, Mask: {mask.shape}")
            
        return x, mask, fname

class LICTestDataset(BaseTestDataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        super().__init__(pre_dir, post_dir, mask_dir, dataset_name="LIC Test")

class MICTestDataset(BaseTestDataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        super().__init__(pre_dir, post_dir, mask_dir, dataset_name="MIC Test")

class HICTestDataset(BaseTestDataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        super().__init__(pre_dir, post_dir, mask_dir, dataset_name="HIC Test")
from torch.utils.data import Subset

def get_majority_class_per_tile(tile_mask):
    # Ignore padding or empty masks
    return torch.bincount(tile_mask.flatten()).argmax().item()

def stratified_tile_sample(dataset, total_samples=500):
    class_indices = {0: [], 1: [], 2: [], 3: []}

    print("Scanning tile classes for stratified sampling...")
    for idx in tqdm(range(len(dataset))):
        _, mask = dataset[idx]
        majority_class = get_majority_class_per_tile(mask)
        if majority_class in class_indices:
            class_indices[majority_class].append(idx)

    # Count all
    total_tiles = sum(len(v) for v in class_indices.values())
    proportions = {cls: len(v)/total_tiles for cls, v in class_indices.items()}
    print("Class proportions:", proportions)

    # Sample proportional counts
    sampled_indices = []
    for cls in range(4):
        n_samples = int(proportions[cls] * total_samples)
        n_available = len(class_indices[cls])
        chosen = random.sample(class_indices[cls], min(n_samples, n_available))
        sampled_indices.extend(chosen)

    return Subset(dataset, sampled_indices)

# ============================
# Evaluation & Visualization
# ============================
def save_visualization(pre_img, post_img, pred_mask, true_mask, fname, save_dir):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    cmap = mcolors.ListedColormap(['black', 'blue', 'yellow', 'red'])
    norm = mcolors.BoundaryNorm(np.arange(-0.5, 4, 1), cmap.N)

    axs[0].imshow(pre_img.permute(1, 2, 0).cpu().numpy())
    axs[0].set_title("Pre-disaster")
    axs[0].axis("off")

    axs[1].imshow(post_img.squeeze(0).cpu().numpy(), cmap='gray')
    axs[1].set_title("Post SAR")
    axs[1].axis("off")

    axs[2].imshow(true_mask.cpu().numpy(), cmap=cmap, norm=norm)
    axs[2].set_title("Ground Truth")
    axs[2].axis("off")

    axs[3].imshow(pred_mask.cpu().numpy(), cmap=cmap, norm=norm)
    axs[3].set_title("Prediction")
    axs[3].axis("off")

    os.makedirs(save_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, fname.replace('.png', '_viz.png')))
    plt.close()

def evaluate_and_visualize(model, dataloader, device, label, save_root, num_classes=4):
    print(f"\n{'='*20}\n[START EVAL] Model: '{label}' on custom test set.\n{'='*20}")
    model.eval()
    model.to(device)

    results = []
    per_class_metrics = {cls: {m: [] for m in ['iou', 'precision', 'recall', 'dice']} for cls in range(num_classes)}
    vis_dir = os.path.join(save_root, label)
    os.makedirs(vis_dir, exist_ok=True)
    # DEBUG: Show where results will be saved
    print(f"Saving visualizations and metrics to: {vis_dir}")

    with torch.no_grad():
        for i, (x, y, fname) in enumerate(tqdm(dataloader, desc=f"Evaluating {label}")):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)
            
            # DEBUG: Check tensor shapes inside the loop (only for the first batch)
            if i == 0:
                print(f"  Batch 0 shapes -> x: {x.shape}, y: {y.shape}, y_pred: {y_pred.shape}")

            for j in range(len(y)):
                fname_i = fname[j]
                pred_j, y_j = y_pred[j], y[j]
                
                acc = compute_accuracy(pred_j, y_j)
                metrics = compute_all_metrics(pred_j, y_j, num_classes=num_classes)

                for cls in range(num_classes):
                    for metric_name in per_class_metrics[cls]:
                        per_class_metrics[cls][metric_name].append(metrics[metric_name][cls])

                results.append({
                    "filename": fname_i, "accuracy": acc,
                    "mean_iou": np.nanmean(metrics['iou']),
                    "mean_precision": np.nanmean(metrics['precision']),
                    "mean_recall": np.nanmean(metrics['recall']),
                    "mean_dice": np.nanmean(metrics['dice']),
                })
                save_visualization(x[j, :3], x[j, 3:], pred_j, y_j, fname_i, vis_dir)

    df = pd.DataFrame(results)
    csv_path = os.path.join(vis_dir, f"{label}_metrics.csv")
    df.to_csv(csv_path, index=False)
    # DEBUG: Confirm CSV saved
    print(f"Metrics saved to {csv_path}")

    print(f"\n[SUMMARY: {label}]")
    print(df.describe())
    print("\n[Per-Class Metrics]")
    for cls in range(num_classes):
        print(f"Class {cls}:")
        for metric in ['iou', 'precision', 'recall', 'dice']:
            values = np.array(per_class_metrics[cls][metric])
            print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
    print(f"[END EVAL] Model: '{label}'.\n{'='*20}")

def evaluate_model_on_xbd(model, dataloader, device, label, save_root):
    print(f"\n{'='*20}\n[START EVAL] Model: '{label}' on xBD tiles.\n{'='*20}")
    model.eval()
    model.to(device)
    
    results = []
    save_dir = os.path.join(save_root, label)
    os.makedirs(save_dir, exist_ok=True)
    # DEBUG: Show where results will be saved
    print(f"Saving xBD metrics to: {save_dir}")
    
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(dataloader, desc=f"Evaluating on xBD: {label}")):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for j in range(x.size(0)):
                pred_j, y_j = y_pred[j], y[j]
                metrics = compute_all_metrics(pred_j, y_j, num_classes=4)
                results.append({
                    "tile_index": i * x.size(0) + j,
                    "accuracy": compute_accuracy(pred_j, y_j),
                    "mean_iou": np.nanmean(metrics['iou']),
                    "mean_precision": np.nanmean(metrics['precision']),
                    "mean_recall": np.nanmean(metrics['recall']),
                    "mean_dice": np.nanmean(metrics['dice']),
})

    df = pd.DataFrame(results)
    csv_path = os.path.join(save_dir, f"xbd_{label}_metrics.csv")
    df.to_csv(csv_path, index=False)
    # DEBUG: Confirm CSV saved
    print(f"xBD metrics saved to {csv_path}")

    print(f"\n[SUMMARY: xBD {label}]")
    print(df.describe())
    print(f"[END EVAL] xBD Model: '{label}'.\n{'='*20}")

# ============================
# Main Execution
# ============================
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Step 1: Define Paths ---
    base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
    
    # Model paths
    iteration_models_dir = os.path.join(base_dir, "iteration_models")
    baseline_path = os.path.join(iteration_models_dir, "model_epoch8.pt")
    
    lic_half_path = os.path.join(iteration_models_dir, "lic_half_finetunedmodel.pt")
    lic_full_path = os.path.join(iteration_models_dir, "lic_full_finetunedmodel.pt")
    
    mic_half_path = os.path.join(iteration_models_dir, "mic_half_finetunedmodel.pt")
    mic_full_path = os.path.join(iteration_models_dir, "mic_full_finetunedmodel.pt")
    
    hic_half_path = os.path.join(iteration_models_dir, "hic_half_finetunedmodel.pt")
    hic_full_path = os.path.join(iteration_models_dir, "hic_full_finetunedmodel.pt")
    
    # Data paths
    xbd_root = r"C:\Users\sweta\.cache\kagglehub\datasets\qianlanzz\xbd-dataset\versions\1\xbd\tier1"
    lic_test_dir = os.path.join(base_dir, "LIC_pseudo", "test")
    mic_test_dir = os.path.join(base_dir, "MIC_pseudo", "test")
    hic_test_dir = os.path.join(base_dir, "HIC_pseudo", "test")

    # Results paths
    lic_results_root = os.path.join(base_dir, "LIC_pseudo", "results")
    mic_results_root = os.path.join(base_dir, "MIC_pseudo", "results")
    hic_results_root = os.path.join(base_dir, "HIC_pseudo", "results")

    # --- Step 2: Load Models ---
    def load_model_from_path(path):
        print(f"Loading model from: {path}")
        model = UNetOriginal(in_channels=4, out_classes=4)
        if not os.path.exists(path):
            print(f"  [ERROR] Model path does not exist: {path}")
            return None
        checkpoint = torch.load(path, map_location=device)
        # Handle both dict and checkpoint objects
        if 'model_state_dict' in checkpoint:
            # DEBUG
            print("  Loading from a checkpoint object...")
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # DEBUG
            print("  Loading from a raw state dictionary...")
            model.load_state_dict(checkpoint)
        return model

    models_to_test = {
        "LIC": {
            "baseline": load_model_from_path(baseline_path),
            "half_finetuned": load_model_from_path(lic_half_path),
            "full_finetuned": load_model_from_path(lic_full_path),
        },
        "MIC": {
            "baseline": load_model_from_path(baseline_path),
            "half_finetuned": load_model_from_path(mic_half_path),
            "full_finetuned": load_model_from_path(mic_full_path),
        },
        "HIC": {
            "baseline": load_model_from_path(baseline_path),
            "half_finetuned": load_model_from_path(hic_half_path),
            "full_finetuned": load_model_from_path(hic_full_path),
        }
    }
    
    # --- Step 3: Prepare DataLoaders ---
    # Create xBD loader (used for all evaluations)
    transform_pre = T.Compose([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    transform_post = T.Compose([T.ToTensor(), T.Normalize(mean=[0.5], std=[0.5])])
    from torch.utils.data import Subset
    from tqdm import tqdm
    
    xbd_tile_dataset = xbd_dataset.get_tile_dataset()
    stratified_subset = stratified_tile_sample(xbd_tile_dataset, total_samples=500)
    xbd_loader = DataLoader(stratified_subset, batch_size=8, shuffle=False)


    # Create LIC, MIC, HIC loaders
    lic_test_dataset = LICTestDataset(os.path.join(lic_test_dir, "pre"), os.path.join(lic_test_dir, "post"), os.path.join(lic_test_dir, "mask"))
    lic_loader = DataLoader(lic_test_dataset, batch_size=2, shuffle=False)
    
    mic_test_dataset = MICTestDataset(os.path.join(mic_test_dir, "pre"), os.path.join(mic_test_dir, "post"), os.path.join(mic_test_dir, "mask"))
    mic_loader = DataLoader(mic_test_dataset, batch_size=2, shuffle=False)

    hic_test_dataset = HICTestDataset(os.path.join(hic_test_dir, "pre"), os.path.join(hic_test_dir, "post"), os.path.join(hic_test_dir, "mask"))
    hic_loader = DataLoader(hic_test_dataset, batch_size=2, shuffle=False)

    # --- Step 4: Run Evaluations ---
    
    # Evaluate on LIC
    print("\n\n--- Starting LIC Evaluation ---")
    for name, model in models_to_test["LIC"].items():
        if model:
            evaluate_and_visualize(model, lic_loader, device, label=f"lic_{name}_updated", save_root=lic_results_root)
            evaluate_model_on_xbd(model, xbd_loader, device, label=f"lic_{name}_updated", save_root=os.path.join(lic_results_root, "xbd_eval"))

    # Evaluate on MIC
    print("\n\n--- Starting MIC Evaluation ---")
    for name, model in models_to_test["MIC"].items():
        if model:
            evaluate_and_visualize(model, mic_loader, device, label=f"mic_{name}_updated", save_root=mic_results_root)
            evaluate_model_on_xbd(model, xbd_loader, device, label=f"mic_{name}_updated", save_root=os.path.join(mic_results_root, "xbd_eval"))
            
    # Evaluate on HIC
    print("\n\n--- Starting HIC Evaluation ---")
    for name, model in models_to_test["HIC"].items():
        if model:
            evaluate_and_visualize(model, hic_loader, device, label=f"hic_{name}_updated", save_root=hic_results_root)
            evaluate_model_on_xbd(model, xbd_loader, device, label=f"hic_{name}_updated", save_root=os.path.join(hic_results_root, "xbd_eval"))

    print("\n\nAll evaluations complete!")

All libraries imported successfully.
Using device: cpu
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\model_epoch8.pt


  checkpoint = torch.load(path, map_location=device)


  Loading from a checkpoint object...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\lic_half_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\lic_full_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\model_epoch8.pt
  Loading from a checkpoint object...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\mic_half_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\mic_full_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\model_epoch8.pt
  Loading from a checkpoin

100%|██████████| 44784/44784 [4:19:40<00:00,  2.87it/s]  


Class proportions: {0: 0.9993971061093248, 1: 0.0, 2: 0.00040192926045016077, 3: 0.00020096463022508038}
[LIC Test Dataset] Found 59 images in C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\test\pre
[MIC Test Dataset] Found 542 images in C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\test\pre
[HIC Test Dataset] Found 324 images in C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\test\pre


--- Starting LIC Evaluation ---

[START EVAL] Model: 'lic_baseline_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\lic_baseline_updated


Evaluating lic_baseline_updated:   0%|          | 0/30 [00:00<?, ?it/s]

[LICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating lic_baseline_updated: 100%|██████████| 30/30 [01:31<00:00,  3.05s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\lic_baseline_updated\lic_baseline_updated_metrics.csv

[SUMMARY: lic_baseline_updated]
        accuracy   mean_iou  mean_precision  mean_recall  mean_dice
count  59.000000  59.000000       59.000000    59.000000  59.000000
mean    0.283409   0.071592        0.283409     0.251412   0.104448
std     0.183508   0.047194        0.183508     0.010849   0.051990
min     0.051620   0.012905        0.051620     0.250000   0.024543
25%     0.162010   0.040503        0.162010     0.250000   0.069696
50%     0.242752   0.060688        0.242752     0.250000   0.097667
75%     0.361473   0.090368        0.361473     0.250000   0.132749
max     0.938889   0.234722        0.938889     0.333333   0.242120

[Per-Class Metrics]
Class 0:
        iou: mean=0.2834, std=0.1819
  precision: mean=0.2834, std=0.1819
     recall: mean=1.0000, std=0.0000
       dice: mean=0.4139, std=0.1988
Class 1:
        iou: mea

Evaluating on xBD: lic_baseline_updated: 100%|██████████| 63/63 [09:51<00:00,  9.39s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\xbd_eval\lic_baseline_updated\xbd_lic_baseline_updated_metrics.csv

[SUMMARY: xBD lic_baseline_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.993480    0.927033
std    142.971455    0.028016    0.180515
min      0.000000    0.638916    0.301753
25%    124.500000    1.000000    1.000000
50%    246.000000    1.000000    1.000000
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'lic_baseline_updated'.

[START EVAL] Model: 'lic_half_finetuned_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\lic_half_finetuned_updated


Evaluating lic_half_finetuned_updated:   0%|          | 0/30 [00:00<?, ?it/s]

[LICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating lic_half_finetuned_updated: 100%|██████████| 30/30 [01:27<00:00,  2.90s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\lic_half_finetuned_updated\lic_half_finetuned_updated_metrics.csv

[SUMMARY: lic_half_finetuned_updated]
        accuracy   mean_iou  mean_precision  mean_recall  mean_dice
count  59.000000  59.000000       59.000000    59.000000  59.000000
mean    0.294068   0.094204        0.275408     0.266111   0.142525
std     0.188199   0.065494        0.143277     0.036655   0.079478
min     0.054626   0.013865        0.057266     0.223151   0.026316
25%     0.190033   0.056984        0.175018     0.249620   0.098921
50%     0.252228   0.081585        0.232745     0.255830   0.133875
75%     0.359863   0.114970        0.373167     0.267138   0.173821
max     0.944916   0.336438        0.623771     0.480077   0.440837

[Per-Class Metrics]
Class 0:
        iou: mean=0.2895, std=0.1939
  precision: mean=0.2935, std=0.1949
     recall: mean=0.9445, std=0.0645
       dice: mean=0.4183, std=0.2076
Class 1

Evaluating on xBD: lic_half_finetuned_updated: 100%|██████████| 63/63 [09:40<00:00,  9.22s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\xbd_eval\lic_half_finetuned_updated\xbd_lic_half_finetuned_updated_metrics.csv

[SUMMARY: xBD lic_half_finetuned_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.931178    0.579562
std    142.971455    0.087627    0.255206
min      0.000000    0.595581    0.198882
25%    124.500000    0.884468    0.428734
50%    246.000000    0.971649    0.489563
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'lic_half_finetuned_updated'.

[START EVAL] Model: 'lic_full_finetuned_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\lic_full_finetuned_updated


Evaluating lic_full_finetuned_updated:   0%|          | 0/30 [00:00<?, ?it/s]

[LICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating lic_full_finetuned_updated: 100%|██████████| 30/30 [01:29<00:00,  2.98s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\lic_full_finetuned_updated\lic_full_finetuned_updated_metrics.csv

[SUMMARY: lic_full_finetuned_updated]
        accuracy   mean_iou  mean_precision  mean_recall  mean_dice
count  59.000000  59.000000       59.000000    59.000000  59.000000
mean    0.301318   0.125132        0.300623     0.305014   0.192537
std     0.186440   0.082789        0.134228     0.067107   0.098970
min     0.060089   0.017511        0.088565     0.196622   0.033603
25%     0.198608   0.079668        0.216208     0.262060   0.135147
50%     0.246765   0.105552        0.252275     0.291691   0.176179
75%     0.344200   0.133955        0.364894     0.325581   0.218985
max     0.960709   0.482361        0.698291     0.614091   0.629608

[Per-Class Metrics]
Class 0:
        iou: mean=0.2898, std=0.2094
  precision: mean=0.3152, std=0.2145
     recall: mean=0.7722, std=0.1854
       dice: mean=0.4147, std=0.2166
Class 1

Evaluating on xBD: lic_full_finetuned_updated: 100%|██████████| 63/63 [09:42<00:00,  9.25s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\xbd_eval\lic_full_finetuned_updated\xbd_lic_full_finetuned_updated_metrics.csv

[SUMMARY: xBD lic_full_finetuned_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.903064    0.507916
std    142.971455    0.120363    0.271105
min      0.000000    0.328400    0.109467
25%    124.500000    0.846092    0.309293
50%    246.000000    0.957962    0.458839
75%    370.500000    0.998451    0.499466
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'lic_full_finetuned_updated'.


--- Starting MIC Evaluation ---

[START EVAL] Model: 'mic_baseline_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\mic_baseline_updated


Evaluating mic_baseline_updated:   0%|          | 0/271 [00:00<?, ?it/s]

[MICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating mic_baseline_updated: 100%|██████████| 271/271 [13:46<00:00,  3.05s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\mic_baseline_updated\mic_baseline_updated_metrics.csv

[SUMMARY: mic_baseline_updated]
         accuracy    mean_iou  mean_precision  mean_recall   mean_dice
count  542.000000  542.000000      542.000000   542.000000  542.000000
mean     0.226407    0.057182        0.224352     0.251053    0.086865
std      0.168111    0.045031        0.167086     0.015584    0.047979
min      0.054092    0.013523        0.054092     0.244237    0.025658
25%      0.132263    0.033161        0.129459     0.250000    0.058555
50%      0.175690    0.044037        0.174339     0.250000    0.074883
75%      0.225693    0.056423        0.225220     0.250000    0.092067
max      0.939484    0.428413        0.939484     0.500000    0.461447

[Per-Class Metrics]
Class 0:
        iou: mean=0.2264, std=0.1680
  precision: mean=0.2264, std=0.1680
     recall: mean=0.9999, std=0.0013
       dice: mean=0.3448, std=0.180

Evaluating on xBD: mic_baseline_updated: 100%|██████████| 63/63 [09:48<00:00,  9.35s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\xbd_eval\mic_baseline_updated\xbd_mic_baseline_updated_metrics.csv

[SUMMARY: xBD mic_baseline_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.993480    0.927033
std    142.971455    0.028016    0.180515
min      0.000000    0.638916    0.301753
25%    124.500000    1.000000    1.000000
50%    246.000000    1.000000    1.000000
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'mic_baseline_updated'.

[START EVAL] Model: 'mic_half_finetuned_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\mic_half_finetuned_updated


Evaluating mic_half_finetuned_updated:   0%|          | 0/271 [00:00<?, ?it/s]

[MICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating mic_half_finetuned_updated: 100%|██████████| 271/271 [14:40<00:00,  3.25s/it]


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\mic_half_finetuned_updated\mic_half_finetuned_updated_metrics.csv

[SUMMARY: mic_half_finetuned_updated]
         accuracy    mean_iou  mean_precision  mean_recall   mean_dice
count  542.000000  542.000000      542.000000   542.000000  542.000000
mean     0.410042    0.217675        0.395924     0.365476    0.327700
std      0.156702    0.089139        0.095287     0.088788    0.102814
min      0.086456    0.035357        0.048285     0.116818    0.063627
25%      0.308361    0.152318        0.329385     0.297070    0.252156
50%      0.372131    0.199885        0.379253     0.345783    0.322192
75%      0.466103    0.271478        0.454546     0.419043    0.389846
max      0.949951    0.539870        0.710470     0.657706    0.665917

[Per-Class Metrics]
Class 0:
        iou: mean=0.2437, std=0.2596
  precision: mean=0.3579, std=0.2547
     recall: mean=0.3796, std=0.2964
       dice: mean

Evaluating on xBD: mic_half_finetuned_updated: 100%|██████████| 63/63 [09:47<00:00,  9.33s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\xbd_eval\mic_half_finetuned_updated\xbd_mic_half_finetuned_updated_metrics.csv

[SUMMARY: xBD mic_half_finetuned_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.930939    0.480239
std    142.971455    0.105581    0.322872
min      0.000000    0.223526    0.055882
25%    124.500000    0.891884    0.243214
50%    246.000000    0.982285    0.327428
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'mic_half_finetuned_updated'.

[START EVAL] Model: 'mic_full_finetuned_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\mic_full_finetuned_updated


Evaluating mic_full_finetuned_updated:   0%|          | 0/271 [00:00<?, ?it/s]

[MICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating mic_full_finetuned_updated: 100%|██████████| 271/271 [14:43<00:00,  3.26s/it]


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\mic_full_finetuned_updated\mic_full_finetuned_updated_metrics.csv

[SUMMARY: mic_full_finetuned_updated]
         accuracy    mean_iou  mean_precision  mean_recall   mean_dice
count  542.000000  542.000000      542.000000   542.000000  542.000000
mean     0.424408    0.227494        0.414654     0.384240    0.336345
std      0.163272    0.096737        0.097523     0.090900    0.108763
min      0.083313    0.028886        0.100769     0.126869    0.053068
25%      0.323505    0.161167        0.351422     0.317036    0.261593
50%      0.387703    0.206747        0.397895     0.367667    0.325611
75%      0.477337    0.270679        0.460838     0.436399    0.400122
max      0.974945    0.627752        0.726900     0.866346    0.679618

[Per-Class Metrics]
Class 0:
        iou: mean=0.2272, std=0.2761
  precision: mean=0.3687, std=0.2738
     recall: mean=0.3228, std=0.3119
       dice: mean

Evaluating on xBD: mic_full_finetuned_updated: 100%|██████████| 63/63 [10:08<00:00,  9.66s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\xbd_eval\mic_full_finetuned_updated\xbd_mic_full_finetuned_updated_metrics.csv

[SUMMARY: xBD mic_full_finetuned_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.919561    0.439374
std    142.971455    0.116963    0.326413
min      0.000000    0.162079    0.040520
25%    124.500000    0.876343    0.224776
50%    246.000000    0.976074    0.247887
75%    370.500000    0.999962    0.499981
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'mic_full_finetuned_updated'.


--- Starting HIC Evaluation ---

[START EVAL] Model: 'hic_baseline_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\hic_baseline_updated


Evaluating hic_baseline_updated:   0%|          | 0/162 [00:00<?, ?it/s]

[HICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating hic_baseline_updated: 100%|██████████| 162/162 [09:19<00:00,  3.45s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\hic_baseline_updated\hic_baseline_updated_metrics.csv

[SUMMARY: hic_baseline_updated]
         accuracy    mean_iou  mean_precision  mean_recall   mean_dice
count  324.000000  324.000000      324.000000   324.000000  324.000000
mean     0.189436    0.048179        0.189436     0.251800    0.074942
std      0.157130    0.042315        0.157130     0.012135    0.046880
min      0.037201    0.009300        0.037201     0.250000    0.017933
25%      0.103283    0.025833        0.103283     0.250000    0.046827
50%      0.133553    0.033522        0.133553     0.250000    0.059117
75%      0.190083    0.047829        0.190083     0.250000    0.080295
max      0.953156    0.293605        0.953156     0.333333    0.312210

[Per-Class Metrics]
Class 0:
        iou: mean=0.1894, std=0.1569
  precision: mean=0.1894, std=0.1569
     recall: mean=1.0000, std=0.0000
       dice: mean=0.2958, std=0.174

Evaluating on xBD: hic_baseline_updated: 100%|██████████| 63/63 [10:09<00:00,  9.68s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\xbd_eval\hic_baseline_updated\xbd_hic_baseline_updated_metrics.csv

[SUMMARY: xBD hic_baseline_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.993480    0.927033
std    142.971455    0.028016    0.180515
min      0.000000    0.638916    0.301753
25%    124.500000    1.000000    1.000000
50%    246.000000    1.000000    1.000000
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'hic_baseline_updated'.

[START EVAL] Model: 'hic_half_finetuned_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\hic_half_finetuned_updated


Evaluating hic_half_finetuned_updated:   0%|          | 0/162 [00:00<?, ?it/s]

[HICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating hic_half_finetuned_updated: 100%|██████████| 162/162 [09:18<00:00,  3.45s/it]


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\hic_half_finetuned_updated\hic_half_finetuned_updated_metrics.csv

[SUMMARY: hic_half_finetuned_updated]
         accuracy    mean_iou  mean_precision  mean_recall   mean_dice
count  324.000000  324.000000      324.000000   324.000000  324.000000
mean     0.311788    0.138416        0.312632     0.288649    0.217823
std      0.151756    0.064056        0.116038     0.061405    0.084997
min      0.023987    0.014913        0.024635     0.142597    0.028783
25%      0.222157    0.097995        0.255706     0.249266    0.163913
50%      0.286385    0.133473        0.314959     0.270027    0.220273
75%      0.368786    0.172026        0.375450     0.323148    0.276830
max      0.949799    0.349564        0.693091     0.498353    0.433355

[Per-Class Metrics]
Class 0:
        iou: mean=0.1676, std=0.1990
  precision: mean=0.2157, std=0.2062
     recall: mean=0.4258, std=0.3123
       dice: mean

Evaluating on xBD: hic_half_finetuned_updated: 100%|██████████| 63/63 [10:22<00:00,  9.89s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\xbd_eval\hic_half_finetuned_updated\xbd_hic_half_finetuned_updated_metrics.csv

[SUMMARY: xBD hic_half_finetuned_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.899604    0.370501
std    142.971455    0.118905    0.135559
min      0.000000    0.306320    0.076580
25%    124.500000    0.842590    0.287359
50%    246.000000    0.951202    0.326502
75%    370.500000    0.986031    0.491871
max    495.000000    1.000000    1.000000
[END EVAL] xBD Model: 'hic_half_finetuned_updated'.

[START EVAL] Model: 'hic_full_finetuned_updated' on custom test set.
Saving visualizations and metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\hic_full_finetuned_updated


Evaluating hic_full_finetuned_updated:   0%|          | 0/162 [00:00<?, ?it/s]

[HICTestDataset] Sample 0 shapes - Input: torch.Size([4, 256, 256]), Mask: torch.Size([256, 256])
  Batch 0 shapes -> x: torch.Size([2, 4, 256, 256]), y: torch.Size([2, 256, 256]), y_pred: torch.Size([2, 256, 256])


Evaluating hic_full_finetuned_updated: 100%|██████████| 162/162 [09:58<00:00,  3.70s/it]


Metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\hic_full_finetuned_updated\hic_full_finetuned_updated_metrics.csv

[SUMMARY: hic_full_finetuned_updated]
         accuracy    mean_iou  mean_precision  mean_recall   mean_dice
count  324.000000  324.000000      324.000000   324.000000  324.000000
mean     0.372926    0.149571        0.343983     0.284949    0.226709
std      0.148302    0.060775        0.105393     0.060237    0.076342
min      0.018570    0.018217        0.044550     0.162084    0.035290
25%      0.269974    0.108306        0.273467     0.247382    0.180241
50%      0.364014    0.145942        0.335164     0.270537    0.228320
75%      0.469788    0.185887        0.422839     0.312967    0.277120
max      0.891113    0.335161        0.670525     0.505018    0.415908

[Per-Class Metrics]
Class 0:
        iou: mean=0.1137, std=0.2175
  precision: mean=0.2853, std=0.2778
     recall: mean=0.1540, std=0.2487
       dice: mean

Evaluating on xBD: hic_full_finetuned_updated: 100%|██████████| 63/63 [10:38<00:00, 10.14s/it]

xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\xbd_eval\hic_full_finetuned_updated\xbd_hic_full_finetuned_updated_metrics.csv

[SUMMARY: xBD hic_full_finetuned_updated]
       tile_index    accuracy    mean_iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.688153    0.193252
std    142.971455    0.174964    0.062264
min      0.000000    0.029541    0.007385
25%    124.500000    0.561707    0.152918
50%    246.000000    0.692703    0.185143
75%    370.500000    0.830841    0.226228
max    495.000000    0.986847    0.328949
[END EVAL] xBD Model: 'hic_full_finetuned_updated'.


All evaluations complete!





In [6]:
import os
import numpy as np
import pandas as pd
import scipy.stats as st

def compute_summary_stats(values, metric_name="", model_label=""):
    values = np.array(values)
    values = values[~np.isnan(values)]  # remove NaNs

    print(f"  → [{model_label}] {metric_name}: {len(values)} valid entries")

    if len(values) == 0:
        print(f"    [WARNING] No valid entries for {metric_name} in {model_label}")
        return {
            "mean": np.nan,
            "median": np.nan,
            "ci95_low": np.nan,
            "ci95_high": np.nan
        }

    mean = np.mean(values)
    median = np.median(values)
    ci95 = st.t.interval(
        confidence=0.95,
        df=len(values) - 1,
        loc=mean,
        scale=st.sem(values) if len(values) > 1 else 0
    )

    return {
        "mean": mean,
        "median": median,
        "ci95_low": ci95[0],
        "ci95_high": ci95[1]
    }

def summarize_bright_metrics(results_root, region_label):
    print(f"\n=== [Summary for {region_label}] ===")
    print(f"[INFO] Root directory: {results_root}")
    
    models = ["baseline", "half_finetuned", "full_finetuned"]
    metrics = ["mean_iou", "mean_precision", "mean_recall", "mean_dice"]
    summary = {}

    for model in models:
        label = f"{region_label.lower()}_{model}_updated"
        csv_path = os.path.join(results_root, label, f"{label}_metrics.csv")
        
        #print(f"\n[DEBUG] Looking for: {csv_path}")
        if not os.path.exists(csv_path):
            #print(f"  [!] Missing file: {csv_path}")
            continue

        #print(f"  [+] Loading CSV for model: {model.upper()}")

        df = pd.read_csv(csv_path)
        #print(f"    → CSV loaded. Rows: {len(df)}, Columns: {list(df.columns)}")

        model_stats = {}
        for metric in metrics:
            if metric not in df.columns:
                #print(f"    [WARNING] Metric '{metric}' not found in CSV!")
                continue
            stats = compute_summary_stats(df[metric], metric_name=metric, model_label=model)
            model_stats[metric] = stats

        summary[model] = model_stats

    # Pretty print results
    for model, stats in summary.items():
        print(f"\n-- {model.upper()} --")
        for metric, values in stats.items():
            print(f"{metric:>13}: mean={values['mean']:.4f}, median={values['median']:.4f}, 95% CI=({values['ci95_low']:.4f}, {values['ci95_high']:.4f})")

    return summary


# After all evaluations complete
lic_summary = summarize_bright_metrics(lic_results_root, "LIC")
mic_summary = summarize_bright_metrics(mic_results_root, "MIC")
hic_summary = summarize_bright_metrics(hic_results_root, "HIC")


=== [Summary for LIC] ===
[INFO] Root directory: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results
  → [baseline] mean_iou: 59 valid entries
  → [baseline] mean_precision: 59 valid entries
  → [baseline] mean_recall: 59 valid entries
  → [baseline] mean_dice: 59 valid entries
  → [half_finetuned] mean_iou: 59 valid entries
  → [half_finetuned] mean_precision: 59 valid entries
  → [half_finetuned] mean_recall: 59 valid entries
  → [half_finetuned] mean_dice: 59 valid entries
  → [full_finetuned] mean_iou: 59 valid entries
  → [full_finetuned] mean_precision: 59 valid entries
  → [full_finetuned] mean_recall: 59 valid entries
  → [full_finetuned] mean_dice: 59 valid entries

-- BASELINE --
     mean_iou: mean=0.0716, median=0.0607, 95% CI=(0.0593, 0.0839)
mean_precision: mean=0.2834, median=0.2428, 95% CI=(0.2356, 0.3312)
  mean_recall: mean=0.2514, median=0.2500, 95% CI=(0.2486, 0.2542)
    mean_dice: mean=0.1044, median=0.0977, 95% CI=(0.0909, 0.1180)

-

In [8]:
#evaluate xBD only
#gem ver
# ============================
# Imports
# ============================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
from torchvision import transforms as T
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import random
from PIL import Image, ImageOps
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

print("All libraries imported successfully.")

# ============================
# Model Definition (U-Net)
# ============================
class UNetOriginal(nn.Module):
    def __init__(self, in_channels=4, out_classes=4):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = conv_block(512, 1024)

        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        # Output
        self.final = nn.Conv2d(64, out_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.final(d1)

# ============================
# Helper & Metric Functions
# ============================
def tile_tensor_and_mask(input_tensor, mask_tensor, tile_size=(256, 256)):
    C, H, W = input_tensor.shape
    th, tw = tile_size
    tiles_input, tiles_mask = [], []
    for i in range(0, H, th):
        for j in range(0, W, tw):
            tiles_input.append(input_tensor[:, i:i+th, j:j+tw])
            tiles_mask.append(mask_tensor[i:i+th, j:j+tw])
    return tiles_input, tiles_mask

def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

def compute_all_metrics(pred, target, num_classes=4):
    pred_flat = pred.flatten().cpu().numpy()
    target_flat = target.flatten().cpu().numpy()
    metrics = {'iou': [], 'precision': [], 'recall': [], 'dice': []}

    for cls in range(num_classes):
        pred_inds = pred_flat == cls
        target_inds = target_flat == cls
        
        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        tp = intersection
        fp = np.logical_and(pred_inds, ~target_inds).sum()
        fn = np.logical_and(~pred_inds, target_inds).sum()

        iou = tp / union if union > 0 else np.nan
        precision = tp / (tp + fp) if (tp + fp) > 0 else np.nan
        recall = tp / (tp + fn) if (tp + fn) > 0 else np.nan
        dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else np.nan

        metrics['iou'].append(iou)
        metrics['precision'].append(precision)
        metrics['recall'].append(recall)
        metrics['dice'].append(dice)
        
    return metrics

# ============================
# Dataset Classes
# ============================
class XBDMulticlassDataset(Dataset):
    def __init__(self, root_dir, transform_pre=None, transform_post=None,
                 image_size=(1024, 1024), tile_size=(256, 256), max_images=None):
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        self.files = sorted([f for f in os.listdir(self.image_dir) if '_pre_disaster.png' in f])
        if max_images:
            self.files = self.files[:max_images]
        # DEBUG: Confirm number of images found
        print(f"[XBDDataset] Found {len(self.files)} pre-disaster images in {self.image_dir}")
        self.transform_pre = transform_pre
        self.transform_post = transform_post
        self.image_size = image_size
        self.tile_size = tile_size
        self.tiles_per_image = (image_size[0] // tile_size[0]) * (image_size[1] // tile_size[1])

    def __getitem__(self, idx):
        pre_file = self.files[idx]
        post_file = pre_file.replace('pre', 'post')
        mask_file = post_file.replace('.png', '_rgb.png')
        pre_path = os.path.join(self.image_dir, pre_file)
        post_path = os.path.join(self.image_dir, post_file)
        mask_path = os.path.join(self.mask_dir, mask_file)

        pre_img_raw = Image.open(pre_path).convert("RGB").resize(self.image_size)
        post_img_raw = Image.open(post_path).convert("RGB").resize(self.image_size)
        mask_img_raw = Image.open(mask_path).convert("RGB").resize(self.image_size, Image.NEAREST)

        def optical_to_sar_like(img):
            img = img.convert('L')
            img = ImageOps.autocontrast(img, cutoff=2)
            return img

        def update_mask_multiclass(mask_rgb_img):
            mask_np = np.array(mask_rgb_img)
            label_mask = np.zeros(mask_np.shape[:2], dtype=np.uint8)
            color_to_label = {
                (0, 0, 0): 0, (0, 255, 255): 0, (0, 0, 255): 1,
                (255, 255, 0): 2, (255, 0, 0): 3, (211, 211, 211): 0
            }
            for rgb, label in color_to_label.items():
                mask = np.all(mask_np == rgb, axis=-1)
                label_mask[mask] = label
            return label_mask

        post_img_sar_raw = optical_to_sar_like(post_img_raw)
        pre_img = self.transform_pre(pre_img_raw) if self.transform_pre else pre_img_raw
        post_img_sar = self.transform_post(post_img_sar_raw) if self.transform_post else post_img_sar_raw

        input_tensor = torch.cat([pre_img, post_img_sar], dim=0)
        mask_tensor = torch.tensor(update_mask_multiclass(mask_img_raw), dtype=torch.long)

        return tile_tensor_and_mask(input_tensor, mask_tensor, self.tile_size)

    def __len__(self):
        return len(self.files)

    def get_tile_dataset(self):
        return TiledXBDDataset(self)

class TiledXBDDataset(Dataset):
    def __init__(self, parent_dataset):
        self.parent_dataset = parent_dataset
        self.tiles_per_image = parent_dataset.tiles_per_image

    def __getitem__(self, idx):
        image_idx = idx // self.tiles_per_image
        tile_idx = idx % self.tiles_per_image
        tiles_input, tiles_mask = self.parent_dataset[image_idx]
        return tiles_input[tile_idx], tiles_mask[tile_idx]

    def __len__(self):
        return len(self.parent_dataset) * self.tiles_per_image

# Note: The three test datasets are identical. They could be refactored into one class.
class BaseTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir, dataset_name="Test"):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        # DEBUG: Confirm number of images found
        print(f"[{dataset_name} Dataset] Found {len(self.files)} images in {pre_dir}")
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        
        x = torch.cat([pre_img, post_img], dim=0)
        
        # DEBUG: Check tensor shapes for the first item
        if idx == 0:
            print(f"[{self.__class__.__name__}] Sample 0 shapes - Input: {x.shape}, Mask: {mask.shape}")
            
        return x, mask, fname

class LICTestDataset(BaseTestDataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        super().__init__(pre_dir, post_dir, mask_dir, dataset_name="LIC Test")

class MICTestDataset(BaseTestDataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        super().__init__(pre_dir, post_dir, mask_dir, dataset_name="MIC Test")

class HICTestDataset(BaseTestDataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        super().__init__(pre_dir, post_dir, mask_dir, dataset_name="HIC Test")
from torch.utils.data import Subset

def get_majority_class_per_tile(tile_mask):
    # Ignore padding or empty masks
    return torch.bincount(tile_mask.flatten()).argmax().item()

def stratified_tile_sample(dataset, total_samples=500):
    class_indices = {0: [], 1: [], 2: [], 3: []}

    print("Scanning tile classes for stratified sampling...")
    for idx in tqdm(range(len(dataset))):
        _, mask = dataset[idx]
        majority_class = get_majority_class_per_tile(mask)
        if majority_class in class_indices:
            class_indices[majority_class].append(idx)

    # Count all
    total_tiles = sum(len(v) for v in class_indices.values())
    proportions = {cls: len(v)/total_tiles for cls, v in class_indices.items()}
    print("Class proportions:", proportions)

    # Sample proportional counts
    sampled_indices = []
    for cls in range(4):
        n_samples = int(proportions[cls] * total_samples)
        n_available = len(class_indices[cls])
        chosen = random.sample(class_indices[cls], min(n_samples, n_available))
        sampled_indices.extend(chosen)

    return Subset(dataset, sampled_indices)

# ============================
# Evaluation & Visualization
# ============================
def save_visualization(pre_img, post_img, pred_mask, true_mask, fname, save_dir):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    cmap = mcolors.ListedColormap(['black', 'blue', 'yellow', 'red'])
    norm = mcolors.BoundaryNorm(np.arange(-0.5, 4, 1), cmap.N)

    axs[0].imshow(pre_img.permute(1, 2, 0).cpu().numpy())
    axs[0].set_title("Pre-disaster")
    axs[0].axis("off")

    axs[1].imshow(post_img.squeeze(0).cpu().numpy(), cmap='gray')
    axs[1].set_title("Post SAR")
    axs[1].axis("off")

    axs[2].imshow(true_mask.cpu().numpy(), cmap=cmap, norm=norm)
    axs[2].set_title("Ground Truth")
    axs[2].axis("off")

    axs[3].imshow(pred_mask.cpu().numpy(), cmap=cmap, norm=norm)
    axs[3].set_title("Prediction")
    axs[3].axis("off")

    os.makedirs(save_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, fname.replace('.png', '_viz.png')))
    plt.close()

def evaluate_and_visualize(model, dataloader, device, label, save_root, num_classes=4):
    print(f"\n{'='*20}\n[START EVAL] Model: '{label}' on custom test set.\n{'='*20}")
    model.eval()
    model.to(device)

    results = []
    per_class_metrics = {cls: {m: [] for m in ['iou', 'precision', 'recall', 'dice']} for cls in range(num_classes)}
    vis_dir = os.path.join(save_root, label)
    os.makedirs(vis_dir, exist_ok=True)
    # DEBUG: Show where results will be saved
    print(f"Saving visualizations and metrics to: {vis_dir}")

    with torch.no_grad():
        for i, (x, y, fname) in enumerate(tqdm(dataloader, desc=f"Evaluating {label}")):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)
            
            # DEBUG: Check tensor shapes inside the loop (only for the first batch)
            if i == 0:
                print(f"  Batch 0 shapes -> x: {x.shape}, y: {y.shape}, y_pred: {y_pred.shape}")

            for j in range(len(y)):
                fname_i = fname[j]
                pred_j, y_j = y_pred[j], y[j]
                
                acc = compute_accuracy(pred_j, y_j)
                metrics = compute_all_metrics(pred_j, y_j, num_classes=num_classes)

                for cls in range(num_classes):
                    for metric_name in per_class_metrics[cls]:
                        per_class_metrics[cls][metric_name].append(metrics[metric_name][cls])

                results.append({
                    "filename": fname_i, "accuracy": acc,
                    "mean_iou": np.nanmean(metrics['iou']),
                    "mean_precision": np.nanmean(metrics['precision']),
                    "mean_recall": np.nanmean(metrics['recall']),
                    "mean_dice": np.nanmean(metrics['dice']),
                })
                save_visualization(x[j, :3], x[j, 3:], pred_j, y_j, fname_i, vis_dir)

    df = pd.DataFrame(results)
    csv_path = os.path.join(vis_dir, f"{label}_metrics.csv")
    df.to_csv(csv_path, index=False)
    # DEBUG: Confirm CSV saved
    print(f"Metrics saved to {csv_path}")

    print(f"\n[SUMMARY: {label}]")
    print(df.describe())
    print("\n[Per-Class Metrics]")
    for cls in range(num_classes):
        print(f"Class {cls}:")
        for metric in ['iou', 'precision', 'recall', 'dice']:
            values = np.array(per_class_metrics[cls][metric])
            print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
    print(f"[END EVAL] Model: '{label}'.\n{'='*20}")

def evaluate_model_on_xbd(model, dataloader, device, label, save_root):
    print(f"\n{'='*20}\n[START EVAL] Model: '{label}' on xBD tiles.\n{'='*20}")
    model.eval()
    model.to(device)
    
    results = []
    save_dir = os.path.join(save_root, label)
    os.makedirs(save_dir, exist_ok=True)
    # DEBUG: Show where results will be saved
    print(f"Saving xBD metrics to: {save_dir}")
    
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(dataloader, desc=f"Evaluating on xBD: {label}")):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for j in range(x.size(0)):
                pred_j, y_j = y_pred[j], y[j]
                metrics = compute_all_metrics(pred_j, y_j, num_classes=4)
                results.append({
                    "tile_index": i * x.size(0) + j,
                    "accuracy": compute_accuracy(pred_j, y_j),
                    "mean_iou": np.nanmean(metrics['iou']),
                    "mean_precision": np.nanmean(metrics['precision']),
                    "mean_recall": np.nanmean(metrics['recall']),
                    "mean_dice": np.nanmean(metrics['dice']),
})

    df = pd.DataFrame(results)
    csv_path = os.path.join(save_dir, f"xbd_{label}_metrics.csv")
    df.to_csv(csv_path, index=False)
    # DEBUG: Confirm CSV saved
    print(f"xBD metrics saved to {csv_path}")

    print(f"\n[SUMMARY: xBD {label}]")
    print(df.describe())
    print(f"[END EVAL] xBD Model: '{label}'.\n{'='*20}")

# ============================
# Main Execution
# ============================
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Step 1: Define Paths ---
    base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
    
    # Model paths
    iteration_models_dir = os.path.join(base_dir, "iteration_models")
    baseline_path = os.path.join(iteration_models_dir, "model_epoch8.pt")
    
    lic_half_path = os.path.join(iteration_models_dir, "lic_half_finetunedmodel.pt")
    lic_full_path = os.path.join(iteration_models_dir, "lic_full_finetunedmodel.pt")
    
    mic_half_path = os.path.join(iteration_models_dir, "mic_half_finetunedmodel.pt")
    mic_full_path = os.path.join(iteration_models_dir, "mic_full_finetunedmodel.pt")
    
    hic_half_path = os.path.join(iteration_models_dir, "hic_half_finetunedmodel.pt")
    hic_full_path = os.path.join(iteration_models_dir, "hic_full_finetunedmodel.pt")
    
    # Data paths
    xbd_root = r"C:\Users\sweta\.cache\kagglehub\datasets\qianlanzz\xbd-dataset\versions\1\xbd\tier1"
    lic_test_dir = os.path.join(base_dir, "LIC_pseudo", "test")
    mic_test_dir = os.path.join(base_dir, "MIC_pseudo", "test")
    hic_test_dir = os.path.join(base_dir, "HIC_pseudo", "test")

    # Results paths
    lic_results_root = os.path.join(base_dir, "LIC_pseudo", "results")
    mic_results_root = os.path.join(base_dir, "MIC_pseudo", "results")
    hic_results_root = os.path.join(base_dir, "HIC_pseudo", "results")

    # --- Step 2: Load Models ---
    def load_model_from_path(path):
        print(f"Loading model from: {path}")
        model = UNetOriginal(in_channels=4, out_classes=4)
        if not os.path.exists(path):
            print(f"  [ERROR] Model path does not exist: {path}")
            return None
        checkpoint = torch.load(path, map_location=device)
        # Handle both dict and checkpoint objects
        if 'model_state_dict' in checkpoint:
            # DEBUG
            print("  Loading from a checkpoint object...")
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # DEBUG
            print("  Loading from a raw state dictionary...")
            model.load_state_dict(checkpoint)
        return model

    models_to_test = {
        "LIC": {
            "baseline": load_model_from_path(baseline_path),
            "half_finetuned": load_model_from_path(lic_half_path),
            "full_finetuned": load_model_from_path(lic_full_path),
        },
        "MIC": {
            "baseline": load_model_from_path(baseline_path),
            "half_finetuned": load_model_from_path(mic_half_path),
            "full_finetuned": load_model_from_path(mic_full_path),
        },
        "HIC": {
            "baseline": load_model_from_path(baseline_path),
            "half_finetuned": load_model_from_path(hic_half_path),
            "full_finetuned": load_model_from_path(hic_full_path),
        }
    }
    
    # --- Step 3: Prepare DataLoaders ---
    # Create xBD loader (used for all evaluations)
    transform_pre = T.Compose([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    transform_post = T.Compose([T.ToTensor(), T.Normalize(mean=[0.5], std=[0.5])])
    from torch.utils.data import Subset
    from tqdm import tqdm
    
    xbd_tile_dataset = xbd_dataset.get_tile_dataset()
    stratified_subset = stratified_tile_sample(xbd_tile_dataset, total_samples=500)
    xbd_loader = DataLoader(stratified_subset, batch_size=8, shuffle=False)


    # Create LIC, MIC, HIC loaders
    lic_test_dataset = LICTestDataset(os.path.join(lic_test_dir, "pre"), os.path.join(lic_test_dir, "post"), os.path.join(lic_test_dir, "mask"))
    lic_loader = DataLoader(lic_test_dataset, batch_size=2, shuffle=False)
    
    mic_test_dataset = MICTestDataset(os.path.join(mic_test_dir, "pre"), os.path.join(mic_test_dir, "post"), os.path.join(mic_test_dir, "mask"))
    mic_loader = DataLoader(mic_test_dataset, batch_size=2, shuffle=False)

    hic_test_dataset = HICTestDataset(os.path.join(hic_test_dir, "pre"), os.path.join(hic_test_dir, "post"), os.path.join(hic_test_dir, "mask"))
    hic_loader = DataLoader(hic_test_dataset, batch_size=2, shuffle=False)

    # --- Step 4: Run Evaluations ---
    
    # Evaluate on LIC
    print("\n\n--- Starting LIC Evaluation ---")
    for name, model in models_to_test["LIC"].items():
        if model:
          #  evaluate_and_visualize(model, lic_loader, device, label=f"lic_{name}_updated", save_root=lic_results_root)
            evaluate_model_on_xbd(model, xbd_loader, device, label=f"lic_{name}_updated", save_root=os.path.join(lic_results_root, "xbd_eval"))

    # Evaluate on MIC
    print("\n\n--- Starting MIC Evaluation ---")
    for name, model in models_to_test["MIC"].items():
        if model:
           # evaluate_and_visualize(model, mic_loader, device, label=f"mic_{name}_updated", save_root=mic_results_root)
            evaluate_model_on_xbd(model, xbd_loader, device, label=f"mic_{name}_updated", save_root=os.path.join(mic_results_root, "xbd_eval"))
            
    # Evaluate on HIC
    print("\n\n--- Starting HIC Evaluation ---")
    for name, model in models_to_test["HIC"].items():
        if model:
           # evaluate_and_visualize(model, hic_loader, device, label=f"hic_{name}_updated", save_root=hic_results_root)
            evaluate_model_on_xbd(model, xbd_loader, device, label=f"hic_{name}_updated", save_root=os.path.join(hic_results_root, "xbd_eval"))

    print("\n\nAll evaluations complete!")

All libraries imported successfully.
Using device: cpu
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\model_epoch8.pt


  checkpoint = torch.load(path, map_location=device)


  Loading from a checkpoint object...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\lic_half_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\lic_full_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\model_epoch8.pt
  Loading from a checkpoint object...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\mic_half_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\mic_full_finetunedmodel.pt
  Loading from a raw state dictionary...
Loading model from: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\iteration_models\model_epoch8.pt
  Loading from a checkpoin

100%|██████████| 44784/44784 [4:00:34<00:00,  3.10it/s]  


Class proportions: {0: 0.9993971061093248, 1: 0.0, 2: 0.00040192926045016077, 3: 0.00020096463022508038}
[LIC Test Dataset] Found 59 images in C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\test\pre
[MIC Test Dataset] Found 542 images in C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\test\pre
[HIC Test Dataset] Found 324 images in C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\test\pre


--- Starting LIC Evaluation ---

[START EVAL] Model: 'lic_baseline_updated' on xBD tiles.
Saving xBD metrics to: C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\xbd_eval\lic_baseline_updated


Evaluating on xBD: lic_baseline_updated: 100%|██████████| 63/63 [09:52<00:00,  9.41s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\xbd_eval\lic_baseline_updated\xbd_lic_baseline_updated_metrics.csv

[SUMMARY: xBD lic_baseline_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.994949    0.915736        0.985393     0.923734   
std    142.971455    0.020293    0.192093        0.069192     0.181397   
min      0.000000    0.785706    0.284500        0.454786     0.333333   
25%    124.500000    1.000000    1.000000        1.000000     1.000000   
50%    246.000000    1.000000    1.000000        1.000000     1.000000   
75%    370.500000    1.000000    1.000000        1.000000     1.000000   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.918539  
std      0.188181  
min      0.306987  
25%      1.000000  
50%    

Evaluating on xBD: lic_half_finetuned_updated: 100%|██████████| 63/63 [09:48<00:00,  9.35s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\xbd_eval\lic_half_finetuned_updated\xbd_lic_half_finetuned_updated_metrics.csv

[SUMMARY: xBD lic_half_finetuned_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.932510    0.578811        0.625101     0.877617   
std    142.971455    0.092291    0.257956        0.230006     0.179626   
min      0.000000    0.462814    0.213852        0.326088     0.318739   
25%    124.500000    0.906540    0.427540        0.500000     0.837217   
50%    246.000000    0.972473    0.488045        0.500000     0.966034   
75%    370.500000    1.000000    1.000000        1.000000     1.000000   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.595327  
std      0.248065  
min      0.260548  
25%      

Evaluating on xBD: lic_full_finetuned_updated: 100%|██████████| 63/63 [09:51<00:00,  9.40s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\LIC_pseudo\results\xbd_eval\lic_full_finetuned_updated\xbd_lic_full_finetuned_updated_metrics.csv

[SUMMARY: xBD lic_full_finetuned_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.904028    0.503843        0.553856     0.856113   
std    142.971455    0.127911    0.266966        0.240428     0.190016   
min      0.000000    0.307449    0.102483        0.318583     0.286896   
25%    124.500000    0.856323    0.306836        0.333333     0.782639   
50%    246.000000    0.960052    0.464058        0.500000     0.948685   
75%    370.500000    0.998634    0.499847        0.547853     0.998596   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.522289  
std      0.256454  
min      0.156768  
25%      

Evaluating on xBD: mic_baseline_updated: 100%|██████████| 63/63 [09:18<00:00,  8.87s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\xbd_eval\mic_baseline_updated\xbd_mic_baseline_updated_metrics.csv

[SUMMARY: xBD mic_baseline_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.994949    0.915736        0.985393     0.923734   
std    142.971455    0.020293    0.192093        0.069192     0.181397   
min      0.000000    0.785706    0.284500        0.454786     0.333333   
25%    124.500000    1.000000    1.000000        1.000000     1.000000   
50%    246.000000    1.000000    1.000000        1.000000     1.000000   
75%    370.500000    1.000000    1.000000        1.000000     1.000000   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.918539  
std      0.188181  
min      0.306987  
25%      1.000000  
50%    

Evaluating on xBD: mic_half_finetuned_updated: 100%|██████████| 63/63 [09:15<00:00,  8.82s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\xbd_eval\mic_half_finetuned_updated\xbd_mic_half_finetuned_updated_metrics.csv

[SUMMARY: xBD mic_half_finetuned_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.929429    0.481271        0.509299     0.878730   
std    142.971455    0.120692    0.323391        0.310860     0.185297   
min      0.000000    0.167206    0.041801        0.235047     0.167206   
25%    124.500000    0.908470    0.242768        0.250000     0.840748   
50%    246.000000    0.984909    0.328064        0.333333     0.981186   
75%    370.500000    1.000000    1.000000        1.000000     1.000000   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.491852  
std      0.316338  
min      0.071627  
25%      

Evaluating on xBD: mic_full_finetuned_updated: 100%|██████████| 63/63 [09:15<00:00,  8.82s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\MIC_pseudo\results\xbd_eval\mic_full_finetuned_updated\xbd_mic_full_finetuned_updated_metrics.csv

[SUMMARY: xBD mic_full_finetuned_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.919244    0.443886        0.472069     0.863670   
std    142.971455    0.129163    0.331154        0.319027     0.196842   
min      0.000000    0.152405    0.038101        0.233430     0.152405   
25%    124.500000    0.887024    0.225168        0.250000     0.804151   
50%    246.000000    0.978958    0.248318        0.261962     0.972977   
75%    370.500000    1.000000    1.000000        1.000000     1.000000   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.453883  
std      0.324334  
min      0.066125  
25%      

Evaluating on xBD: hic_baseline_updated: 100%|██████████| 63/63 [09:16<00:00,  8.83s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\xbd_eval\hic_baseline_updated\xbd_hic_baseline_updated_metrics.csv

[SUMMARY: xBD hic_baseline_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.994949    0.915736        0.985393     0.923734   
std    142.971455    0.020293    0.192093        0.069192     0.181397   
min      0.000000    0.785706    0.284500        0.454786     0.333333   
25%    124.500000    1.000000    1.000000        1.000000     1.000000   
50%    246.000000    1.000000    1.000000        1.000000     1.000000   
75%    370.500000    1.000000    1.000000        1.000000     1.000000   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.918539  
std      0.188181  
min      0.306987  
25%      1.000000  
50%    

Evaluating on xBD: hic_half_finetuned_updated: 100%|██████████| 63/63 [09:14<00:00,  8.80s/it]


xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\xbd_eval\hic_half_finetuned_updated\xbd_hic_half_finetuned_updated_metrics.csv

[SUMMARY: xBD hic_half_finetuned_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.899208    0.370721        0.410756     0.851788   
std    142.971455    0.127586    0.131368        0.114799     0.185653   
min      0.000000    0.256775    0.064194        0.241307     0.244811   
25%    124.500000    0.857826    0.284344        0.333333     0.784988   
50%    246.000000    0.955307    0.327199        0.360690     0.945328   
75%    370.500000    0.986343    0.492542        0.500000     0.985924   
max    495.000000    1.000000    1.000000        1.000000     1.000000   

        mean_dice  
count  499.000000  
mean     0.387590  
std      0.122742  
min      0.102156  
25%      

Evaluating on xBD: hic_full_finetuned_updated: 100%|██████████| 63/63 [09:13<00:00,  8.79s/it]

xBD metrics saved to C:\Users\sweta\anaconda_projects\non-trivial\performance_bias\HIC_pseudo\results\xbd_eval\hic_full_finetuned_updated\xbd_hic_full_finetuned_updated_metrics.csv

[SUMMARY: xBD hic_full_finetuned_updated]
       tile_index    accuracy    mean_iou  mean_precision  mean_recall  \
count  499.000000  499.000000  499.000000      499.000000   499.000000   
mean   247.136273    0.688780    0.194814        0.281192     0.664373   
std    142.971455    0.185514    0.065283        0.045014     0.204337   
min      0.000000    0.007675    0.001919        0.230965     0.007675   
25%    124.500000    0.569176    0.151577        0.250000     0.523003   
50%    246.000000    0.705078    0.189602        0.250000     0.684250   
75%    370.500000    0.840645    0.236661        0.333333     0.837547   
max    495.000000    0.982803    0.336471        0.499666     0.982803   

        mean_dice  
count  499.000000  
mean     0.226622  
std      0.059072  
min      0.003808  
25%      




In [2]:
import os 
os.getcwd()

'C:\\Users\\sweta\\anaconda_projects\\non-trivial\\performance_bias'