In [None]:
import os
import rasterio
from rasterio.windows import Window
import numpy as np
import cv2
import glob
from pathlib import Path

script_dir = Path.cwd()
project_dir = script_dir.parent
output_dir = project_dir / "dataset_final"
input_dir = "sat_images"

TILE_SIZE = 512
OVERLAP = 0                     # Keep 0 for distinct patches

# Normalization Thresholds
THRESH_WINTER = 8000  # Higher because snow is very bright
THRESH_SUMMER = 3000  # Standard for vegetation/urban

def get_band_paths(folder_path):
    """
    Finds B02, B03, B04, B08 in the folder recursively.
    """
    bands = {}
    search_path = os.path.join(folder_path, "**", "*.jp2")
    all_files = glob.glob(search_path, recursive=True)
    
    for f in all_files:
        # Check for the band code in the filename (more robust)
        if "_B02_" in f or f.endswith("B02.jp2"): bands['B02'] = f
        if "_B03_" in f or f.endswith("B03.jp2"): bands['B03'] = f
        if "_B04_" in f or f.endswith("B04.jp2"): bands['B04'] = f
        if "_B08_" in f or f.endswith("B08.jp2"): bands['B08'] = f
        
    if len(bands) < 4:
        # Only print warning if we found SOME jp2s but not all bands
        if len(all_files) > 0:
            print(f"Warning: Missing bands in {folder_path}. Found: {list(bands.keys())}")
        return None
    return bands

def get_threshold(folder_name):
    """
    Decides normalization threshold based on folder name.
    If 'winter' is in the name (e.g., 'winter_2024'), use high threshold.
    """
    if 'winter' in folder_name.lower():
        return THRESH_WINTER
    return THRESH_SUMMER

def normalize_band(band, threshold):
    """
    Normalizes 16-bit integers to 0-1 float using specific threshold.
    """
    # Clip values to the threshold to avoid "washed out" highlights
    band = np.clip(band, 0, threshold)
    # Scale to 0-1
    band = band.astype(np.float32) / float(threshold)
    return band

def process_location(tile_code, folder_name, input_folder):
    print(f"   Processing: {tile_code} -> {folder_name} ...")
    
    band_paths = get_band_paths(input_folder)
    if not band_paths: 
        print(f"   No valid bands found in {folder_name}")
        return

    # Determine threshold based on folder name (winter vs summer)
    threshold = get_threshold(folder_name)

    # Setup Output Directory: dataset_final/T36UYA/summer_prev/
    save_dir = os.path.join(output_dir, tile_code, folder_name)
    os.makedirs(save_dir, exist_ok=True)

    # Open bands
    with rasterio.open(band_paths['B04'], driver='JP2OpenJPEG') as src_r, \
         rasterio.open(band_paths['B03'], driver='JP2OpenJPEG') as src_g, \
         rasterio.open(band_paths['B02'], driver='JP2OpenJPEG') as src_b:

        width = src_r.width
        height = src_r.height
        
        step = TILE_SIZE - OVERLAP
        count = 0
        
        # Grid Loop
        for x in range(0, width, step):
            for y in range(0, height, step):
                
                w = min(TILE_SIZE, width - x)
                h = min(TILE_SIZE, height - y)
                
                # Strict check: Ignore incomplete edge tiles
                if w < TILE_SIZE or h < TILE_SIZE:
                    continue

                window = Window(x, y, w, h)
                
                # Read Data
                r = src_r.read(1, window=window)
                g = src_g.read(1, window=window)
                b = src_b.read(1, window=window)
                

                if np.mean(r) == 0 and np.mean(g) == 0: 
                    continue
                
                # Normalize
                r_norm = normalize_band(r, threshold)
                g_norm = normalize_band(g, threshold)
                b_norm = normalize_band(b, threshold)

                # Stack & Convert to 8-bit for PNG
                rgb = np.dstack((r_norm, g_norm, b_norm))
                rgb_uint8 = (rgb * 255).astype(np.uint8)
                
                # Convert RGB to BGR for OpenCV
                bgr_save = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2BGR)
                
                # Filename: MUST be consistent across seasons for matching
                filename = f"tile_{x}_{y}.png"
                save_path = os.path.join(save_dir, filename)
                
                cv2.imwrite(save_path, bgr_save)
                count += 1

    print(f"   Saved {count} tiles to {save_dir}")

def main():
    print("--- Sentinel-2 Pre-processor ---")
    
    if not os.path.exists(input_dir):
        print(f"Error: Input folder '{input_dir}' not found!")
        return

    # Iterate over Tile IDs (e.g., T36UYA, T36UVC)
    tile_folders = [f for f in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, f))]
    
    for tile_code in tile_folders:
        print(f"\nScanning Tile: {tile_code}")
        tile_path = os.path.join(input_dir, tile_code)
        
        # Iterate over ANY subfolder inside (winter, summer, summer_prev, etc.)
        subfolders = [f for f in os.listdir(tile_path) if os.path.isdir(os.path.join(tile_path, f))]
        
        if not subfolders:
            print(f"   Empty tile folder.")
            continue
            
        for folder_name in subfolders:
            season_path = os.path.join(tile_path, folder_name)
            process_location(tile_code, folder_name, season_path)

if __name__ == "__main__":
    main()

--- Sentinel-2 Pre-processor ---
Error: Input folder 'sat_images' not found!
