In [None]:
import os
import numpy as np
import rasterio
from PIL import Image
import random

# path
imageDir = r"C:\Users\wasif\Desktop\fine_tuning\images"
maskDir = r"C:\Users\wasif\Desktop\fine_tuning\masks"
outputDir = r"C:\Users\wasif\Desktop\fineTuned"
# config
cropSize = 256
cropsPerImage = 200

# output
os.makedirs(os.path.join(outputDir, 'images'), exist_ok=True)
os.makedirs(os.path.join(outputDir, 'masks'), exist_ok=True)

# sort the files
imageFiles = sorted(f for f in os.listdir(imageDir) if f.endswith(".jp2"))
maskFiles = sorted(f for f in os.listdir(maskDir) if f.endswith(".png"))

if len(imageFiles) != len(maskFiles):
    print("Error: Mismatch between number of images and masks.")
    exit()

globalCounter = 0

# Process each img mask pair
for imgName, maskName in zip(imageFiles, maskFiles):
    imgPath = os.path.join(imageDir, imgName)
    maskPath = os.path.join(maskDir, maskName)

    print(f"\nProcessing {imgName} and {maskName}...")

    try:
        with rasterio.open(imgPath) as src:
            image = src.read()
            height, width = src.height, src.width
            numBands = src.count
            dtype = src.dtypes[0]
            crs = src.crs
            transformOrig = src.transform

        mask = np.array(Image.open(maskPath).convert("L"))

        if mask.shape != (height, width):
            print(f"Size mismatch. Skipping {imgName}.")
            continue

    except Exception as e:
        print(f"Error reading {imgName}: {e}")
        continue

    # all possible crops
    cropCoords = [(x, y) for y in range(0, height - cropSize + 1, cropSize)
                         for x in range(0, width - cropSize + 1, cropSize)]

    if not cropCoords:
        print("No crops possible (image too small). Skipping.")
        continue

    random.shuffle(cropCoords)

    saved = 0
    for x, y in cropCoords:
        if saved >= cropsPerImage:
            break

        cropImg = image[:, y:y+cropSize, x:x+cropSize]
        cropMask = mask[y:y+cropSize, x:x+cropSize]

        baseName = f"{os.path.splitext(imgName)[0]}_crop_{globalCounter}"
        imgOutPath = os.path.join(outputDir, 'images', f"{baseName}.jp2")
        maskOutPath = os.path.join(outputDir, 'masks', f"{baseName}.png")

        cropTransform = rasterio.transform.from_origin(
            transformOrig.c + x * transformOrig.a,
            transformOrig.f + y * transformOrig.e,
            transformOrig.a,
            transformOrig.e
        )

        try:
            with rasterio.open(
                imgOutPath, "w",
                driver="JP2OpenJPEG",
                height=cropSize,
                width=cropSize,
                count=numBands,
                dtype=dtype,
                crs=crs,
                transform=cropTransform
            ) as dst:
                dst.write(cropImg)

            Image.fromarray(cropMask).save(maskOutPath)

            globalCounter += 1
            saved += 1

        except Exception as e:
            print(f"Error saving crop {baseName}: {e}")

    print(f"Saved {saved} crops for {imgName}.")

print("\nCropping complete.")
print(f"Total crops saved: {globalCounter}")



Processing Image_50_stacked.jp2 and Image_50_stacked.png...




Saved 200 crops for Image_50_stacked.jp2.

Processing Image_52_stacked.jp2 and Image_52_stacked.png...
Saved 200 crops for Image_52_stacked.jp2.

Cropping complete.
Total crops saved: 400
