# Notebook 1: Dataset Preparation and Mask Generation

This notebook prepares the dataset for mural restoration by:
1. Loading ground truth images from Mural512
2. Generating synthetic damage masks (random and structured patterns)
3. Creating damaged versions of images
4. Organizing data into the project structure


In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
from scipy.ndimage import binary_dilation, binary_erosion
from skimage.draw import line, circle, polygon

# Add parent directory to path for utils
sys.path.append('..')
from utils import load_image, save_image, ensure_dir, get_image_files

# Set random seed for reproducibility
np.random.seed(42)
random.seed(42)

print("Libraries imported successfully!")


## 1. Setup Directories and Load Dataset


In [None]:
# Define paths
MURAL512_PATH = Path('../MuralDH/Mural512')
GROUND_TRUTH_DIR = Path('../data/ground_truth')
DAMAGED_DIR = Path('../data/damaged')
MASKS_DIR = Path('../data/masks')

# Create directories
ensure_dir(GROUND_TRUTH_DIR)
ensure_dir(DAMAGED_DIR)
ensure_dir(MASKS_DIR)

# Load all images from Mural512
image_files = get_image_files(MURAL512_PATH, '*.png')
print(f"Found {len(image_files)} images in Mural512")

# Display sample image info
if image_files:
    sample_img = load_image(image_files[0])
    print(f"\nSample image shape: {sample_img.shape}")
    print(f"Sample image dtype: {sample_img.dtype}")
    print(f"Sample image path: {image_files[0].name}")


## 2. Mask Generation Functions

We'll create two types of masks:
- **Random masks**: Circles, rectangles, irregular polygons
- **Structured masks**: Crack patterns, scratch lines, patch regions


In [None]:
def generate_random_circle_mask(height, width, min_radius=20, max_radius=80):
    """Generate a random circle mask."""
    center_x = np.random.randint(max_radius, width - max_radius)
    center_y = np.random.randint(max_radius, height - max_radius)
    radius = np.random.randint(min_radius, max_radius)
    
    mask = np.zeros((height, width), dtype=np.uint8)
    rr, cc = circle(center_y, center_x, radius, shape=(height, width))
    mask[rr, cc] = 255
    return mask


def generate_random_rectangle_mask(height, width, min_size=30, max_size=150):
    """Generate a random rectangle mask."""
    x1 = np.random.randint(0, width - max_size)
    y1 = np.random.randint(0, height - max_size)
    w = np.random.randint(min_size, max_size)
    h = np.random.randint(min_size, max_size)
    x2 = min(x1 + w, width)
    y2 = min(y1 + h, height)
    
    mask = np.zeros((height, width), dtype=np.uint8)
    mask[y1:y2, x1:x2] = 255
    return mask


def generate_random_polygon_mask(height, width, num_vertices=5):
    """Generate a random irregular polygon mask."""
    # Generate random vertices
    vertices = []
    for _ in range(num_vertices):
        x = np.random.randint(0, width)
        y = np.random.randint(0, height)
        vertices.append([x, y])
    
    # Create polygon
    mask = np.zeros((height, width), dtype=np.uint8)
    vertices = np.array(vertices)
    rr, cc = polygon(vertices[:, 1], vertices[:, 0], shape=(height, width))
    mask[rr, cc] = 255
    return mask


def generate_crack_mask(height, width, num_cracks=3):
    """Generate crack pattern mask."""
    mask = np.zeros((height, width), dtype=np.uint8)
    
    for _ in range(num_cracks):
        # Random start and end points
        x1 = np.random.randint(0, width)
        y1 = np.random.randint(0, height)
        x2 = np.random.randint(0, width)
        y2 = np.random.randint(0, height)
        
        # Draw line
        rr, cc = line(y1, x1, y2, x2)
        # Keep within bounds
        valid = (rr >= 0) & (rr < height) & (cc >= 0) & (cc < width)
        mask[rr[valid], cc[valid]] = 255
        
        # Add some thickness with dilation
        mask = binary_dilation(mask, structure=np.ones((3, 3))).astype(np.uint8) * 255
    
    return mask


def generate_scratch_mask(height, width, num_scratches=5):
    """Generate scratch lines mask."""
    mask = np.zeros((height, width), dtype=np.uint8)
    
    for _ in range(num_scratches):
        # Random start point
        x1 = np.random.randint(0, width)
        y1 = np.random.randint(0, height)
        
        # Random direction and length
        angle = np.random.uniform(0, 2 * np.pi)
        length = np.random.randint(50, 200)
        
        x2 = int(x1 + length * np.cos(angle))
        y2 = int(y1 + length * np.sin(angle))
        
        # Keep within bounds
        x2 = np.clip(x2, 0, width - 1)
        y2 = np.clip(y2, 0, height - 1)
        
        # Draw line
        rr, cc = line(y1, x1, y2, x2)
        valid = (rr >= 0) & (rr < height) & (cc >= 0) & (cc < width)
        mask[rr[valid], cc[valid]] = 255
    
    # Add thickness
    mask = binary_dilation(mask, structure=np.ones((2, 2))).astype(np.uint8) * 255
    return mask


def generate_patch_mask(height, width, num_patches=2):
    """Generate patch region mask (larger damaged areas)."""
    mask = np.zeros((height, width), dtype=np.uint8)
    
    for _ in range(num_patches):
        # Random patch size and position
        patch_h = np.random.randint(50, min(200, height // 3))
        patch_w = np.random.randint(50, min(200, width // 3))
        
        x = np.random.randint(0, width - patch_w)
        y = np.random.randint(0, height - patch_h)
        
        mask[y:y+patch_h, x:x+patch_w] = 255
    
    return mask


def generate_combined_mask(height, width, mask_type='random'):
    """
    Generate a combined mask with multiple damage patterns.
    
    Args:
        height: Image height
        width: Image width
        mask_type: 'random', 'structured', or 'both'
    """
    mask = np.zeros((height, width), dtype=np.uint8)
    
    if mask_type == 'random' or mask_type == 'both':
        # Add random shapes
        num_shapes = np.random.randint(2, 6)
        for _ in range(num_shapes):
            shape_type = np.random.choice(['circle', 'rectangle', 'polygon'])
            if shape_type == 'circle':
                shape_mask = generate_random_circle_mask(height, width)
            elif shape_type == 'rectangle':
                shape_mask = generate_random_rectangle_mask(height, width)
            else:
                shape_mask = generate_random_polygon_mask(height, width)
            mask = np.maximum(mask, shape_mask)
    
    if mask_type == 'structured' or mask_type == 'both':
        # Add structured patterns
        if np.random.random() > 0.5:
            crack_mask = generate_crack_mask(height, width)
            mask = np.maximum(mask, crack_mask)
        if np.random.random() > 0.5:
            scratch_mask = generate_scratch_mask(height, width)
            mask = np.maximum(mask, scratch_mask)
        if np.random.random() > 0.3:
            patch_mask = generate_patch_mask(height, width)
            mask = np.maximum(mask, patch_mask)
    
    # Ensure mask covers 5-25% of image
    mask_ratio = np.sum(mask > 0) / (height * width)
    if mask_ratio < 0.05:
        # Add more damage
        additional_mask = generate_random_rectangle_mask(height, width, 40, 100)
        mask = np.maximum(mask, additional_mask)
    elif mask_ratio > 0.25:
        # Reduce mask size
        mask = binary_erosion(mask, structure=np.ones((5, 5))).astype(np.uint8) * 255
    
    return mask

print("Mask generation functions defined!")


In [None]:
# Load a sample image for visualization
if image_files:
    sample_img = load_image(image_files[0])
    h, w = sample_img.shape[:2]
    
    # Generate different mask types
    random_mask = generate_combined_mask(h, w, 'random')
    structured_mask = generate_combined_mask(h, w, 'structured')
    combined_mask = generate_combined_mask(h, w, 'both')
    
    # Visualize
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    axes[0, 0].imshow(sample_img)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(random_mask, cmap='gray')
    axes[0, 1].set_title('Random Mask')
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(cv2.bitwise_and(sample_img, sample_img, mask=random_mask))
    axes[0, 2].set_title('Random Mask Applied')
    axes[0, 2].axis('off')
    
    axes[1, 0].imshow(structured_mask, cmap='gray')
    axes[1, 0].set_title('Structured Mask')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(combined_mask, cmap='gray')
    axes[1, 1].set_title('Combined Mask')
    axes[1, 1].axis('off')
    
    # Apply mask to create damaged image
    damaged = sample_img.copy()
    damaged[combined_mask == 255] = 0  # Set damaged regions to black
    
    axes[1, 2].imshow(damaged)
    axes[1, 2].set_title('Damaged Image')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Mask coverage: {np.sum(combined_mask > 0) / (h * w) * 100:.2f}%")


In [None]:
def create_damaged_image(image, mask, damage_type='black'):
    """
    Create damaged image by applying mask.
    
    Args:
        image: Original image
        mask: Binary mask (255 for damaged regions)
        damage_type: 'black', 'white', or 'noise'
    """
    damaged = image.copy()
    mask_bool = mask == 255
    
    if damage_type == 'black':
        damaged[mask_bool] = 0
    elif damage_type == 'white':
        damaged[mask_bool] = 255
    elif damage_type == 'noise':
        noise = np.random.randint(0, 256, size=damaged[mask_bool].shape, dtype=np.uint8)
        damaged[mask_bool] = noise
    else:
        damaged[mask_bool] = 0
    
    return damaged

print("Damage function defined!")


In [None]:
# Process all images
print(f"Processing {len(image_files)} images...")

# Statistics
mask_types_used = {'random': 0, 'structured': 0, 'both': 0}
damage_types_used = {'black': 0, 'white': 0, 'noise': 0}
mask_coverage_stats = []

for idx, img_path in enumerate(tqdm(image_files, desc="Processing images")):
    # Load image
    img = load_image(img_path)
    h, w = img.shape[:2]
    
    # Determine mask type (alternate for variety)
    if idx % 3 == 0:
        mask_type = 'random'
    elif idx % 3 == 1:
        mask_type = 'structured'
    else:
        mask_type = 'both'
    mask_types_used[mask_type] += 1
    
    # Generate mask
    mask = generate_combined_mask(h, w, mask_type)
    mask_coverage = np.sum(mask > 0) / (h * w)
    mask_coverage_stats.append(mask_coverage)
    
    # Determine damage type (alternate)
    damage_type = ['black', 'white', 'noise'][idx % 3]
    damage_types_used[damage_type] += 1
    
    # Create damaged image
    damaged_img = create_damaged_image(img, mask, damage_type)
    
    # Get base filename
    base_name = img_path.stem
    
    # Save files
    save_image(img, GROUND_TRUTH_DIR / f"{base_name}.png")
    save_image(damaged_img, DAMAGED_DIR / f"{base_name}.png")
    cv2.imwrite(str(MASKS_DIR / f"{base_name}.png"), mask)

print("\nProcessing complete!")
print(f"\nMask type distribution:")
for k, v in mask_types_used.items():
    print(f"  {k}: {v}")
print(f"\nDamage type distribution:")
for k, v in damage_types_used.items():
    print(f"  {k}: {v}")
print(f"\nMask coverage statistics:")
print(f"  Mean: {np.mean(mask_coverage_stats)*100:.2f}%")
print(f"  Std: {np.std(mask_coverage_stats)*100:.2f}%")
print(f"  Min: {np.min(mask_coverage_stats)*100:.2f}%")
print(f"  Max: {np.max(mask_coverage_stats)*100:.2f}%")


## 5. Verify Data Consistency


In [None]:
# Verify all files are created
gt_files = get_image_files(GROUND_TRUTH_DIR)
damaged_files = get_image_files(DAMAGED_DIR)
mask_files = get_image_files(MASKS_DIR)

print(f"Ground truth images: {len(gt_files)}")
print(f"Damaged images: {len(damaged_files)}")
print(f"Mask images: {len(mask_files)}")

# Check if counts match
if len(gt_files) == len(damaged_files) == len(mask_files) == len(image_files):
    print("\n✓ All files created successfully! Counts match.")
else:
    print("\n⚠ Warning: File counts don't match!")

# Verify a sample file
if gt_files and damaged_files and mask_files:
    sample_name = gt_files[0].stem
    gt_sample = load_image(gt_files[0])
    damaged_sample = load_image(damaged_files[0])
    mask_sample = cv2.imread(str(mask_files[0]), cv2.IMREAD_GRAYSCALE)
    
    print(f"\nSample verification ({sample_name}):")
    print(f"  Ground truth shape: {gt_sample.shape}")
    print(f"  Damaged shape: {damaged_sample.shape}")
    print(f"  Mask shape: {mask_sample.shape}")
    
    if gt_sample.shape == damaged_sample.shape == (mask_sample.shape[0], mask_sample.shape[1], 3):
        print("  ✓ Shapes match!")
    else:
        print("  ⚠ Shape mismatch!")


## 6. Summary

Dataset preparation complete! The following data has been created:
- **Ground truth images**: `data/ground_truth/` (5096 images)
- **Damaged images**: `data/damaged/` (5096 images)
- **Masks**: `data/masks/` (5096 masks)

All images are ready for restoration methods in subsequent notebooks.
