In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ============================
# Imports
# ============================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms as T
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import random
from PIL import Image, ImageOps


# ============================
# Dataset Classes
# ============================
class XBDMulticlassDataset(Dataset):
    def __init__(self, root_dir, transform_pre=None, transform_post=None,
                 image_size=(1024, 1024), tile_size=(256, 256), max_images=None):
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        self.files = sorted([f for f in os.listdir(self.image_dir) if '_pre_disaster.png' in f])
        if max_images:
            self.files = self.files[:max_images]
        self.transform_pre = transform_pre
        self.transform_post = transform_post
        self.image_size = image_size
        self.tile_size = tile_size
        self.tiles_per_image = (image_size[0] // tile_size[0]) * (image_size[1] // tile_size[1])

    def __getitem__(self, idx, return_raw=False):
        pre_file = self.files[idx]
        post_file = pre_file.replace('pre', 'post')
        mask_file = post_file.replace('.png', '_rgb.png')
        pre_path = os.path.join(self.image_dir, pre_file)
        post_path = os.path.join(self.image_dir, post_file)
        mask_path = os.path.join(self.mask_dir, mask_file)

        pre_img_raw = Image.open(pre_path).convert("RGB").resize(self.image_size)
        post_img_raw = Image.open(post_path).convert("RGB").resize(self.image_size)
        mask_img_raw = Image.open(mask_path).convert("RGB").resize(self.image_size, Image.NEAREST)

        def optical_to_sar_like(img):
            img = img.convert('L')
            img = ImageOps.autocontrast(img, cutoff=2)
            return img

        def update_mask_multiclass(mask_rgb_img):
            mask_np = np.array(mask_rgb_img)
            label_mask = np.zeros(mask_np.shape[:2], dtype=np.uint8)
            color_to_label = {
                (0, 0, 0): 0, (0, 255, 255): 0, (0, 0, 255): 1,
                (255, 255, 0): 2, (255, 0, 0): 3, (211, 211, 211): 0
            }
            for rgb, label in color_to_label.items():
                mask = np.all(mask_np == rgb, axis=-1)
                label_mask[mask] = label
            return label_mask

        post_img_sar_raw = optical_to_sar_like(post_img_raw)
        if return_raw:
            mask_np = update_mask_multiclass(mask_img_raw)
            return pre_img_raw, post_img_sar_raw, mask_np

        pre_img = self.transform_pre(pre_img_raw) if self.transform_pre else pre_img_raw
        post_img_sar = self.transform_post(post_img_sar_raw) if self.transform_post else post_img_sar_raw

        input_tensor = torch.cat([pre_img, post_img_sar], dim=0)
        mask_tensor = torch.tensor(update_mask_multiclass(mask_img_raw), dtype=torch.long)

        return tile_tensor_and_mask(input_tensor, mask_tensor, self.tile_size)

    def __len__(self):
        return len(self.files)

    def get_tile_dataset(self):
        return TiledXBDDataset(self)

class TiledXBDDataset(Dataset):
    def __init__(self, parent_dataset):
        self.parent_dataset = parent_dataset
        self.tiles_per_image = parent_dataset.tiles_per_image

    def __getitem__(self, idx):
        image_idx = idx // self.tiles_per_image
        tile_idx = idx % self.tiles_per_image
        tiles_input, tiles_mask = self.parent_dataset[image_idx]
        return tiles_input[tile_idx], tiles_mask[tile_idx]

    def __len__(self):
        return len(self.parent_dataset) * self.tiles_per_image

class FilteredTileDataset(Dataset):
    def __init__(self, tiled_dataset, keep_zero_damage_prob=0.1):
        self.tiled_dataset = tiled_dataset
        self.filtered_indices = []
        for idx in range(len(tiled_dataset)):
            _, mask = tiled_dataset[idx]
            if mask.max() == 0:
                if random.random() < keep_zero_damage_prob:
                    self.filtered_indices.append(idx)
            else:
                self.filtered_indices.append(idx)

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        return self.tiled_dataset[self.filtered_indices[idx]]

from torch.utils.data import Subset

def get_majority_class_per_tile(tile_mask):
    # Ignore padding or empty masks
    return torch.bincount(tile_mask.flatten()).argmax().item()

def stratified_tile_sample(dataset, total_samples=500):
    class_indices = {0: [], 1: [], 2: [], 3: []}

    print("Scanning tile classes for stratified sampling...")
    for idx in tqdm(range(len(dataset))):
        _, mask = dataset[idx]
        majority_class = get_majority_class_per_tile(mask)
        if majority_class in class_indices:
            class_indices[majority_class].append(idx)

    # Count all
    total_tiles = sum(len(v) for v in class_indices.values())
    proportions = {cls: len(v)/total_tiles for cls, v in class_indices.items()}
    print("Class proportions:", proportions)

    # Sample proportional counts
    sampled_indices = []
    for cls in range(4):
        n_samples = int(proportions[cls] * total_samples)
        n_available = len(class_indices[cls])
        chosen = random.sample(class_indices[cls], min(n_samples, n_available))
        sampled_indices.extend(chosen)

    return Subset(dataset, sampled_indices)

# ============================
# Helper Functions
# ============================
def tile_tensor_and_mask(input_tensor, mask_tensor, tile_size=(256, 256)):
    C, H, W = input_tensor.shape
    th, tw = tile_size
    tiles_input, tiles_mask = [], []
    for i in range(0, H, th):
        for j in range(0, W, tw):
            tiles_input.append(input_tensor[:, i:i+th, j:j+tw])
            tiles_mask.append(mask_tensor[i:i+th, j:j+tw])
    return tiles_input, tiles_mask

# ============================
# Visualization
# ============================
def unnormalize(img_tensor, mean, std):
    for t, m, s in zip(img_tensor, mean, std):
        t.mul_(s).add_(m)
    return img_tensor

def visualize_predictions(inputs, masks, outputs, n=4):
    preds = torch.argmax(outputs, dim=1).cpu().numpy()
    masks = masks.cpu().numpy()
    cmap = mcolors.ListedColormap(['black', 'blue', 'yellow', 'red'])
    bounds = [0, 1, 2, 3, 4]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    for i in range(min(n, inputs.shape[0])):
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        pre_rgb = unnormalize(inputs[i][:3].clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        axs[0].imshow(np.transpose(torch.clamp(pre_rgb, 0, 1).cpu().numpy(), (1, 2, 0)))
        axs[0].set_title("Pre-disaster RGB")
        axs[0].axis('off')
        axs[1].imshow(inputs[i][3].cpu().numpy(), cmap='gray')
        axs[1].set_title("Post-disaster SAR")
        axs[1].axis('off')
        axs[2].imshow(preds[i], cmap=cmap, norm=norm)
        axs[2].set_title("Predicted")
        axs[2].axis('off')
        plt.show()
        plt.figure()
        plt.imshow(masks[i], cmap=cmap, norm=norm)
        plt.title("Ground Truth")
        plt.axis('off')
        plt.show()

print("Saved all")

class UNetOriginal(nn.Module):
    def __init__(self, in_channels=4, out_classes=4):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = conv_block(512, 1024)

        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        # Output
        self.final = nn.Conv2d(64, out_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.final(d1)
model_path = 'C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/model_epoch7.pt'
model = UNetOriginal(in_channels=4, out_classes=4)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
print("model loaded successfully")


from torch.utils.data import DataLoader

# Transforms (same as training)
transform_pre = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_post = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

# Load full xBD dataset
xbd_root = r"C:\Users\sweta\.cache\kagglehub\datasets\qianlanzz\xbd-dataset\versions\1\xbd\tier1"
xbd_dataset = XBDMulticlassDataset(
    xbd_root, 
    transform_pre=transform_pre, 
    transform_post=transform_post,
    image_size=(1024, 1024),
    tile_size=(256, 256)
)

# Convert to per-tile dataset for evaluation
xbd_tile_dataset = xbd_dataset.get_tile_dataset()
xbd_loader = DataLoader(xbd_tile_dataset, batch_size=8, shuffle=False)

def evaluate_model_on_xbd(model, dataloader, device, label, save_root):
    model.eval()
    model.to(device)
    
    results = []
    save_dir = os.path.join(save_root, f"xbd_{label}")
    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(dataloader, desc=f"Evaluating on xBD: {label}")):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for j in range(x.size(0)):
                acc = compute_accuracy(y_pred[j], y[j])
                iou = compute_iou(y_pred[j], y[j])
                results.append({
                    "tile_index": i * x.size(0) + j,
                    "accuracy": acc,
                    "iou": iou
                })

    # Save CSV
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(save_dir, f"xbd_{label}_metrics.csv"), index=False)

    # Summary
    print(f"\n[SUMMARY: xBD {label}]")
    print(df.describe())
    

Saved all
model loaded successfully


  model.load_state_dict(torch.load(model_path, map_location='cpu'))


In [6]:
from torch.utils.data import Subset
from tqdm import tqdm
xbd_tile_dataset = xbd_dataset.get_tile_dataset()
stratified_subset = stratified_tile_sample(xbd_tile_dataset, total_samples=500)
xbd_loader = DataLoader(stratified_subset, batch_size=8, shuffle=False)

Scanning tile classes for stratified sampling...


100%|██████████| 44784/44784 [3:57:20<00:00,  3.14it/s]  

Class proportions: {0: 0.9993971061093248, 1: 0.0, 2: 0.00040192926045016077, 3: 0.00020096463022508038}





LIC

In [7]:
#goal: Test how the model behaves differently once it's fine tuned (are there any patterns?) also test how well the different models perform on LIC_test. LIC_test does have a damage mask as somewhat ground truth so use that.
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

# ============ Dataset =============
class LICTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        x = torch.cat([pre_img, post_img], dim=0)
        return x, mask, fname

# ============ Metrics ============
def compute_iou(pred, target, num_classes=4):
    pred = pred.flatten()
    target = target.flatten()
    iou_per_class = []
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            iou_per_class.append(np.nan)
        else:
            iou_per_class.append(intersection / union)
    return np.nanmean(iou_per_class)

def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

# ============ Visualization ============
def save_visualization(pre_img, post_img, pred_mask, true_mask, fname, save_dir):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))

    axs[0].imshow(pre_img.permute(1, 2, 0).cpu())
    axs[0].set_title("Pre-disaster")
    axs[0].axis("off")

    axs[1].imshow(post_img.squeeze(0).cpu(), cmap='gray')
    axs[1].set_title("Post SAR")
    axs[1].axis("off")

    axs[2].imshow(true_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
    axs[2].set_title("Ground Truth")
    axs[2].axis("off")

    axs[3].imshow(pred_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
    axs[3].set_title("Prediction")
    axs[3].axis("off")

    os.makedirs(save_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, fname.replace('.png', '_viz.png')))
    plt.close()

# ============ Evaluation ============
def evaluate_and_visualize(model, dataloader, device, label, save_root):
    model.eval()
    model.to(device)

    results = []
    vis_dir = os.path.join(save_root, label)
    os.makedirs(vis_dir, exist_ok=True)

    with torch.no_grad():
        for x, y, fname in tqdm(dataloader, desc=f"Evaluating {label}"):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for i in range(len(y)):
                fname_i = fname[i]
                acc = compute_accuracy(y_pred[i], y[i])
                iou = compute_iou(y_pred[i], y[i])
                results.append({"filename": fname_i, "accuracy": acc, "iou": iou})

                save_visualization(x[i, :3], x[i, 3:], y_pred[i], y[i], fname_i, vis_dir)

    # Save CSV
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(vis_dir, f"{label}_metrics.csv"), index=False)

    # Print summary
    print(f"\n[SUMMARY: {label}]")
    print(df.describe())

# ============ Load and Run ============

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(path):
    model = UNetOriginal(in_channels=4, out_classes=4)
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
    return model

# Paths
base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
save_root = os.path.join(base_dir, "LIC_pseudo", "results")

baseline_path = os.path.join(base_dir, "iteration_models", "model_epoch8.pt")
half_path = os.path.join(base_dir, "iteration_models", "lic_half_finetunedmodel.pt")
full_path = os.path.join(base_dir, "iteration_models", "lic_full_finetunedmodel.pt")

# Load models
baseline = load_model(baseline_path)
half = load_model(half_path)
full = load_model(full_path)

# Dataset
test_dir = os.path.join(base_dir, "LIC_pseudo", "test")
test_dataset = LICTestDataset(os.path.join(test_dir, "pre"),
                              os.path.join(test_dir, "post"),
                              os.path.join(test_dir, "mask"))
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Evaluate all 3
#bright_save_root = os.path.join(base_dir, "LIC_pseudo", "results", "bright")
#evaluate_and_visualize(baseline, test_loader, device, label="baseline", save_root=bright_save_root)
#evaluate_and_visualize(half, test_loader, device, label="lic_half", save_root=bright_save_root)
#evaluate_and_visualize(full, test_loader, device, label="lic_full", save_root=bright_save_root)
xbd_save_root = os.path.join(base_dir, "LIC_pseudo", "results", "xbd")
evaluate_model_on_xbd(baseline, xbd_loader, device, label="baseline", save_root=xbd_save_root)
evaluate_model_on_xbd(half, xbd_loader, device, label="lic_half", save_root=xbd_save_root)
evaluate_model_on_xbd(full, xbd_loader, device, label="lic_full", save_root=xbd_save_root)


  checkpoint = torch.load(path, map_location=device)
Evaluating on xBD: baseline: 100%|██████████| 63/63 [10:24<00:00,  9.91s/it]



[SUMMARY: xBD baseline]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.993386    0.900725
std    142.971455    0.029563    0.204571
min      0.000000    0.607132    0.299647
25%    124.500000    1.000000    1.000000
50%    246.000000    1.000000    1.000000
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000


Evaluating on xBD: lic_half: 100%|██████████| 63/63 [10:28<00:00,  9.98s/it]



[SUMMARY: xBD lic_half]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.932707    0.559485
std    142.971455    0.086282    0.242646
min      0.000000    0.485046    0.181788
25%    124.500000    0.897469    0.430180
50%    246.000000    0.968918    0.487396
75%    370.500000    0.998138    0.511667
max    495.000000    1.000000    1.000000


Evaluating on xBD: lic_full: 100%|██████████| 63/63 [10:22<00:00,  9.88s/it]


[SUMMARY: xBD lic_full]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.904372    0.486507
std    142.971455    0.119239    0.250159
min      0.000000    0.365463    0.121821
25%    124.500000    0.852379    0.305191
50%    246.000000    0.952408    0.455956
75%    370.500000    0.995987    0.498734
max    495.000000    1.000000    1.000000





In [8]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

# ============ Dataset =============
class MICTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        x = torch.cat([pre_img, post_img], dim=0)
        return x, mask, fname

# ============ Metrics ============
def compute_iou(pred, target, num_classes=4):
    pred = pred.flatten()
    target = target.flatten()
    iou_per_class = []
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            iou_per_class.append(np.nan)
        else:
            iou_per_class.append(intersection / union)
    return np.nanmean(iou_per_class)

def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

# ============ Visualization ============
def save_visualization(pre_img, post_img, pred_mask, true_mask, fname, save_dir):
    try:
        fig, axs = plt.subplots(1, 4, figsize=(16, 4))
        axs[0].imshow(pre_img.permute(1, 2, 0).cpu())
        axs[0].set_title("Pre-disaster")
        axs[0].axis("off")

        axs[1].imshow(post_img.squeeze(0).cpu(), cmap='gray')
        axs[1].set_title("Post SAR")
        axs[1].axis("off")

        axs[2].imshow(true_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
        axs[2].set_title("Ground Truth")
        axs[2].axis("off")

        axs[3].imshow(pred_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
        axs[3].set_title("Prediction")
        axs[3].axis("off")

        os.makedirs(save_dir, exist_ok=True)
        safe_fname = os.path.basename(fname).replace('.png', '_viz.png').replace(' ', '_')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, safe_fname))
        plt.close()
    except Exception as e:
        print(f"[Warning] Could not save {fname}: {e}")
# ============ Evaluation ============
def evaluate_and_visualize(model, dataloader, device, label, save_root):
    model.eval()
    model.to(device)

    results = []
    vis_dir = os.path.join(save_root, label)
    os.makedirs(vis_dir, exist_ok=True)

    with torch.no_grad():
        for x, y, fname in tqdm(dataloader, desc=f"Evaluating {label}"):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for i in range(len(y)):
                fname_i = fname[i]
                acc = compute_accuracy(y_pred[i], y[i])
                iou = compute_iou(y_pred[i], y[i])
                results.append({"filename": fname_i, "accuracy": acc, "iou": iou})

                save_visualization(x[i, :3], x[i, 3:], y_pred[i], y[i], fname_i, vis_dir)

    # Save CSV
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(vis_dir, f"{label}_metrics.csv"), index=False)

    # Print summary
    print(f"\n[SUMMARY: {label}]")
    print(df.describe())

# ============ Load and Run ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(path):
    model = UNetOriginal(in_channels=4, out_classes=4)
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
    return model

# Paths
base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
mic_base = os.path.join(base_dir, "MIC_pseudo")
save_root = os.path.join(mic_base, "results")

baseline_path = os.path.join(base_dir, "iteration_models", "model_epoch8.pt")
half_path = os.path.join(base_dir, "iteration_models", "mic_half_finetunedmodel.pt")
full_path = os.path.join(base_dir, "iteration_models", "mic_full_finetunedmodel.pt")

# Load models
baseline = load_model(baseline_path)
half = load_model(half_path)
full = load_model(full_path)

# MIC Dataset
test_dir = os.path.join(mic_base, "test")
test_dataset = MICTestDataset(os.path.join(test_dir, "pre"),
                              os.path.join(test_dir, "post"),
                              os.path.join(test_dir, "mask"))
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Evaluate all 3 models on MIC test set
#bright_save_root = os.path.join(mic_base, "results", "bright")
#evaluate_and_visualize(baseline, test_loader, device, label="baseline", save_root=bright_save_root)
#evaluate_and_visualize(half, test_loader, device, label="mic_half", save_root=bright_save_root)
#evaluate_and_visualize(full, test_loader, device, label="mic_full", save_root=bright_save_root)

# Optional: Evaluate on MIC-style xBD data (if xbd_loader is defined)
mic_xbd_save_root = os.path.join(mic_base, "results", "xbd")
evaluate_model_on_xbd(baseline, xbd_loader, device, label="baseline", save_root=mic_xbd_save_root)
evaluate_model_on_xbd(half, xbd_loader, device, label="mic_half", save_root=mic_xbd_save_root)
evaluate_model_on_xbd(full, xbd_loader, device, label="mic_full", save_root=mic_xbd_save_root)

  checkpoint = torch.load(path, map_location=device)
Evaluating on xBD: baseline: 100%|██████████| 63/63 [10:25<00:00,  9.93s/it]



[SUMMARY: xBD baseline]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.993386    0.900725
std    142.971455    0.029563    0.204571
min      0.000000    0.607132    0.299647
25%    124.500000    1.000000    1.000000
50%    246.000000    1.000000    1.000000
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000


Evaluating on xBD: mic_half: 100%|██████████| 63/63 [10:27<00:00,  9.96s/it]



[SUMMARY: xBD mic_half]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.928177    0.445440
std    142.971455    0.112306    0.294074
min      0.000000    0.155121    0.038780
25%    124.500000    0.901314    0.242765
50%    246.000000    0.979477    0.326355
75%    370.500000    0.999702    0.499836
max    495.000000    1.000000    1.000000


Evaluating on xBD: mic_full: 100%|██████████| 63/63 [10:21<00:00,  9.87s/it]


[SUMMARY: xBD mic_full]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.917107    0.399179
std    142.971455    0.122471    0.292692
min      0.000000    0.143768    0.035942
25%    124.500000    0.881729    0.225851
50%    246.000000    0.975510    0.247620
75%    370.500000    0.999222    0.498798
max    495.000000    1.000000    1.000000





In [9]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

# ============ Dataset =============
class HICTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        x = torch.cat([pre_img, post_img], dim=0)
        return x, mask, fname

# ============ Metrics ============
def compute_iou(pred, target, num_classes=4):
    pred = pred.flatten()
    target = target.flatten()
    iou_per_class = []
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            iou_per_class.append(np.nan)
        else:
            iou_per_class.append(intersection / union)
    return np.nanmean(iou_per_class)

def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

# ============ Visualization ============
def save_visualization(pre_img, post_img, pred_mask, true_mask, fname, save_dir):
    try:
        fig, axs = plt.subplots(1, 4, figsize=(16, 4))
        axs[0].imshow(pre_img.permute(1, 2, 0).cpu())
        axs[0].set_title("Pre-disaster")
        axs[0].axis("off")

        axs[1].imshow(post_img.squeeze(0).cpu(), cmap='gray')
        axs[1].set_title("Post SAR")
        axs[1].axis("off")

        axs[2].imshow(true_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
        axs[2].set_title("Ground Truth")
        axs[2].axis("off")

        axs[3].imshow(pred_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
        axs[3].set_title("Prediction")
        axs[3].axis("off")

        os.makedirs(save_dir, exist_ok=True)
        safe_fname = os.path.basename(fname).replace('.png', '_viz.png').replace(' ', '_')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, safe_fname))
        plt.close()
    except Exception as e:
        print(f"[Warning] Could not save {fname}: {e}")


# ============ Evaluation ============
def evaluate_and_visualize(model, dataloader, device, label, save_root):
    model.eval()
    model.to(device)

    results = []
    vis_dir = os.path.join(save_root, label)
    os.makedirs(vis_dir, exist_ok=True)

    with torch.no_grad():
        for x, y, fname in tqdm(dataloader, desc=f"Evaluating {label}"):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for i in range(len(y)):
                fname_i = fname[i]
                acc = compute_accuracy(y_pred[i], y[i])
                iou = compute_iou(y_pred[i], y[i])
                results.append({"filename": fname_i, "accuracy": acc, "iou": iou})

                save_visualization(x[i, :3], x[i, 3:], y_pred[i], y[i], fname_i, vis_dir)

    # Save CSV
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(vis_dir, f"{label}_metrics.csv"), index=False)

    # Print summary
    print(f"\n[SUMMARY: {label}]")
    print(df.describe())

# ============ Load and Run ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(path):
    model = UNetOriginal(in_channels=4, out_classes=4)
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
    return model

# Paths
base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
hic_base = os.path.join(base_dir, "HIC_pseudo")
save_root = os.path.join(hic_base, "results")

baseline_path = os.path.join(base_dir, "iteration_models", "model_epoch8.pt")
half_path = os.path.join(base_dir, "iteration_models", "hic_half_finetunedmodel.pt")
full_path = os.path.join(base_dir, "iteration_models", "hic_full_finetunedmodel.pt")

# Load models
baseline = load_model(baseline_path)
half = load_model(half_path)
full = load_model(full_path)

# HIC Dataset
test_dir = os.path.join(hic_base, "test")
test_dataset = HICTestDataset(os.path.join(test_dir, "pre"),
                              os.path.join(test_dir, "post"),
                              os.path.join(test_dir, "mask"))
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Evaluate all 3 models on HIC test set
#bright_save_root = os.path.join(hic_base, "results", "bright")
#evaluate_and_visualize(baseline, test_loader, device, label="baseline", save_root=bright_save_root)
#evaluate_and_visualize(half, test_loader, device, label="hic_half", save_root=bright_save_root)
#evaluate_and_visualize(full, test_loader, device, label="hic_full", save_root=bright_save_root)

# Optional: Evaluate on HIC-style xBD data (if xbd_loader is defined)
hic_xbd_save_root = os.path.join(hic_base, "results", "xbd")
evaluate_model_on_xbd(baseline, xbd_loader, device, label="baseline", save_root=hic_xbd_save_root)
evaluate_model_on_xbd(half, xbd_loader, device, label="hic_half", save_root=hic_xbd_save_root)
evaluate_model_on_xbd(full, xbd_loader, device, label="hic_full", save_root=hic_xbd_save_root)

  checkpoint = torch.load(path, map_location=device)
Evaluating on xBD: baseline: 100%|██████████| 63/63 [10:26<00:00,  9.95s/it]



[SUMMARY: xBD baseline]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.993386    0.900725
std    142.971455    0.029563    0.204571
min      0.000000    0.607132    0.299647
25%    124.500000    1.000000    1.000000
50%    246.000000    1.000000    1.000000
75%    370.500000    1.000000    1.000000
max    495.000000    1.000000    1.000000


Evaluating on xBD: hic_half: 100%|██████████| 63/63 [10:29<00:00, 10.00s/it]



[SUMMARY: xBD hic_half]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.898861    0.365926
std    142.971455    0.120764    0.136038
min      0.000000    0.235657    0.058914
25%    124.500000    0.850159    0.275195
50%    246.000000    0.945129    0.324808
75%    370.500000    0.985489    0.491058
max    495.000000    1.000000    1.000000


Evaluating on xBD: hic_full: 100%|██████████| 63/63 [10:24<00:00,  9.92s/it]


[SUMMARY: xBD hic_full]
       tile_index    accuracy         iou
count  499.000000  499.000000  499.000000
mean   247.136273    0.682382    0.191288
std    142.971455    0.182684    0.065277
min      0.000000    0.054825    0.013706
25%    124.500000    0.558769    0.149138
50%    246.000000    0.689087    0.183025
75%    370.500000    0.832489    0.225420
max    495.000000    0.995697    0.401630





In [5]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

# ============ Dataset =============
class LICTestDataset(Dataset):
    def __init__(self, pre_dir, post_dir, mask_dir):
        self.pre_dir = pre_dir
        self.post_dir = post_dir
        self.mask_dir = mask_dir
        self.files = [f for f in os.listdir(pre_dir) if f.endswith(('.png', '.jpg'))]
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        pre_path = os.path.join(self.pre_dir, fname)
        post_path = os.path.join(self.post_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        pre_img = self.transform(Image.open(pre_path).convert("RGB"))
        post_img = self.transform(Image.open(post_path).convert("L"))
        mask = torch.from_numpy(np.array(Image.open(mask_path))).long()

        if post_img.dim() == 2:
            post_img = post_img.unsqueeze(0)
        x = torch.cat([pre_img, post_img], dim=0)
        return x, mask, fname

# ============ Metrics ============
def compute_all_metrics(pred, target, num_classes=4):
    pred = pred.flatten().cpu().numpy()
    target = target.flatten().cpu().numpy()

    metrics = {'iou': [], 'precision': [], 'recall': [], 'dice': []}

    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls

        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        tp = intersection
        fp = np.logical_and(pred_inds, ~target_inds).sum()
        fn = np.logical_and(~pred_inds, target_inds).sum()

        iou = intersection / union if union > 0 else np.nan
        precision = tp / (tp + fp) if (tp + fp) > 0 else np.nan
        recall = tp / (tp + fn) if (tp + fn) > 0 else np.nan
        dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else np.nan

        metrics['iou'].append(iou)
        metrics['precision'].append(precision)
        metrics['recall'].append(recall)
        metrics['dice'].append(dice)

    return metrics

def compute_accuracy(pred, target):
    return (pred == target).float().mean().item()

# ============ Visualization ============
def save_visualization(pre_img, post_img, pred_mask, true_mask, fname, save_dir):
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(pre_img.permute(1, 2, 0).cpu())
    axs[0].set_title("Pre-disaster"); axs[0].axis("off")
    axs[1].imshow(post_img.squeeze(0).cpu(), cmap='gray')
    axs[1].set_title("Post SAR"); axs[1].axis("off")
    axs[2].imshow(true_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
    axs[2].set_title("Ground Truth"); axs[2].axis("off")
    axs[3].imshow(pred_mask.cpu(), cmap='tab10', vmin=0, vmax=3)
    axs[3].set_title("Prediction"); axs[3].axis("off")
    os.makedirs(save_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, fname.replace('.png', '_viz.png')))
    plt.close()

# ============ Evaluation ============
def evaluate_and_visualize(model, dataloader, device, label, save_root, num_classes=4):
    model.eval()
    model.to(device)

    results = []
    per_class_metrics = {cls: {'iou': [], 'precision': [], 'recall': [], 'dice': []} for cls in range(num_classes)}
    vis_dir = os.path.join(save_root, label)
    os.makedirs(vis_dir, exist_ok=True)

    with torch.no_grad():
        for x, y, fname in tqdm(dataloader, desc=f"Evaluating {label}"):
            x, y = x.to(device), y.to(device)
            y_pred = torch.argmax(model(x), dim=1)

            for i in range(len(y)):
                fname_i = fname[i]
                acc = compute_accuracy(y_pred[i], y[i])
                metrics = compute_all_metrics(y_pred[i], y[i], num_classes=num_classes)

                for cls in range(num_classes):
                    for metric in ['iou', 'precision', 'recall', 'dice']:
                        per_class_metrics[cls][metric].append(metrics[metric][cls])

                results.append({
                    "filename": fname_i,
                    "accuracy": acc,
                    "mean_iou": np.nanmean(metrics['iou']),
                    "mean_precision": np.nanmean(metrics['precision']),
                    "mean_recall": np.nanmean(metrics['recall']),
                    "mean_dice": np.nanmean(metrics['dice']),
                })

                save_visualization(x[i, :3], x[i, 3:], y_pred[i], y[i], fname_i, vis_dir)

    # Save metrics
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(vis_dir, f"{label}_metrics.csv"), index=False)

    # Print overall summary
    print(f"\n[SUMMARY: {label}]")
    print(df.describe())

    # Per-class summary
    print("\n[Per-Class Metrics]")
    for cls in range(num_classes):
        print(f"Class {cls}:")
        for metric in ['iou', 'precision', 'recall', 'dice']:
            values = np.array(per_class_metrics[cls][metric])
            print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")

# ============ Load and Run ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model(path):
    model = UNetOriginal(in_channels=4, out_classes=4)
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
    return model

base_dir = r"C:\Users\sweta\anaconda_projects\non-trivial\performance_bias"
save_root = os.path.join(base_dir, "LIC_pseudo", "results")

baseline_path = os.path.join(base_dir, "iteration_models", "model_epoch8.pt")
half_path = os.path.join(base_dir, "iteration_models", "lic_half_finetunedmodel.pt")
full_path = os.path.join(base_dir, "iteration_models", "lic_full_finetunedmodel.pt")

baseline = load_model(baseline_path)
half = load_model(half_path)
full = load_model(full_path)

test_dir = os.path.join(base_dir, "LIC_pseudo", "test")
test_dataset = LICTestDataset(
    os.path.join(test_dir, "pre"),
    os.path.join(test_dir, "post"),
    os.path.join(test_dir, "mask")
)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
bright_save_root = os.path.join(base_dir, "LIC_pseudo", "results", "bright")
evaluate_and_visualize(baseline, test_loader, device, label="baseline_updated", save_root=bright_save_root)
evaluate_and_visualize(half, test_loader, device, label="lic_half_updated", save_root=bright_save_root)
evaluate_and_visualize(full, test_loader, device, label="lic_full_updated", save_root=bright_save_root)
xbd_save_root = os.path.join(base_dir, "LIC_pseudo", "results", "xbd")
evaluate_model_on_xbd(baseline, xbd_loader, device, label="baseline_updated", save_root=xbd_save_root)
evaluate_model_on_xbd(half, xbd_loader, device, label="lic_half_updated", save_root=xbd_save_root)
evaluate_model_on_xbd(full, xbd_loader, device, label="lic_full_updated", save_root=xbd_save_root)
# Evaluate all models
#evaluate_and_visualize(baseline, test_loader, device, label="baseline", save_root=save_root)
#evaluate_and_visualize(half, test_loader, device, label="lic_half", save_root=save_root)
#evaluate_and_visualize(full, test_loader, device, label="lic_full", save_root=save_root)


  checkpoint = torch.load(path, map_location=device)
Evaluating baseline_updated: 100%|██████████| 30/30 [01:42<00:00,  3.41s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,



[SUMMARY: baseline_updated]
        accuracy   mean_iou  mean_precision  mean_recall  mean_dice
count  59.000000  59.000000       59.000000    59.000000  59.000000
mean    0.283409   0.071592        0.283409     0.251412   0.104448
std     0.183508   0.047194        0.183508     0.010849   0.051990
min     0.051620   0.012905        0.051620     0.250000   0.024543
25%     0.162010   0.040503        0.162010     0.250000   0.069696
50%     0.242752   0.060688        0.242752     0.250000   0.097667
75%     0.361473   0.090368        0.361473     0.250000   0.132749
max     0.938889   0.234722        0.938889     0.333333   0.242120

[Per-Class Metrics]
Class 0:
        iou: mean=0.2834, std=0.1819
  precision: mean=0.2834, std=0.1819
     recall: mean=1.0000, std=0.0000
       dice: mean=0.4139, std=0.1988
Class 1:
        iou: mean=0.0000, std=0.0000
  precision: mean=nan, std=nan
     recall: mean=0.0000, std=0.0000
       dice: mean=0.0000, std=0.0000
Class 2:
        iou: mean=0.0

Evaluating lic_half_updated: 100%|██████████| 30/30 [01:41<00:00,  3.40s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,



[SUMMARY: lic_half_updated]
        accuracy   mean_iou  mean_precision  mean_recall  mean_dice
count  59.000000  59.000000       59.000000    59.000000  59.000000
mean    0.294068   0.094204        0.275408     0.266111   0.142525
std     0.188199   0.065494        0.143277     0.036655   0.079478
min     0.054626   0.013865        0.057266     0.223151   0.026316
25%     0.190033   0.056984        0.175018     0.249620   0.098921
50%     0.252228   0.081585        0.232745     0.255830   0.133875
75%     0.359863   0.114970        0.373167     0.267138   0.173821
max     0.944916   0.336438        0.623771     0.480077   0.440837

[Per-Class Metrics]
Class 0:
        iou: mean=0.2895, std=0.1939
  precision: mean=0.2935, std=0.1949
     recall: mean=0.9445, std=0.0645
       dice: mean=0.4183, std=0.2076
Class 1:
        iou: mean=0.0000, std=0.0000
  precision: mean=nan, std=nan
     recall: mean=0.0000, std=0.0000
       dice: mean=0.0000, std=0.0000
Class 2:
        iou: mean=0.0

Evaluating lic_full_updated: 100%|██████████| 30/30 [01:43<00:00,  3.46s/it]
  print(f"  {metric:>9}: mean={np.nanmean(values):.4f}, std={np.nanstd(values):.4f}")
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,



[SUMMARY: lic_full_updated]
        accuracy   mean_iou  mean_precision  mean_recall  mean_dice
count  59.000000  59.000000       59.000000    59.000000  59.000000
mean    0.301318   0.125132        0.300623     0.305014   0.192537
std     0.186440   0.082789        0.134228     0.067107   0.098970
min     0.060089   0.017511        0.088565     0.196622   0.033603
25%     0.198608   0.079668        0.216208     0.262060   0.135147
50%     0.246765   0.105552        0.252275     0.291691   0.176179
75%     0.344200   0.133955        0.364894     0.325581   0.218985
max     0.960709   0.482361        0.698291     0.614091   0.629608

[Per-Class Metrics]
Class 0:
        iou: mean=0.2898, std=0.2094
  precision: mean=0.3152, std=0.2145
     recall: mean=0.7722, std=0.1854
       dice: mean=0.4147, std=0.2166
Class 1:
        iou: mean=0.0000, std=0.0000
  precision: mean=nan, std=nan
     recall: mean=0.0000, std=0.0000
       dice: mean=0.0000, std=0.0000
Class 2:
        iou: mean=0.1

Evaluating on xBD: baseline_updated:   1%|▏         | 77/5598 [13:27<16:04:50, 10.49s/it]


KeyboardInterrupt: 