Stitches mask patches back together while preserving the original raster projection information from georeferenced TIFF files.

In [1]:
from PIL import Image
import numpy as np
import rasterio
from rasterio.transform import from_bounds
import glob
import os
from pathlib import Path
import re

In [9]:
# Specify the original georeferenced image and mask patches directory
map_name = "Selworthy"
original_image_path = f"../../Full/{map_name}.tif"
patch_size = 256
mask_patches_dir = f"{map_name}/{map_name}_predictions/masks"  # Directory containing mask patches

# Load original image metadata
with rasterio.open(original_image_path) as src:
    original_crs = src.crs
    original_transform = src.transform
    original_bounds = src.bounds
    original_height = src.height
    original_width = src.width
    original_dtype = src.dtypes[0]
    
print(f"Original image dimensions: {original_width} x {original_height}")
print(f"Original CRS: {original_crs}")
print(f"Original transform: {original_transform}")
print(f"Original bounds: {original_bounds}")

Original image dimensions: 16534 x 15592
Original CRS: LOCAL_CS["OSGB 1936 / British National Grid",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]
Original transform: | 0.50, 0.00, 288217.83|
| 0.00,-0.50, 150356.23|
| 0.00, 0.00, 1.00|
Original bounds: BoundingBox(left=288217.8345919344, bottom=142560.22751245857, right=296484.8345919344, top=150356.22751245857)


In [10]:
# Find all mask patch files
patch_files = glob.glob(os.path.join(mask_patches_dir, "patch_*.png"))
patch_files.sort()  # Sort to ensure proper order

if not patch_files:
    raise FileNotFoundError(f"No patch files found in {mask_patches_dir}")

print(f"Found {len(patch_files)} mask patches")

# Extract patch coordinates from filenames and organize
patch_dict = {}
max_i, max_j = 0, 0

for patch_file in patch_files:
    filename = os.path.basename(patch_file)
    # Extract i, j coordinates from filename (assuming format: patch_i_j.png)
    match = re.search(r'patch_(\d+)_(\d+)', filename)
    if match:
        i, j = int(match.group(1)), int(match.group(2))
        patch_dict[(i, j)] = patch_file
        max_i = max(max_i, i)
        max_j = max(max_j, j)

grid_rows = max_i + 1
grid_cols = max_j + 1
print(f"Patch grid dimensions: {grid_rows} x {grid_cols}")

# Calculate the padded dimensions based on patch grid
padded_height = grid_rows * patch_size
padded_width = grid_cols * patch_size

# Calculate the padding that was originally applied
pad_h = padded_height - original_height
pad_w = padded_width - original_width

print(f"Padded dimensions: {padded_width} x {padded_height}")
print(f"Original padding applied: height={pad_h}, width={pad_w}")
print(f"Will crop final image to: {original_width} x {original_height}")

# Load a sample patch to determine the number of channels
sample_patch = Image.open(list(patch_dict.values())[0])
sample_array = np.array(sample_patch)

# Determine if patches are grayscale or RGB
if len(sample_array.shape) == 2:
    # Grayscale
    stitched_array = np.zeros((padded_height, padded_width), dtype=sample_array.dtype)
    channels = 1
else:
    # RGB/RGBA
    channels = sample_array.shape[2]
    stitched_array = np.zeros((padded_height, padded_width, channels), dtype=sample_array.dtype)

print(f"Mask patches are {'grayscale' if channels == 1 else f'{channels}-channel'}")

# Stitch patches together
for i in range(grid_rows):
    for j in range(grid_cols):
        if (i, j) in patch_dict:
            # Load patch
            patch_img = Image.open(patch_dict[(i, j)])
            patch_array = np.array(patch_img)
            
            # Calculate position in the full array
            start_row = i * patch_size
            end_row = start_row + patch_size
            start_col = j * patch_size
            end_col = start_col + patch_size
            
            # Place patch in the full array
            if channels == 1:
                stitched_array[start_row:end_row, start_col:end_col] = patch_array
            else:
                stitched_array[start_row:end_row, start_col:end_col, :] = patch_array
        else:
            print(f"Warning: Missing patch at position ({i}, {j})")

print("Patches stitched successfully")

# Remove padding to restore original dimensions
if channels == 1:
    final_array = stitched_array[:original_height, :original_width]
else:
    final_array = stitched_array[:original_height, :original_width, :]

print(f"Final array shape after cropping: {final_array.shape}")
print(f"Expected shape: ({original_height}, {original_width}{', ' + str(channels) if channels > 1 else ''})")

# Verify the dimensions match
expected_shape = (original_height, original_width) if channels == 1 else (original_height, original_width, channels)
assert final_array.shape == expected_shape, f"Shape mismatch: got {final_array.shape}, expected {expected_shape}"

Found 3965 mask patches
Patch grid dimensions: 61 x 65
Padded dimensions: 16640 x 15616
Original padding applied: height=24, width=106
Will crop final image to: 16534 x 15592
Mask patches are grayscale
Patches stitched successfully
Final array shape after cropping: (15592, 16534)
Expected shape: (15592, 16534)


In [11]:
# Prepare output filename
output_filename = f"{map_name}/{map_name}_stitched_mask.tif"

# Determine the appropriate data type for the output
# Common mask dtypes: uint8 for 0-255 values, bool for binary masks
output_dtype = final_array.dtype

# Prepare array for rasterio (handle different channel configurations)
if channels == 1:
    # For grayscale, rasterio expects (bands, height, width)
    output_array = final_array.reshape(1, original_height, original_width)
    count = 1
else:
    # For multi-channel, transpose to (bands, height, width)
    output_array = np.transpose(final_array, (2, 0, 1))
    count = channels

# Save the georeferenced TIFF
with rasterio.open(
    output_filename,
    'w',
    driver='GTiff',
    height=original_height,
    width=original_width,
    count=count,
    dtype=output_dtype,
    crs=original_crs,
    transform=original_transform,
    compress='lzw'  # Optional compression
) as dst:
    dst.write(output_array)

print(f"Georeferenced mask saved as: {output_filename}")
print(f"Output dimensions: {original_width} x {original_height}")
print(f"Output CRS: {original_crs}")
print(f"Output channels: {count}")

# Verify the output file
with rasterio.open(output_filename) as verify:
    print(f"\nVerification:")
    print(f"Saved file CRS: {verify.crs}")
    print(f"Saved file transform: {verify.transform}")
    print(f"Saved file bounds: {verify.bounds}")
    print(f"Saved file shape: {verify.shape}")

Georeferenced mask saved as: Selworthy/Selworthy_stitched_mask.tif
Output dimensions: 16534 x 15592
Output CRS: LOCAL_CS["OSGB 1936 / British National Grid",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]
Output channels: 1

Verification:
Saved file CRS: LOCAL_CS["OSGB 1936 / British National Grid",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]
Saved file transform: | 0.50, 0.00, 288217.83|
| 0.00,-0.50, 150356.23|
| 0.00, 0.00, 1.00|
Saved file bounds: BoundingBox(left=288217.8345919344, bottom=142560.22751245857, right=296484.8345919344, top=150356.22751245857)
Saved file shape: (15592, 16534)
