In [None]:
import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.transform import resize
from netCDF4 import Dataset
from pyrosm import OSM, get_data
import numpy as np
from skimage.transform import rotate, rescale, AffineTransform, warp
from scipy.ndimage import shift

def reproject_sentinel_to_utm(sentinel_band_file, utm_zone, bands=["B02", "B03", "B04", "B08"]):
    ds = xr.load_dataset(sentinel_band_file)
    band_data = ds[bands]
    
    utm_crs = f"EPSG:{32600 + utm_zone}"  # Assuming Northern Hemisphere
    if 'crs' not in band_data.attrs:
        band_data = band_data.rio.write_crs(utm_crs, inplace=True)
    else:
        band_data = band_data.rio.reproject(utm_crs)
    
    return band_data

def plot_overlay(sentinel_band_file, jpeg_path, bands, time):
    ds = sentinel_band_file
    band_data = ds[bands].to_array(dim="bands")
    
    fig, ax = plt.subplots(figsize=(8, 8))
    band_data[{"t": time}].plot.imshow(ax=ax, vmin=0, vmax=3500)
    plt.title(f"Sentinel-2 Natural Color Composite ({', '.join(bands)})")
    plt.show()
    
    # Read and plot the JPEG image
    img = mpimg.imread(jpeg_path)
    
    # Get the extent of the Sentinel-2 data
    xmin, ymin, xmax, ymax = band_data.rio.bounds()
    
    # Overlay the JPEG image with the correct extent
    fig, ax = plt.subplots(figsize=(10, 8))
    band_data[{"t": time}].plot.imshow(ax=ax, vmin=0, vmax=2000, cmap='gray')
    ax.imshow(img, extent=[xmin, xmax, ymin, ymax], alpha=0.6)
    
    plt.title("Overlay of Sentinel-2 Band and JPEG Image")
    plt.xlabel("x coordinate of projection [m]")
    plt.ylabel("y coordinate of projection [m]")
    plt.show()


def detect_clouds(patch_data, blue_threshold=2000, nir_threshold=5000, max_high_pixels=750):
    blue_band = patch_data[0]  # B02
    nir_band = patch_data[3]   # B08
    
    # Identify unique pixels that are high-value in either band
    high_value_pixels = np.logical_or(blue_band > blue_threshold, nir_band > nir_threshold)
    
    # Count the number of unique high-value pixels
    high_value_pixels_count = np.sum(high_value_pixels)
    
    # Debug information
    print(f"Number of unique high-value pixels: {high_value_pixels_count}")
    
    # Return True if the number of high-valued pixels exceeds the threshold
    return high_value_pixels_count > max_high_pixels

def create_and_save_patches(band_array, jpeg_path, patch_size, stride, output_dir, bands, time_index=1):
    # ... (keep the existing code up to the patch creation loop)
    img = mpimg.imread(jpeg_path)
    cloud_counter = 0
    # Debug: Print shape of the band array and the JPEG image
    print(f"Shape of band_array: {band_array.shape}")
    print(f"Shape of JPEG image: {img.shape}")
    
    # Resize JPEG image to match the Sentinel-2 data dimensions
    sentinel_height, sentinel_width = band_array.shape[2], band_array.shape[3]
    resized_img = resize(img, (sentinel_height, sentinel_width, img.shape[2]), anti_aliasing=True)
    
    # Debug: Print shape of the resized JPEG image
    print(f"Shape of resized JPEG image: {resized_img.shape}")
    
    # Calculate the number of patches in x and y directions ensuring no overlap
    num_patches_x = max(1, int((sentinel_width - patch_size) / patch_size) + 1)
    num_patches_y = max(1, int((sentinel_height - patch_size) / patch_size) + 1)
    
    # Debug: Print number of patches in X and Y directions
    print(f"Total patches to be created: {num_patches_x * num_patches_y}")
    print(f"Number of patches in X direction: {num_patches_x}")
    print(f"Number of patches in Y direction: {num_patches_y}")
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Iterate over the grid to create patches
    patch_index = 0
    for i in range(num_patches_y):
        for j in range(num_patches_x):
            # Extract the patch from the Sentinel-2 data
            sentinel_patch = band_array[
                :,  # Keep all bands
                time_index,  # Use the specified time index
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size
            ]
            
            # Flip the Sentinel-2 patch horizontally
            # sentinel_patch = np.flip(sentinel_patch, axis=1)
            print(f"checking cloudy patch at ({i}, {j})")
            if detect_clouds(sentinel_patch):
                cloud_counter+=1
                print(f"Skipping cloudy patch at ({i}, {j})")
                continue
            # Ensure the patch is of the correct size
            if sentinel_patch.shape[1] != patch_size or sentinel_patch.shape[2] != patch_size:
                print(f"Skipping Sentinel patch at ({i}, {j}) due to incorrect size.")
                continue
            
            # Extract the corresponding patch from the resized JPEG image
            jpeg_patch = resized_img[
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size,
                :
            ]
            
            # Ensure the patch is of the correct size
            if jpeg_patch.shape[0] != patch_size or jpeg_patch.shape[1] != patch_size:
                print(f"Skipping JPEG patch at ({i}, {j}) due to incorrect size.")
                continue
            
            # Save the Sentinel-2 patch as a NetCDF file
            sentinel_patch_file = os.path.join(output_dir, f'sentinel_patch_{patch_index}.nc')
            sentinel_patch_ds = xr.Dataset(
                {
                    'B02': (['y', 'x'], sentinel_patch[0]),
                    'B03': (['y', 'x'], sentinel_patch[1]),
                    'B04': (['y', 'x'], sentinel_patch[2]),
                    'B08': (['y', 'x'], sentinel_patch[3])
                }
            )
            sentinel_patch_ds.to_netcdf(sentinel_patch_file)
            
            # Save the JPEG patch as a JPEG file
            jpeg_patch_file = os.path.join(output_dir, f'jpeg_patch_{patch_index}.jpg')
            plt.imsave(jpeg_patch_file, jpeg_patch)
            
            print(f"Patch {patch_index}:")
            print(f"  Sentinel patch file: {sentinel_patch_file}")
            print(f"  JPEG patch file: {jpeg_patch_file}")
            
            patch_index += 1
    
    print(f'Created {patch_index} patches and saved to {output_dir}')
    print(f"cloud patches removed {cloud_counter}")

    def augment_patch(patch, augmentation_type):
        if augmentation_type == 'rotate':
            angle = np.random.randint(-30, 30)
            return rotate(patch, angle, mode='reflect', preserve_range=True)
        elif augmentation_type == 'shear':
            shear = np.random.uniform(-0.2, 0.2)
            transform = AffineTransform(shear=shear)
            return warp(patch, transform, mode='reflect', preserve_range=True)
        elif augmentation_type == 'reflect':
            return np.fliplr(patch)
        else:
            return patch

    augmentation_types = ['rotate', 'shear', 'reflect']
    
    # Iterate over the grid to create patches
    patch_index = 0
    for i in range(num_patches_y):
        for j in range(num_patches_x):
            # Extract the patch from the Sentinel-2 data
            sentinel_patch = band_array[
                :,  # Keep all bands
                time_index,  # Use the specified time index
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size
            ]
            
            print(f"checking cloudy patch at ({i}, {j})")
            if detect_clouds(sentinel_patch):
                cloud_counter += 1
                print(f"Skipping cloudy patch at ({i}, {j})")
                continue

            # Ensure the patch is of the correct size
            if sentinel_patch.shape[1] != patch_size or sentinel_patch.shape[2] != patch_size:
                print(f"Skipping Sentinel patch at ({i}, {j}) due to incorrect size.")
                continue
            
            # Extract the corresponding patch from the resized JPEG image
            jpeg_patch = resized_img[
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size,
                :
            ]
            
            # Ensure the patch is of the correct size
            if jpeg_patch.shape[0] != patch_size or jpeg_patch.shape[1] != patch_size:
                print(f"Skipping JPEG patch at ({i}, {j}) due to incorrect size.")
                continue
            
            # Save original patches
            save_patches(sentinel_patch, jpeg_patch, patch_index, output_dir)

            # Apply augmentations and save augmented patches
            for aug_type in augmentation_types:
                aug_sentinel_patch = np.array([augment_patch(band, aug_type) for band in sentinel_patch])
                aug_jpeg_patch = augment_patch(jpeg_patch, aug_type)
                
                save_patches(aug_sentinel_patch, aug_jpeg_patch, patch_index, output_dir, aug_type)
            
            patch_index += 1
    
    print(f'Created {patch_index * (len(augmentation_types) + 1)} patches (including augmentations) and saved to {output_dir}')
    print(f"Cloud patches removed: {cloud_counter}")

def save_patches(sentinel_patch, jpeg_patch, patch_index, output_dir, aug_type=''):
    # Save the Sentinel-2 patch as a NetCDF file
    sentinel_patch_file = os.path.join(output_dir, f'sentinel_patch_{patch_index}{aug_type}.nc')
    sentinel_patch_ds = xr.Dataset(
        {
            'B02': (['y', 'x'], sentinel_patch[0]),
            'B03': (['y', 'x'], sentinel_patch[1]),
            'B04': (['y', 'x'], sentinel_patch[2]),
            'B08': (['y', 'x'], sentinel_patch[3])
        }
    )
    sentinel_patch_ds.to_netcdf(sentinel_patch_file)
    
    # Save the JPEG patch as a JPEG file
    jpeg_patch_file = os.path.join(output_dir, f'jpeg_patch_{patch_index}{aug_type}.jpg')
    plt.imsave(jpeg_patch_file, jpeg_patch)
    
    print(f"Patch {patch_index}{aug_type}:")
    print(f"  Sentinel patch file: {sentinel_patch_file}")
    print(f"  JPEG patch file: {jpeg_patch_file}")

def visualize_and_overlay_patches(patch_folder, num_patches, bands):
    patch_files = [f for f in os.listdir(patch_folder) if f.startswith('sentinel_patch_') and f.endswith('.nc')]
    jpeg_files = [f for f in os.listdir(patch_folder) if f.startswith('jpeg_patch_') and f.endswith('.jpg')]
    patch_files.sort()
    jpeg_files.sort()

    for i in range(min(num_patches, len(patch_files))):
        patch_file = patch_files[i]
        jpeg_file = jpeg_files[i]
        
        patch_file_path = os.path.join(patch_folder, patch_file)
        jpeg_file_path = os.path.join(patch_folder, jpeg_file)

        # Read NetCDF patch
        with Dataset(patch_file_path, 'r') as ds:
            patch_data = np.array([ds[band][:] for band in bands])
            band_data = xr.DataArray(patch_data, dims=("bands", "y", "x"), coords={"bands": bands})

        # Ensure the patch is 3D for RGB display
        if patch_data.shape[0] >= 3:
            sentinel_rgb_patch = np.stack([patch_data[2], patch_data[1], patch_data[0]], axis=-1)
        else:
            raise ValueError("Sentinel patch must have at least three bands for RGB visualization.")

        # Load JPEG patch
        jpeg_patch = mpimg.imread(jpeg_file_path)
        
        # Plot Sentinel-2 patch using band_data
        fig, ax = plt.subplots(figsize=(8, 8))
        # Plot overlay of Sentinel-2 patch and JPEG patch
        ax.imshow(sentinel_rgb_patch, vmin=0, vmax=3500)
        ax.imshow(jpeg_patch, alpha=0.6)
        ax.set_title(f'Overlay Patch {i}')
        ax.axis('off')

        #Part two
        fig, ax = plt.subplots(figsize=(8, 8))
        band_data.plot.imshow(ax=ax, vmin=0, vmax=3500)
        plt.title(f"Sentinel-2 Natural Color Composite ({', '.join(bands)})")
        plt.show()
            # Read and plot the JPEG image
        img = mpimg.imread(jpeg_file_path)
        
        # Get the extent of the Sentinel-2 data
        
        # Overlay the JPEG image with the correct extent
        fig, ax = plt.subplots(figsize=(10, 8))
        band_data.plot.imshow(ax=ax, vmin=0, vmax=2000)
        ax.imshow(img, alpha=0.6)
        plt.title("Overlay of Sentinel-2 Band and JPEG Image")
        plt.show()

cities = {
    "muenchen": {"lat": [48.0800, 48.1900], "lon": [11.5000, 11.6700], "osm": "muenchen.osm.pbf", "utm_zone": 32, "sentinel_file": "muenchen_sentinel.nc", "time": 0},
    "rotterdam": {"lat": [51.8750, 51.9700], "lon": [4.4200, 4.5400], "osm": "rotterdam.osm.pbf", "utm_zone": 31, "sentinel_file": "rotterdam_sentinel.nc", "time": 0},
    "duesseldorf": {"lat": [51.1646, 51.3086], "lon": [6.6847, 6.8785], "osm": "duesseldorf.osm.pbf", "utm_zone": 32, "sentinel_file": "duesseldorf_sentinel.nc", "time": 0},
    "prag": {"lat": [49.9900, 50.1100], "lon": [14.3500, 14.5400], "osm": "prag.osm.pbf", "utm_zone": 33, "sentinel_file": "prag_sentinel.nc", "time": 0},
    "greater_london": {"lat": [51.4400, 51.5800], "lon": [-0.2000, 0.0200], "osm": "greater_london.osm.pbf", "utm_zone": 30, "sentinel_file": "greater_london_sentinel.nc", "time": 2},
    "dublin": {"lat": [53.3244, 53.4273], "lon": [-6.3874, -6.1071], "osm": "dublin.osm.pbf", "utm_zone": 29, "sentinel_file": "dublin_sentinel.nc", "time": 1},
    "frankfurt": {"lat": [50.0250, 50.2100], "lon": [8.5200, 8.7500], "osm": "frankfurt.osm.pbf", "utm_zone": 32, "sentinel_file": "frankfurt_sentinel.nc", "time": 2},
    "madrid": {"lat": [40.3125, 40.6437], "lon": [-3.8890, -3.5556], "osm": "madrid.osm.pbf", "utm_zone": 30, "sentinel_file": "madrid_sentinel.nc", "time": -1},
    "bruessel": {"lat": [50.7968, 50.9101], "lon": [4.3054, 4.4317], "osm": "bruessel.osm.pbf", "utm_zone": 31, "sentinel_file": "bruessel_sentinel.nc", "time": 0},
    "berlin": {"lat": [52.454927,52.574409], "lon": [13.294333, 13.500205], "osm": "berlin.osm.pbf", "utm_zone": 33, "sentinel_file": "berlin_sentinel.nc", "time": 1},
    "magdeburg": {"lat": [52.0762, 52.1985], "lon": [11.5430, 11.6760], "osm": "magdeburg.osm.pbf", "utm_zone": 32, "sentinel_file": "magdeburg_sentinel.nc", "time" :0},
    "bremerhaven": {"lat": [53.4978, 53.6008], "lon": [8.5142, 8.6552], "osm": "bremerhaven.osm.pbf", "utm_zone": 32, "sentinel_file": "bremerhaven_sentinel.nc", "time" :1},
    "nuernberg": {"lat": [49.3955, 49.4904], "lon": [11.0023, 11.1445], "osm": "nuernberg.osm.pbf", "utm_zone": 32, "sentinel_file": "nuernberg_sentinel.nc", "time" :0},
    "erfurt": {"lat": [50.9300, 51.0100], "lon": [11.0100, 11.1100], "osm": "erfurt.osm.pbf", "utm_zone": 32, "sentinel_file": "erfurt_sentinel.nc", "time" :0},
    "rostock": {"lat": [54.0647, 54.1237], "lon": [12.0735, 12.1616], "osm": "rostock.osm.pbf", "utm_zone": 33, "sentinel_file": "rostock_sentinel.nc", "time" :0},
    "chemnitz": {"lat": [50.7830, 50.8750], "lon": [12.8830, 12.9550], "osm": "chemnitz.osm.pbf", "utm_zone": 33, "sentinel_file": "chemnitz_sentinel.nc", "time" :3},
    "potsdam": {"lat": [52.3567, 52.4348], "lon": [13.0142, 13.1208], "osm": "potsdam.osm.pbf", "utm_zone": 33, "sentinel_file": "potsdam_sentinel.nc", "time" :1},
    "bonn": {"lat": [50.6540, 50.7640], "lon": [7.0525, 7.1650], "osm": "bonn.osm.pbf", "utm_zone": 32, "sentinel_file": "bonn_sentinel.nc", "time" :0},
    "duisburg": {"lat": [51.3650, 51.5150], "lon": [6.6840, 6.8460], "osm": "duisburg.osm.pbf", "utm_zone": 32, "sentinel_file": "duisburg_sentinel.nc", "time" :0},
    "osnabrueck": {"lat": [52.2450, 52.3050], "lon": [7.9500, 8.0800], "osm": "osnabrueck.osm.pbf", "utm_zone": 32, "sentinel_file": "osnabrueck_sentinel.nc", "time" :3},
}
patch_size = 64
stride = 64
bands = ['B02', 'B03', 'B04', 'B08']
num_patches_to_visualize = 5
for city, data in cities.items():
    lat = data["lat"]
    lon = data["lon"]
    time = data["time"]
    bbox = [lon[0], lat[0], lon[1], lat[1]]
    print(f"Reprojecting Sentinel-2 data for {city}")
    sentinel_ds = reproject_sentinel_to_utm(data["sentinel_file"], data["utm_zone"], bands=bands)
    
    output_path = f"./output/{city}_patches"  # Ensure this directory exists

    # Process OSM data and save as JPEG
    print(f"Processing OSM data for {city}")
    #process_osm_data(city, data["osm"], sentinel_ds, bbox, output_path)
    
    # Create patches for Sentinel-2 and JPEG images
    jpeg_path = os.path.join(output_path, f'{city}_buildings.jpg')
    band_array = sentinel_ds[bands].to_array(dim="bands").values
    
    plot_overlay(sentinel_ds, jpeg_path, bands,time)
    #print(f"Creating and saving patches for {city}")
    create_and_save_patches(band_array, jpeg_path, patch_size, stride, output_path, bands,time)

    # Visualize and overlay patches
    #print(f"Visualizing and overlaying patches for {city}")
    visualize_and_overlay_patches(output_path, num_patches_to_visualize, bands)

    print(f"Creating and saving patches for {city}")
    #create_and_save_patches(band_array, jpeg_path, patch_size, stride, output_path, bands, time)

    # Display cloud-free patches
    print(f"Displaying cloud-free patches for {city}")
    #display_cloud_free_patches(output_path, num_patches_to_visualize, bands)
    #plot_overlay(sentinel_ds, jpeg_path, bands)