In [1]:
# ============================
# Imports
# ============================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms as T
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import random
from PIL import Image, ImageOps


# ============================
# Dataset Classes
# ============================
class XBDMulticlassDataset(Dataset):
    def __init__(self, root_dir, transform_pre=None, transform_post=None,
                 image_size=(1024, 1024), tile_size=(256, 256), max_images=None):
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        self.files = sorted([f for f in os.listdir(self.image_dir) if '_pre_disaster.png' in f])
        if max_images:
            self.files = self.files[:max_images]
        self.transform_pre = transform_pre
        self.transform_post = transform_post
        self.image_size = image_size
        self.tile_size = tile_size
        self.tiles_per_image = (image_size[0] // tile_size[0]) * (image_size[1] // tile_size[1])

    def __getitem__(self, idx, return_raw=False):
        pre_file = self.files[idx]
        post_file = pre_file.replace('pre', 'post')
        mask_file = post_file.replace('.png', '_rgb.png')
        pre_path = os.path.join(self.image_dir, pre_file)
        post_path = os.path.join(self.image_dir, post_file)
        mask_path = os.path.join(self.mask_dir, mask_file)

        pre_img_raw = Image.open(pre_path).convert("RGB").resize(self.image_size)
        post_img_raw = Image.open(post_path).convert("RGB").resize(self.image_size)
        mask_img_raw = Image.open(mask_path).convert("RGB").resize(self.image_size, Image.NEAREST)

        def optical_to_sar_like(img):
            img = img.convert('L')
            img = ImageOps.autocontrast(img, cutoff=2)
            return img

        def update_mask_multiclass(mask_rgb_img):
            mask_np = np.array(mask_rgb_img)
            label_mask = np.zeros(mask_np.shape[:2], dtype=np.uint8)
            color_to_label = {
                (0, 0, 0): 0, (0, 255, 255): 0, (0, 0, 255): 1,
                (255, 255, 0): 2, (255, 0, 0): 3, (211, 211, 211): 0
            }
            for rgb, label in color_to_label.items():
                mask = np.all(mask_np == rgb, axis=-1)
                label_mask[mask] = label
            return label_mask

        post_img_sar_raw = optical_to_sar_like(post_img_raw)
        if return_raw:
            mask_np = update_mask_multiclass(mask_img_raw)
            return pre_img_raw, post_img_sar_raw, mask_np

        pre_img = self.transform_pre(pre_img_raw) if self.transform_pre else pre_img_raw
        post_img_sar = self.transform_post(post_img_sar_raw) if self.transform_post else post_img_sar_raw

        input_tensor = torch.cat([pre_img, post_img_sar], dim=0)
        mask_tensor = torch.tensor(update_mask_multiclass(mask_img_raw), dtype=torch.long)

        return tile_tensor_and_mask(input_tensor, mask_tensor, self.tile_size)

    def __len__(self):
        return len(self.files)

    def get_tile_dataset(self):
        return TiledXBDDataset(self)

class TiledXBDDataset(Dataset):
    def __init__(self, parent_dataset):
        self.parent_dataset = parent_dataset
        self.tiles_per_image = parent_dataset.tiles_per_image

    def __getitem__(self, idx):
        image_idx = idx // self.tiles_per_image
        tile_idx = idx % self.tiles_per_image
        tiles_input, tiles_mask = self.parent_dataset[image_idx]
        return tiles_input[tile_idx], tiles_mask[tile_idx]

    def __len__(self):
        return len(self.parent_dataset) * self.tiles_per_image

class FilteredTileDataset(Dataset):
    def __init__(self, tiled_dataset, keep_zero_damage_prob=0.1):
        self.tiled_dataset = tiled_dataset
        self.filtered_indices = []
        for idx in range(len(tiled_dataset)):
            _, mask = tiled_dataset[idx]
            if mask.max() == 0:
                if random.random() < keep_zero_damage_prob:
                    self.filtered_indices.append(idx)
            else:
                self.filtered_indices.append(idx)

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        return self.tiled_dataset[self.filtered_indices[idx]]

# ============================
# Helper Functions
# ============================
def tile_tensor_and_mask(input_tensor, mask_tensor, tile_size=(256, 256)):
    C, H, W = input_tensor.shape
    th, tw = tile_size
    tiles_input, tiles_mask = [], []
    for i in range(0, H, th):
        for j in range(0, W, tw):
            tiles_input.append(input_tensor[:, i:i+th, j:j+tw])
            tiles_mask.append(mask_tensor[i:i+th, j:j+tw])
    return tiles_input, tiles_mask

# ============================
# Visualization
# ============================
def unnormalize(img_tensor, mean, std):
    for t, m, s in zip(img_tensor, mean, std):
        t.mul_(s).add_(m)
    return img_tensor

def visualize_predictions(inputs, masks, outputs, n=4):
    preds = torch.argmax(outputs, dim=1).cpu().numpy()
    masks = masks.cpu().numpy()
    cmap = mcolors.ListedColormap(['black', 'blue', 'yellow', 'red'])
    bounds = [0, 1, 2, 3, 4]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    for i in range(min(n, inputs.shape[0])):
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        pre_rgb = unnormalize(inputs[i][:3].clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        axs[0].imshow(np.transpose(torch.clamp(pre_rgb, 0, 1).cpu().numpy(), (1, 2, 0)))
        axs[0].set_title("Pre-disaster RGB")
        axs[0].axis('off')
        axs[1].imshow(inputs[i][3].cpu().numpy(), cmap='gray')
        axs[1].set_title("Post-disaster SAR")
        axs[1].axis('off')
        axs[2].imshow(preds[i], cmap=cmap, norm=norm)
        axs[2].set_title("Predicted")
        axs[2].axis('off')
        plt.show()
        plt.figure()
        plt.imshow(masks[i], cmap=cmap, norm=norm)
        plt.title("Ground Truth")
        plt.axis('off')
        plt.show()

print("Saved all")

Saved all


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNetOriginal(nn.Module):
    def __init__(self, in_channels=4, out_classes=4):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )

        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = conv_block(512, 1024)

        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        # Output
        self.final = nn.Conv2d(64, out_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.final(d1)
model_path = 'C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/model_epoch7.pt'
model = UNetOriginal(in_channels=4, out_classes=4)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
print("model loaded successfully")


model loaded successfully


  model.load_state_dict(torch.load(model_path, map_location='cpu'))


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import random
from PIL import Image, ImageOps
import csv

# ============================
# Dataset Classes and Helpers
# ============================

class XBDMulticlassDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform_pre=None, transform_post=None,
                 image_size=(1024, 1024), tile_size=(256, 256), max_images=None):
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        self.files = sorted([f for f in os.listdir(self.image_dir) if '_pre_disaster.png' in f])
        if max_images:
            self.files = self.files[:max_images]
        self.transform_pre = transform_pre
        self.transform_post = transform_post
        self.image_size = image_size
        self.tile_size = tile_size
        self.tiles_per_image = (image_size[0] // tile_size[0]) * (image_size[1] // tile_size[1])

    def __getitem__(self, idx, return_raw=False):
        pre_file = self.files[idx]
        post_file = pre_file.replace('pre', 'post')
        mask_file = post_file.replace('.png', '_rgb.png')
        pre_path = os.path.join(self.image_dir, pre_file)
        post_path = os.path.join(self.image_dir, post_file)
        mask_path = os.path.join(self.mask_dir, mask_file)

        pre_img_raw = Image.open(pre_path).convert("RGB").resize(self.image_size)
        post_img_raw = Image.open(post_path).convert("RGB").resize(self.image_size)
        mask_img_raw = Image.open(mask_path).convert("RGB").resize(self.image_size, Image.NEAREST)

        def optical_to_sar_like(img):
            img = img.convert('L')
            img = ImageOps.autocontrast(img, cutoff=2)
            return img

        def update_mask_multiclass(mask_rgb_img):
            mask_np = np.array(mask_rgb_img)
            label_mask = np.zeros(mask_np.shape[:2], dtype=np.uint8)
            color_to_label = {
                (0, 0, 0): 0, (0, 255, 255): 0, (0, 0, 255): 1,
                (255, 255, 0): 2, (255, 0, 0): 3, (211, 211, 211): 0
            }
            for rgb, label in color_to_label.items():
                mask = np.all(mask_np == rgb, axis=-1)
                label_mask[mask] = label
            return label_mask

        post_img_sar_raw = optical_to_sar_like(post_img_raw)
        if return_raw:
            mask_np = update_mask_multiclass(mask_img_raw)
            return pre_img_raw, post_img_sar_raw, mask_np

        pre_img = self.transform_pre(pre_img_raw) if self.transform_pre else pre_img_raw
        post_img_sar = self.transform_post(post_img_sar_raw) if self.transform_post else post_img_sar_raw

        input_tensor = torch.cat([pre_img, post_img_sar], dim=0)
        mask_tensor = torch.tensor(update_mask_multiclass(mask_img_raw), dtype=torch.long)

        return tile_tensor_and_mask(input_tensor, mask_tensor, self.tile_size)

    def __len__(self):
        return len(self.files)

    def get_tile_dataset(self):
        return TiledXBDDataset(self)

class TiledXBDDataset(torch.utils.data.Dataset):
    def __init__(self, parent_dataset):
        self.parent_dataset = parent_dataset
        self.tiles_per_image = parent_dataset.tiles_per_image

    def __getitem__(self, idx):
        image_idx = idx // self.tiles_per_image
        tile_idx = idx % self.tiles_per_image
        tiles_input, tiles_mask = self.parent_dataset[image_idx]
        return tiles_input[tile_idx], tiles_mask[tile_idx]

    def __len__(self):
        return len(self.parent_dataset) * self.tiles_per_image

class FilteredTileDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, keep_zero_damage_prob=0.1):
        """
        Filters full-size images from dataset before tiling to avoid unnecessary computation.
        """
        self.dataset = dataset  # this should be an XBDMulticlassDataset, not a tiled one
        self.valid_image_indices = []

        print("Filtering full images before tiling...")

        for idx in range(len(dataset)):
            if idx % 100 == 0:
                print(f"Checking image {idx}/{len(dataset)}")

            _, _, mask = dataset.__getitem__(idx, return_raw=True)
            if mask.max() == 0:
                if random.random() < keep_zero_damage_prob:
                    self.valid_image_indices.append(idx)
            else:
                self.valid_image_indices.append(idx)

        print(f"Kept {len(self.valid_image_indices)} of {len(dataset)} images after filtering.")
        self.tiles_per_image = dataset.tiles_per_image
        self.total_tiles = len(self.valid_image_indices) * self.tiles_per_image

    def __len__(self):
        return self.total_tiles

    def __getitem__(self, idx):
        image_idx = self.valid_image_indices[idx // self.tiles_per_image]
        tile_idx = idx % self.tiles_per_image
        tiles_input, tiles_mask = self.dataset[image_idx]
        return tiles_input[tile_idx], tiles_mask[tile_idx]


def tile_tensor_and_mask(input_tensor, mask_tensor, tile_size=(256, 256)):
    C, H, W = input_tensor.shape
    th, tw = tile_size
    tiles_input, tiles_mask = [], []
    for i in range(0, H, th):
        for j in range(0, W, tw):
            tiles_input.append(input_tensor[:, i:i+th, j:j+tw])
            tiles_mask.append(mask_tensor[i:i+th, j:j+tw])
    return tiles_input, tiles_mask

def unnormalize(img_tensor, mean, std):
    for t, m, s in zip(img_tensor, mean, std):
        t.mul_(s).add_(m)
    return img_tensor

def visualize_predictions(inputs, masks, outputs, n=4):
    preds = torch.argmax(outputs, dim=1).cpu().numpy()
    masks = masks.cpu().numpy()
    cmap = mcolors.ListedColormap(['black', 'blue', 'yellow', 'red'])
    bounds = [0, 1, 2, 3, 4]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    for i in range(min(n, inputs.shape[0])):
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        pre_rgb = unnormalize(inputs[i][:3].clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        axs[0].imshow(np.transpose(torch.clamp(pre_rgb, 0, 1).cpu().numpy(), (1, 2, 0)))
        axs[0].set_title("Pre-disaster RGB")
        axs[0].axis('off')
        axs[1].imshow(inputs[i][3].cpu().numpy(), cmap='gray')
        axs[1].set_title("Post-disaster SAR")
        axs[1].axis('off')
        axs[2].imshow(preds[i], cmap=cmap, norm=norm)
        axs[2].set_title("Predicted")
        axs[2].axis('off')
        plt.show()
        plt.figure()
        plt.imshow(masks[i], cmap=cmap, norm=norm)
        plt.title("Ground Truth")
        plt.axis('off')
        plt.show()
import matplotlib.pyplot as plt

def save_visual_predictions(inputs, masks, outputs, epoch, visuals_path, n=30):
    preds = torch.argmax(outputs, dim=1).cpu().numpy()
    masks_np = masks.cpu().numpy()
    
    cmap = mcolors.ListedColormap(['black', 'blue', 'yellow', 'red'])
    bounds = [0, 1, 2, 3, 4]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    for i in range(min(n, inputs.shape[0])):
        fig, axs = plt.subplots(1, 4, figsize=(16, 4))  # 4 subplots now

        # Pre-disaster RGB
        pre_rgb = unnormalize(inputs[i][:3].clone(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        axs[0].imshow(np.transpose(torch.clamp(pre_rgb, 0, 1).cpu().numpy(), (1, 2, 0)))
        axs[0].set_title("Pre-disaster RGB")
        axs[0].axis('off')

        # Post-disaster SAR
        axs[1].imshow(inputs[i][3].cpu().numpy(), cmap='gray')
        axs[1].set_title("Post-disaster SAR")
        axs[1].axis('off')

        # Predicted mask
        axs[2].imshow(preds[i], cmap=cmap, norm=norm)
        axs[2].set_title("Predicted")
        axs[2].axis('off')

        # Ground truth mask
        axs[3].imshow(masks_np[i], cmap=cmap, norm=norm)
        axs[3].set_title("Ground Truth")
        axs[3].axis('off')

        fig_path = os.path.join(visuals_path, f'epoch{epoch}_sample{i}.png')
        plt.savefig(fig_path, bbox_inches='tight')
        plt.close(fig)  # Close to free memory

print("Dataset and helper functions loaded")

# ============================
# Model training + metrics + visualization
# ============================

# Define your model class or import it here
# from your_model_file import UNetOriginal

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model & optimizer
model = UNetOriginal(in_channels=4, out_classes=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Transforms
transform_pre = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_post = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5])
])

data_root = r"C:\Users\sweta\.cache\kagglehub\datasets\qianlanzz\xbd-dataset\versions\1\xbd\tier1"
print("Loading dataset from:", data_root)

dataset = XBDMulticlassDataset(
    data_root,
    transform_pre=transform_pre,
    transform_post=transform_post,
    image_size=(1024, 1024),
    tile_size=(256, 256),
)

print(f"Loaded XBDMulticlassDataset with {len(dataset)} images")

filtered_dataset = FilteredTileDataset(dataset, keep_zero_damage_prob=0.1)
print(f"Filtered tiled dataset has {len(filtered_dataset)} tiles after zero-damage filtering")


train_loader = DataLoader(filtered_dataset, batch_size=8, shuffle=True, num_workers=0)
print("DataLoader created")

checkpoint_path = 'C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/model_epoch7.pt'
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path} ...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)  # checkpoint is state_dict only
    print("Checkpoint loaded. Model weights restored.")
    start_epoch = 8
else:
    print("No checkpoint found, starting from scratch")
    start_epoch = 0

num_epochs = 15  # total epochs

metrics_csv_path = r'C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/metrics_log.csv'
visuals_path = r'C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/visuals' #this is a folder

def compute_iou_per_class(preds, labels, num_classes=4):
    ious = []
    preds = preds.flatten()
    labels = labels.flatten()
    for cls in range(num_classes):
        pred_inds = preds == cls
        label_inds = labels == cls
        intersection = np.logical_and(pred_inds, label_inds).sum()
        union = np.logical_or(pred_inds, label_inds).sum()
        iou = intersection / union if union > 0 else 1.0
        ious.append(iou)
    return ious

def compute_dice_per_class(preds, labels, num_classes=4):
    dices = []
    preds = preds.flatten()
    labels = labels.flatten()
    for cls in range(num_classes):
        pred_inds = preds == cls
        label_inds = labels == cls
        intersection = 2 * np.logical_and(pred_inds, label_inds).sum()
        denominator = pred_inds.sum() + label_inds.sum()
        dice = intersection / denominator if denominator > 0 else 1.0
        dices.append(dice)
    return dices

for epoch in range(start_epoch, num_epochs):
    print(f"\n======== Epoch {epoch} ========")
    model.train()
    total_loss = 0
    for batch_idx, (inputs, masks) in enumerate(train_loader):
        print(f"Batch {batch_idx} - Loading data...")
        inputs, masks = inputs.to(device), masks.to(device)
        print(f"Batch {batch_idx} - Inputs shape: {inputs.shape}, Masks shape: {masks.shape}")

        optimizer.zero_grad()
        outputs = model(inputs)
        print(f"Batch {batch_idx} - Forward pass done. Outputs shape: {outputs.shape}")

        loss = criterion(outputs, masks)
        print(f"Batch {batch_idx} - Loss computed: {loss.item():.6f}")

        loss.backward()
        optimizer.step()
        print(f"Batch {batch_idx} - Backward pass and optimizer step done")

        total_loss += loss.item()

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch} Batch {batch_idx} Loss: {loss.item():.6f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch} Average Loss: {avg_loss:.6f}")

    # Save checkpoint
    checkpoint_file = f'model_epoch{epoch}.pt'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_file)
    print(f"Checkpoint saved: {checkpoint_file}")

    # Evaluate on entire filtered_dataset for metrics
    model.eval()
    print("Evaluating on entire filtered dataset for metrics...")
    all_preds = []
    all_masks = []

    with torch.no_grad():
        filtered_loader_eval = DataLoader(filtered_dataset, batch_size=8, shuffle=False, num_workers=0)
        for batch_idx_eval, (inputs_eval, masks_eval) in enumerate(filtered_loader_eval):
            inputs_eval, masks_eval = inputs_eval.to(device), masks_eval.to(device)
            outputs_eval = model(inputs_eval)
            preds_eval = torch.argmax(outputs_eval, dim=1).cpu().numpy()
            masks_eval_np = masks_eval.cpu().numpy()

            all_preds.append(preds_eval)
            all_masks.append(masks_eval_np)

    all_preds_np = np.concatenate(all_preds, axis=0)
    all_masks_np = np.concatenate(all_masks, axis=0)

    ious = compute_iou_per_class(all_preds_np, all_masks_np)
    dices = compute_dice_per_class(all_preds_np, all_masks_np)
    print(f"Epoch {epoch} IoU per class: {ious}")
    print(f"Epoch {epoch} Dice per class: {dices}")

    # Append metrics to CSV
    file_exists = os.path.isfile(metrics_csv_path)
    with open(metrics_csv_path, 'a', newline='') as f:
        writer = csv.writer(f)
        if not file_exists:
            print(f"Creating new metrics CSV at {metrics_csv_path}")
            writer.writerow(['epoch',
                             'iou_class_0', 'iou_class_1', 'iou_class_2', 'iou_class_3',
                             'dice_class_0', 'dice_class_1', 'dice_class_2', 'dice_class_3'])
        writer.writerow([epoch] + ious + dices)
        print(f"Logged metrics for epoch {epoch} to CSV")

    # Visualize predictions on a small batch
    with torch.no_grad():
        print("Saving visual predictions on a sample batch...")
        inputs_viz, masks_viz = next(iter(train_loader))
        inputs_viz, masks_viz = inputs_viz.to(device), masks_viz.to(device)
        outputs_viz = model(inputs_viz)
        save_visual_predictions(inputs_viz, masks_viz, outputs_viz, epoch, visuals_path, n=15)

    print(f"Epoch {epoch} complete\n")

Dataset and helper functions loaded
Using device: cpu
Loading dataset from: C:\Users\sweta\.cache\kagglehub\datasets\qianlanzz\xbd-dataset\versions\1\xbd\tier1
Loaded XBDMulticlassDataset with 2799 images
Filtering full images before tiling...
Checking image 0/2799
Checking image 100/2799
Checking image 200/2799
Checking image 300/2799
Checking image 400/2799
Checking image 500/2799
Checking image 600/2799
Checking image 700/2799
Checking image 800/2799
Checking image 900/2799
Checking image 1000/2799
Checking image 1100/2799
Checking image 1200/2799
Checking image 1300/2799
Checking image 1400/2799
Checking image 1500/2799
Checking image 1600/2799
Checking image 1700/2799
Checking image 1800/2799
Checking image 1900/2799
Checking image 2000/2799
Checking image 2100/2799
Checking image 2200/2799
Checking image 2300/2799
Checking image 2400/2799
Checking image 2500/2799
Checking image 2600/2799
Checking image 2700/2799
Kept 1476 of 2799 images after filtering.
Filtered tiled dataset has

  checkpoint = torch.load(checkpoint_path, map_location=device)


Batch 0 - Loading data...
Batch 0 - Inputs shape: torch.Size([8, 4, 256, 256]), Masks shape: torch.Size([8, 256, 256])
Batch 0 - Forward pass done. Outputs shape: torch.Size([8, 4, 256, 256])
Batch 0 - Loss computed: 0.026508
Batch 0 - Backward pass and optimizer step done
Epoch 8 Batch 0 Loss: 0.026508
Batch 1 - Loading data...
Batch 1 - Inputs shape: torch.Size([8, 4, 256, 256]), Masks shape: torch.Size([8, 256, 256])
Batch 1 - Forward pass done. Outputs shape: torch.Size([8, 4, 256, 256])
Batch 1 - Loss computed: 0.063511
Batch 1 - Backward pass and optimizer step done
Batch 2 - Loading data...
Batch 2 - Inputs shape: torch.Size([8, 4, 256, 256]), Masks shape: torch.Size([8, 256, 256])
Batch 2 - Forward pass done. Outputs shape: torch.Size([8, 4, 256, 256])
Batch 2 - Loss computed: 0.028944
Batch 2 - Backward pass and optimizer step done
Batch 3 - Loading data...
Batch 3 - Inputs shape: torch.Size([8, 4, 256, 256]), Masks shape: torch.Size([8, 256, 256])
Batch 3 - Forward pass done.

In [26]:
import os
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint_path = 'C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/model_epoch7.pt'

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path} ...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)  # load weights directly
    print("Checkpoint loaded. Model weights restored.")
else:
    print("Checkpoint not found.")


Loading checkpoint from C:/Users/sweta/anaconda_projects/non-trivial/performance_bias/iteration_models/model_epoch7.pt ...
Checkpoint loaded. Model weights restored.


  checkpoint = torch.load(checkpoint_path, map_location=device)
