In [None]:
"""
Author: Wasif Rasool Qazi
Cropping script for Sentinel-2-Dataset
"""


import os
import numpy as np
import rasterio
from PIL import Image

# Image paths
imageDir = r"C:\Users\wasif\Desktop\AugTest\images"
maskDir = r"C:\Users\wasif\Desktop\AugTest\masks"
OutputPth = r"C:\Users\wasif\Desktop\Cropped_balanced"

# Threshhold
cropSize = 256
maxCropsPerCategory = 10000
targetCropsPerImage = 200   

# Categories
categories = ['land', 'water', 'mixed']
for cat in categories:
    os.makedirs(os.path.join(OutputPth, cat, 'images'), exist_ok=True)
    os.makedirs(os.path.join(OutputPth, cat, 'masks'), exist_ok=True)

# cat counter

categoryCounters = {cat: 0 for cat in categories}

# Classify crop based on mask 
def classifyCrop(maskCrop):
    if np.all(maskCrop == 255):
        return "land"
    elif np.all(maskCrop == 0):
        return "water"
    else:
        return "mixed"

# crop Processing
imageFiles = sorted([f for f in os.listdir(imageDir) if f.endswith(".jp2")])
maskFiles = sorted([f for f in os.listdir(maskDir) if f.endswith(".png")])

for imgName, maskName in zip(imageFiles, maskFiles):
    imgPath = os.path.join(imageDir, imgName)
    maskPath = os.path.join(maskDir, maskName)

    with rasterio.open(imgPath) as src:
        image = src.read()
        meta = src.meta

    mask = np.array(Image.open(maskPath).convert("L"))
    height, width = mask.shape
    numBands = image.shape[0]

    print(f"--- Processing {imgName} and {maskName} ---")

    cropCount = 0
    for y in range(0, height - cropSize + 1, cropSize):
        for x in range(0, width - cropSize + 1, cropSize):
            if cropCount >= targetCropsPerImage:
                break

            cropImg = image[:, y:y + cropSize, x:x + cropSize]
            cropMask = mask[y:y + cropSize, x:x + cropSize]
            label = classifyCrop(cropMask)

            if categoryCounters[label] >= maxCropsPerCategory:
                continue

            cropIndex = categoryCounters[label]
            baseName = f"{label}_crop_{cropIndex}"

            imagePth = os.path.join(OutputPth, label, 'images', f"{baseName}.jp2")
            maskOutPath = os.path.join(OutputPth, label, 'masks', f"{baseName}.png")

            transform = rasterio.transform.from_origin(
                meta["transform"][2] + x * meta["transform"][0],
                meta["transform"][5] + y * meta["transform"][4],
                meta["transform"][0],
                meta["transform"][4]
            )

            try:
                with rasterio.open(
                    imagePth, "w",
                    driver="JP2OpenJPEG",
                    height=cropSize,
                    width=cropSize,
                    count=numBands,
                    dtype=image.dtype,
                    crs=meta["crs"],
                    transform=transform
                ) as dst:
                    dst.write(cropImg)

                Image.fromarray(cropMask).save(maskOutPath)
                categoryCounters[label] += 1
                cropCount += 1

            except Exception as e:
                print(f"Error saving crop {baseName},{e}")

    print(f"Finished processing: {imgName} — {cropCount} crops saved.\n")

# check
print("Cropping completed.")
for cat in categories:
    print(f"{cat.capitalize()} crops saved: {categoryCounters[cat]}")


Cropping completed.
Land crops saved: 0
Water crops saved: 0
Mixed crops saved: 0
