In [None]:
# xBD Pipeline: Data Loading and Augmentation

This notebook demonstrates how to:
1. Load and preprocess xBD dataset images
2. Apply data augmentation
3. Visualize the results


In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from src.utils.transforms import normalize_image, augment_image_mask, resize_image_mask
from src.data.dataset import XBDDataset


In [None]:
## 1. Load Dataset


In [None]:
# Initialize dataset
data_dir = Path('../Data/train')
dataset = XBDDataset(
    data_dir=data_dir,
    image_size=(512, 512),
    augment=False
)

print(f'Dataset size: {len(dataset)}')


In [None]:
## 2. Visualize Sample Images


In [None]:
def plot_sample(image, mask, title):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    
    ax1.imshow(image)
    ax1.set_title('Image')
    ax1.axis('off')
    
    ax2.imshow(mask, cmap='gray')
    ax2.set_title('Mask')
    ax2.axis('off')
    
    plt.suptitle(title)
    plt.show()

# Get a sample
image, mask = dataset[0]
plot_sample(image.numpy().transpose(1, 2, 0), mask.numpy().squeeze(), 'Original Sample')


In [None]:
## 3. Data Augmentation Examples


In [None]:
# Convert tensors back to numpy for augmentation
image_np = image.numpy().transpose(1, 2, 0)
mask_np = mask.numpy().squeeze()

# Apply different augmentations
augmentations = [
    {'do_flip': True, 'do_rotate': False},
    {'do_flip': False, 'do_rotate': True},
    {'do_flip': True, 'do_rotate': True}
]

for i, aug_params in enumerate(augmentations):
    aug_img, aug_mask = augment_image_mask(image_np, mask_np, **aug_params)
    plot_sample(aug_img, aug_mask, f'Augmentation {i+1}')


In [None]:
## 4. Batch Loading Example


In [None]:
from torch.utils.data import DataLoader

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0
)

# Get a batch
images, masks = next(iter(dataloader))

# Plot batch
fig, axes = plt.subplots(4, 2, figsize=(12, 24))
for i in range(4):
    axes[i, 0].imshow(images[i].numpy().transpose(1, 2, 0))
    axes[i, 0].set_title(f'Image {i+1}')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(masks[i].numpy().squeeze(), cmap='gray')
    axes[i, 1].set_title(f'Mask {i+1}')
    axes[i, 1].axis('off')

plt.tight_layout()
plt.show()
