In [None]:
import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.transform import resize
from netCDF4 import Dataset
from pyrosm import OSM, get_data

def reproject_sentinel_to_utm(sentinel_band_file, utm_zone, bands=["B02", "B03", "B04", "B08"]):
    ds = xr.load_dataset(sentinel_band_file)
    band_data = ds[bands]
    
    utm_crs = f"EPSG:{32600 + utm_zone}"  # Assuming Northern Hemisphere
    if 'crs' not in band_data.attrs:
        band_data = band_data.rio.write_crs(utm_crs, inplace=True)
    else:
        band_data = band_data.rio.reproject(utm_crs)
    
    return band_data

def plot_overlay(sentinel_band_file, jpeg_path, bands, time):
    ds = sentinel_band_file
    band_data = ds[bands].to_array(dim="bands")
    
    fig, ax = plt.subplots(figsize=(8, 8))
    band_data[{"t": time}].plot.imshow(ax=ax, vmin=0, vmax=3500)
    plt.title(f"Sentinel-2 Natural Color Composite ({', '.join(bands)})")
    plt.show()
    
    # Read and plot the JPEG image
    img = mpimg.imread(jpeg_path)
    
    # Get the extent of the Sentinel-2 data
    xmin, ymin, xmax, ymax = band_data.rio.bounds()
    
    # Overlay the JPEG image with the correct extent
    fig, ax = plt.subplots(figsize=(10, 8))
    band_data[{"t": time}].plot.imshow(ax=ax, vmin=0, vmax=2000, cmap='gray')
    ax.imshow(img, extent=[xmin, xmax, ymin, ymax], alpha=0.6)
    
    plt.title("Overlay of Sentinel-2 Band and JPEG Image")
    plt.xlabel("x coordinate of projection [m]")
    plt.ylabel("y coordinate of projection [m]")
    plt.show()

def detect_clouds(patch_data, blue_threshold=2000, nir_threshold=5000, max_high_pixels=750):
    blue_band = patch_data[0]  # B02
    nir_band = patch_data[3]   # B08
    
    # Identify unique pixels that are high-value in either band
    high_value_pixels = np.logical_or(blue_band > blue_threshold, nir_band > nir_threshold)
    
    # Count the number of unique high-value pixels
    high_value_pixels_count = np.sum(high_value_pixels)
    
    # Debug information
    print(f"Number of unique high-value pixels: {high_value_pixels_count}")
    
    # Return True if the number of high-valued pixels exceeds the threshold
    return high_value_pixels_count > max_high_pixels

def create_and_save_patches(band_array, jpeg_path, patch_size, stride, output_dir, bands, time_index=1):
    # Read the JPEG image
    img = mpimg.imread(jpeg_path)
    cloud_counter = 0
    # Debug: Print shape of the band array and the JPEG image
    print(f"Shape of band_array: {band_array.shape}")
    print(f"Shape of JPEG image: {img.shape}")
    
    # Resize JPEG image to match the Sentinel-2 data dimensions
    sentinel_height, sentinel_width = band_array.shape[2], band_array.shape[3]
    resized_img = resize(img, (sentinel_height, sentinel_width, img.shape[2]), anti_aliasing=True)
    
    # Debug: Print shape of the resized JPEG image
    print(f"Shape of resized JPEG image: {resized_img.shape}")
    
    # Calculate the number of patches in x and y directions ensuring no overlap
    num_patches_x = max(1, int((sentinel_width - patch_size) / patch_size) + 1)
    num_patches_y = max(1, int((sentinel_height - patch_size) / patch_size) + 1)
    
    # Debug: Print number of patches in X and Y directions
    print(f"Total patches to be created: {num_patches_x * num_patches_y}")
    print(f"Number of patches in X direction: {num_patches_x}")
    print(f"Number of patches in Y direction: {num_patches_y}")
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Iterate over the grid to create patches
    patch_index = 0
    for i in range(num_patches_y):
        for j in range(num_patches_x):
            # Extract the patch from the Sentinel-2 data
            sentinel_patch = band_array[
                :,  # Keep all bands
                time_index,  # Use the specified time index
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size
            ]
            
            # Flip the Sentinel-2 patch horizontally
            # sentinel_patch = np.flip(sentinel_patch, axis=1)
            print(f"checking cloudy patch at ({i}, {j})")
            if detect_clouds(sentinel_patch):
                cloud_counter+=1
                print(f"Skipping cloudy patch at ({i}, {j})")
                continue
            # Ensure the patch is of the correct size
            if sentinel_patch.shape[1] != patch_size or sentinel_patch.shape[2] != patch_size:
                print(f"Skipping Sentinel patch at ({i}, {j}) due to incorrect size.")
                continue
            
            # Extract the corresponding patch from the resized JPEG image
            jpeg_patch = resized_img[
                i * patch_size:(i + 1) * patch_size,
                j * patch_size:(j + 1) * patch_size,
                :
            ]
            
            # Ensure the patch is of the correct size
            if jpeg_patch.shape[0] != patch_size or jpeg_patch.shape[1] != patch_size:
                print(f"Skipping JPEG patch at ({i}, {j}) due to incorrect size.")
                continue
            
            # Save the Sentinel-2 patch as a NetCDF file
            sentinel_patch_file = os.path.join(output_dir, f'sentinel_patch_{patch_index}normal.nc')
            sentinel_patch_ds = xr.Dataset(
                {
                    'B02': (['y', 'x'], sentinel_patch[0]),
                    'B03': (['y', 'x'], sentinel_patch[1]),
                    'B04': (['y', 'x'], sentinel_patch[2]),
                    'B08': (['y', 'x'], sentinel_patch[3])
                }
            )
            sentinel_patch_ds.to_netcdf(sentinel_patch_file)
            
            # Save the JPEG patch as a JPEG file
            jpeg_patch_file = os.path.join(output_dir, f'jpeg_patch_{patch_index}normal.jpg')
            plt.imsave(jpeg_patch_file, jpeg_patch)
            
            print(f"Patch {patch_index}:")
            print(f"  Sentinel patch file: {sentinel_patch_file}")
            print(f"  JPEG patch file: {jpeg_patch_file}")
            
            patch_index += 1
    
    print(f'Created {patch_index} patches and saved to {output_dir}')
    print(f"cloud patches removed {cloud_counter}")

# Function to visualize and overlay patches from Sentinel-2 and JPEG images
def visualize_and_overlay_patches(patch_folder, num_patches, bands):
    patch_files = [f for f in os.listdir(patch_folder) if f.startswith('sentinel_patch_') and f.endswith('.nc')]
    jpeg_files = [f for f in os.listdir(patch_folder) if f.startswith('jpeg_patch_') and f.endswith('.jpg')]
    patch_files.sort()
    jpeg_files.sort()

    for i in range(min(num_patches, len(patch_files))):
        patch_file = patch_files[i]
        jpeg_file = jpeg_files[i]
        
        patch_file_path = os.path.join(patch_folder, patch_file)
        jpeg_file_path = os.path.join(patch_folder, jpeg_file)

        # Read NetCDF patch
        with Dataset(patch_file_path, 'r') as ds:
            patch_data = np.array([ds[band][:] for band in bands])
            band_data = xr.DataArray(patch_data, dims=("bands", "y", "x"), coords={"bands": bands})

        # Ensure the patch is 3D for RGB display
        if patch_data.shape[0] >= 3:
            sentinel_rgb_patch = np.stack([patch_data[2], patch_data[1], patch_data[0]], axis=-1)
        else:
            raise ValueError("Sentinel patch must have at least three bands for RGB visualization.")

        # Load JPEG patch
        jpeg_patch = mpimg.imread(jpeg_file_path)
        
        # Plot Sentinel-2 patch using band_data
        fig, ax = plt.subplots(figsize=(8, 8))
        # Plot overlay of Sentinel-2 patch and JPEG patch
        ax.imshow(sentinel_rgb_patch, vmin=0, vmax=3500)
        ax.imshow(jpeg_patch, alpha=0.6)
        ax.set_title(f'Overlay Patch {i}')
        ax.axis('off')

        #Part two
        fig, ax = plt.subplots(figsize=(8, 8))
        band_data.plot.imshow(ax=ax, vmin=0, vmax=3500)
        plt.title(f"Sentinel-2 Natural Color Composite ({', '.join(bands)})")
        plt.show()
            # Read and plot the JPEG image
        img = mpimg.imread(jpeg_file_path)
        
        # Get the extent of the Sentinel-2 data
        
        # Overlay the JPEG image with the correct extent
        fig, ax = plt.subplots(figsize=(10, 8))
        band_data.plot.imshow(ax=ax, vmin=0, vmax=2000)
        ax.imshow(img, alpha=0.6)
        plt.title("Overlay of Sentinel-2 Band and JPEG Image")
        plt.show()

cities = {
    "muenchen": {"lat": [48.0800, 48.1900], "lon": [11.5000, 11.6700], "osm": "muenchen.osm.pbf", "utm_zone": 32, "sentinel_file": "muenchen_sentinel.nc", "time": 0},
    "rotterdam": {"lat": [51.8750, 51.9700], "lon": [4.4200, 4.5400], "osm": "rotterdam.osm.pbf", "utm_zone": 31, "sentinel_file": "rotterdam_sentinel.nc", "time": 0},
    "duesseldorf": {"lat": [51.1646, 51.3086], "lon": [6.6847, 6.8785], "osm": "duesseldorf.osm.pbf", "utm_zone": 32, "sentinel_file": "duesseldorf_sentinel.nc", "time": 0},
    "prag": {"lat": [49.9900, 50.1100], "lon": [14.3500, 14.5400], "osm": "prag.osm.pbf", "utm_zone": 33, "sentinel_file": "prag_sentinel.nc", "time": 0},
    "greater_london": {"lat": [51.4400, 51.5800], "lon": [-0.2000, 0.0200], "osm": "greater_london.osm.pbf", "utm_zone": 30, "sentinel_file": "greater_london_sentinel.nc", "time": 2},
    "dublin": {"lat": [53.3244, 53.4273], "lon": [-6.3874, -6.1071], "osm": "dublin.osm.pbf", "utm_zone": 29, "sentinel_file": "dublin_sentinel.nc", "time": 1},
    "frankfurt": {"lat": [50.0250, 50.2100], "lon": [8.5200, 8.7500], "osm": "frankfurt.osm.pbf", "utm_zone": 32, "sentinel_file": "frankfurt_sentinel.nc", "time": 2},
    "madrid": {"lat": [40.3125, 40.6437], "lon": [-3.8890, -3.5556], "osm": "madrid.osm.pbf", "utm_zone": 30, "sentinel_file": "madrid_sentinel.nc", "time": -1},
    "bruessel": {"lat": [50.7968, 50.9101], "lon": [4.3054, 4.4317], "osm": "bruessel.osm.pbf", "utm_zone": 31, "sentinel_file": "bruessel_sentinel.nc", "time": 0},
    "berlin": {"lat": [52.454927,52.574409], "lon": [13.294333, 13.500205], "osm": "berlin.osm.pbf", "utm_zone": 33, "sentinel_file": "berlin_sentinel.nc", "time": 1},
    "magdeburg": {"lat": [52.0762, 52.1985], "lon": [11.5430, 11.6760], "osm": "magdeburg.osm.pbf", "utm_zone": 32, "sentinel_file": "magdeburg_sentinel.nc", "time" :0},
    "bremerhaven": {"lat": [53.4978, 53.6008], "lon": [8.5142, 8.6552], "osm": "bremerhaven.osm.pbf", "utm_zone": 32, "sentinel_file": "bremerhaven_sentinel.nc", "time" :1},
    "nuernberg": {"lat": [49.3955, 49.4904], "lon": [11.0023, 11.1445], "osm": "nuernberg.osm.pbf", "utm_zone": 32, "sentinel_file": "nuernberg_sentinel.nc", "time" :0},
    "erfurt": {"lat": [50.9300, 51.0100], "lon": [11.0100, 11.1100], "osm": "erfurt.osm.pbf", "utm_zone": 32, "sentinel_file": "erfurt_sentinel.nc", "time" :0},
    "rostock": {"lat": [54.0647, 54.1237], "lon": [12.0735, 12.1616], "osm": "rostock.osm.pbf", "utm_zone": 33, "sentinel_file": "rostock_sentinel.nc", "time" :0},
    "chemnitz": {"lat": [50.7830, 50.8750], "lon": [12.8830, 12.9550], "osm": "chemnitz.osm.pbf", "utm_zone": 33, "sentinel_file": "chemnitz_sentinel.nc", "time" :3},
    "potsdam": {"lat": [52.3567, 52.4348], "lon": [13.0142, 13.1208], "osm": "potsdam.osm.pbf", "utm_zone": 33, "sentinel_file": "potsdam_sentinel.nc", "time" :1},
    "bonn": {"lat": [50.6540, 50.7640], "lon": [7.0525, 7.1650], "osm": "bonn.osm.pbf", "utm_zone": 32, "sentinel_file": "bonn_sentinel.nc", "time" :0},
    "duisburg": {"lat": [51.3650, 51.5150], "lon": [6.6840, 6.8460], "osm": "duisburg.osm.pbf", "utm_zone": 32, "sentinel_file": "duisburg_sentinel.nc", "time" :0},
    "osnabrueck": {"lat": [52.2450, 52.3050], "lon": [7.9500, 8.0800], "osm": "osnabrueck.osm.pbf", "utm_zone": 32, "sentinel_file": "osnabrueck_sentinel.nc", "time" :3},
}
patch_size = 64
stride = 64
bands = ['B02', 'B03', 'B04', 'B08']
num_patches_to_visualize = 5
for city, data in cities.items():
    lat = data["lat"]
    lon = data["lon"]
    time = data["time"]
    bbox = [lon[0], lat[0], lon[1], lat[1]]
    print(f"Reprojecting Sentinel-2 data for {city}")
    sentinel_ds = reproject_sentinel_to_utm(data["sentinel_file"], data["utm_zone"], bands=bands)
    
    output_path = f"./output/{city}_patches"  # Ensure this directory exists

    # Process OSM data and save as JPEG
    print(f"Processing OSM data for {city}")
    #process_osm_data(city, data["osm"], sentinel_ds, bbox, output_path)
    
    # Create patches for Sentinel-2 and JPEG images
    jpeg_path = os.path.join(output_path, f'{city}_buildings.jpg')
    band_array = sentinel_ds[bands].to_array(dim="bands").values
    
    plot_overlay(sentinel_ds, jpeg_path, bands,time)
    #print(f"Creating and saving patches for {city}")
    create_and_save_patches(band_array, jpeg_path, patch_size, stride, output_path, bands,time)

    # Visualize and overlay patches
    #print(f"Visualizing and overlaying patches for {city}")
    visualize_and_overlay_patches(output_path, num_patches_to_visualize, bands)

    print(f"Creating and saving patches for {city}")
    #create_and_save_patches(band_array, jpeg_path, patch_size, stride, output_path, bands, time)

    # Display cloud-free patches
    print(f"Displaying cloud-free patches for {city}")
    #display_cloud_free_patches(output_path, num_patches_to_visualize, bands)
    #plot_overlay(sentinel_ds, jpeg_path, bands)

In [None]:
import os
import numpy as np
import tensorflow as tf
from netCDF4 import Dataset
from PIL import Image
import matplotlib.pyplot as plt

def create_tensors_from_patches(patch_folder, bands, aug_type = ''):
    nc_files = sorted([f for f in os.listdir(patch_folder) if f.startswith('sentinel_patch_') and aug_type in f and f.endswith('.nc')])
    jpeg_files = sorted([f for f in os.listdir(patch_folder) if f.startswith('jpeg_patch_') and aug_type in f and f.endswith('.jpg')])
    
    print(f"Found {len(nc_files)} NC files and {len(jpeg_files)} JPEG files")

    def load_sentinel_patch(file_path):
        try:
            with Dataset(file_path, 'r') as ds:
                patch_data = np.array([ds[band][:] for band in bands])
            return patch_data
        except Exception as e:
            print(f"Error loading Sentinel patch {file_path}: {str(e)}")
            return None

    def load_jpeg_patch(file_path):
        try:
            with Image.open(file_path) as img:
                img_array = np.array(img.convert('L')).astype(np.float32) / 255.0
            return img_array
        except Exception as e:
            print(f"Error loading JPEG patch {file_path}: {str(e)}")
            return None

    sentinel_patches = []
    jpeg_patches = []

    for nc_file, jpeg_file in zip(nc_files, jpeg_files):
        sentinel_patch = load_sentinel_patch(os.path.join(patch_folder, nc_file))
        jpeg_patch = load_jpeg_patch(os.path.join(patch_folder, jpeg_file))
        
        if sentinel_patch is not None and jpeg_patch is not None:
            sentinel_patches.append(sentinel_patch)
            jpeg_patches.append(jpeg_patch)

    if sentinel_patches and jpeg_patches:
        # Reshape input tensor to channels last format
        input_tensor = tf.convert_to_tensor(np.transpose(np.array(sentinel_patches), (0, 2, 3, 1)), dtype=tf.float32)
        
        # Target tensor should not include a channel dimension
        target_tensor = tf.convert_to_tensor(np.array(jpeg_patches), dtype=tf.float32)
        
        print(f"Input tensor shape: {input_tensor.shape}")
        print(f"Target tensor shape: {target_tensor.shape}")
        
        return input_tensor, target_tensor
    else:
        print("Failed to create tensors: No valid patches found")
        return None, None

def visualize_tensors(input_tensor, target_tensor, bands, num_samples=5):
    num_samples = min(num_samples, input_tensor.shape[0])
    
    for i in range(num_samples):
        fig, axes = plt.subplots(1, len(bands) + 1, figsize=(20, 4))
        fig.suptitle(f"Sample {i+1}")

        for j, band in enumerate(bands):
            axes[j].imshow(input_tensor[i, :, :, j], cmap='viridis')
            axes[j].set_title(f"Band {band}")
            axes[j].axis('off')

        # Binary mask (no need to squeeze)
        axes[len(bands)].imshow(target_tensor[i], cmap='binary')
        axes[len(bands)].set_title("Binary Building Mask")
        axes[len(bands)].axis('off')

        plt.tight_layout()
        plt.show()

        # Print mask values statistics
        mask_values = target_tensor[i].numpy()
        print(f"Sample {i+1} Binary Building Mask Statistics:")
        print(f"  Min value: {np.min(mask_values):.4f}")
        print(f"  Max value: {np.max(mask_values):.4f}")
        print(f"  Mean value: {np.mean(mask_values):.4f}")
        print(f"  Unique values: {np.unique(mask_values)}")
        print()


def process_city(city_name, city_data, bands):
    print(f"\nProcessing {city_name}")
    patch_folder = f"./output/{city_name}_patches"
    
    if not os.path.exists(patch_folder):
        print(f"Patch folder not found for {city_name}. Skipping.")
        return None, None

    input_tensor, target_tensor = create_tensors_from_patches(patch_folder, bands)
    
    if input_tensor is not None and target_tensor is not None:
        print(f"Tensors created for {city_name}")
        print(f"Input tensor shape: {input_tensor.shape}")
        print(f"Target tensor shape: {target_tensor.shape}")
        
        # Visualize the first few samples
        visualize_tensors(input_tensor, target_tensor, bands)
    else:
        print(f"Failed to create tensors for {city_name}")
    
    return input_tensor, target_tensor

# City data
cities = {
    "muenchen": {"lat": [48.0800, 48.1900], "lon": [11.5000, 11.6700], "osm": "muenchen.osm.pbf", "utm_zone": 32, "sentinel_file": "muenchen_sentinel.nc", "time": 0},
    "rotterdam": {"lat": [51.8750, 51.9700], "lon": [4.4200, 4.5400], "osm": "rotterdam.osm.pbf", "utm_zone": 31, "sentinel_file": "rotterdam_sentinel.nc", "time": 0},
    "duesseldorf": {"lat": [51.1646, 51.3086], "lon": [6.6847, 6.8785], "osm": "duesseldorf.osm.pbf", "utm_zone": 32, "sentinel_file": "duesseldorf_sentinel.nc", "time": 0},
    "prag": {"lat": [49.9900, 50.1100], "lon": [14.3500, 14.5400], "osm": "prag.osm.pbf", "utm_zone": 33, "sentinel_file": "prag_sentinel.nc", "time": 0},
    "greater_london": {"lat": [51.4400, 51.5800], "lon": [-0.2000, 0.0200], "osm": "greater_london.osm.pbf", "utm_zone": 30, "sentinel_file": "greater_london_sentinel.nc", "time": 2},
    "dublin": {"lat": [53.3244, 53.4273], "lon": [-6.3874, -6.1071], "osm": "dublin.osm.pbf", "utm_zone": 29, "sentinel_file": "dublin_sentinel.nc", "time": 1},
    "frankfurt": {"lat": [50.0250, 50.2100], "lon": [8.5200, 8.7500], "osm": "frankfurt.osm.pbf", "utm_zone": 32, "sentinel_file": "frankfurt_sentinel.nc", "time": 2},
    "madrid": {"lat": [40.3125, 40.6437], "lon": [-3.8890, -3.5556], "osm": "madrid.osm.pbf", "utm_zone": 30, "sentinel_file": "madrid_sentinel.nc", "time": -1},
    "bruessel": {"lat": [50.7968, 50.9101], "lon": [4.3054, 4.4317], "osm": "bruessel.osm.pbf", "utm_zone": 31, "sentinel_file": "bruessel_sentinel.nc", "time": 0},
    #"singapore": {"lat": [1.2378, 1.4707], "lon": [103.6017, 104.0123], "osm": "singapore.osm.pbf", "utm_zone": 48, "sentinel_file": "singapore_sentinel.nc", "time": 0},
    "berlin": {"lat": [52.454927,52.574409], "lon": [13.294333, 13.500205], "osm": "berlin.osm.pbf", "utm_zone": 33, "sentinel_file": "berlin_sentinel.nc", "time": 1},
    "magdeburg": {"lat": [52.0762, 52.1985], "lon": [11.5430, 11.6760], "osm": "magdeburg.osm.pbf", "utm_zone": 32, "sentinel_file": "magdeburg_sentinel.nc", "time" :0},
    "bremerhaven": {"lat": [53.4978, 53.6008], "lon": [8.5142, 8.6552], "osm": "bremerhaven.osm.pbf", "utm_zone": 32, "sentinel_file": "bremerhaven_sentinel.nc", "time" :1},
    "nuernberg": {"lat": [49.3955, 49.4904], "lon": [11.0023, 11.1445], "osm": "nuernberg.osm.pbf", "utm_zone": 32, "sentinel_file": "nuernberg_sentinel.nc", "time" :0},
    "erfurt": {"lat": [50.9300, 51.0100], "lon": [11.0100, 11.1100], "osm": "erfurt.osm.pbf", "utm_zone": 32, "sentinel_file": "erfurt_sentinel.nc", "time" :0},
    "rostock": {"lat": [54.0647, 54.1237], "lon": [12.0735, 12.1616], "osm": "rostock.osm.pbf", "utm_zone": 33, "sentinel_file": "rostock_sentinel.nc", "time" :0},
    "chemnitz": {"lat": [50.7830, 50.8750], "lon": [12.8830, 12.9550], "osm": "chemnitz.osm.pbf", "utm_zone": 33, "sentinel_file": "chemnitz_sentinel.nc", "time" :3},
    "potsdam": {"lat": [52.3567, 52.4348], "lon": [13.0142, 13.1208], "osm": "potsdam.osm.pbf", "utm_zone": 33, "sentinel_file": "potsdam_sentinel.nc", "time" :1},
    "bonn": {"lat": [50.6540, 50.7640], "lon": [7.0525, 7.1650], "osm": "bonn.osm.pbf", "utm_zone": 32, "sentinel_file": "bonn_sentinel.nc", "time" :0},
    "duisburg": {"lat": [51.3650, 51.5150], "lon": [6.6840, 6.8460], "osm": "duisburg.osm.pbf", "utm_zone": 32, "sentinel_file": "duisburg_sentinel.nc", "time" :0},
    "osnabrueck": {"lat": [52.2450, 52.3050], "lon": [7.9500, 8.0800], "osm": "osnabrueck.osm.pbf", "utm_zone": 32, "sentinel_file": "osnabrueck_sentinel.nc", "time" :3},
}

def create_binary_mask(target_tensor, threshold=0.08):
    """
    Convert the inverted floating-point mask to a binary mask.
    
    Args:
    target_tensor: Tensor of shape (N, H, W) with inverted mask values
    threshold: Value above which a pixel is considered a building (default: 0.06)
    
    Returns:
    Binary mask tensor of shape (N, H, W)
    """
    # Invert the mask if it hasn't been inverted yet
    inverted_mask = 1 - target_tensor
    
    # Convert to binary mask
    binary_mask = tf.cast(inverted_mask > threshold, tf.float32)
    
    return binary_mask

# Modify the filter_empty_or_full_masks function
def filter_empty_or_full_masks(input_tensor, target_tensor, min_threshold=0.01, max_threshold=0.99):
    # Create binary mask
    binary_mask = create_binary_mask(target_tensor)
    
    # Calculate the fraction of building pixels in each mask
    mask_fraction = tf.reduce_mean(binary_mask, axis=[1, 2])
    
    # Create a boolean mask for samples to keep
    keep_mask =  (mask_fraction > min_threshold) & (mask_fraction < max_threshold)
    
    # Apply the mask to both input tensor and binary mask
    filtered_input = tf.boolean_mask(input_tensor, keep_mask, axis=0)
    filtered_target = tf.boolean_mask(binary_mask, keep_mask, axis=0)
    
    print(f"Kept {tf.reduce_sum(tf.cast(keep_mask, tf.int32))} out of {keep_mask.shape[0]} samples")
    
    return filtered_input, filtered_target

# Function to process city data with filtering
def process_city_with_filtering(city_name, city_data, bands,aug_type, min_threshold=0.01, max_threshold=0.99):
    print(f"\nProcessing {city_name}")
    patch_folder = f"./output/{city_name}_patches"
    
    if not os.path.exists(patch_folder):
        print(f"Patch folder not found for {city_name}. Skipping.")
        return None, None

    input_tensor, target_tensor = create_tensors_from_patches(patch_folder, bands, aug_type = aug_type)
    
    if input_tensor is not None and target_tensor is not None:
        print("Before filtering:")
        print(f"  Input tensor shape: {input_tensor.shape}")
        print(f"  Target tensor shape: {target_tensor.shape}")
        
        # Apply filtering
        filtered_input, filtered_target = filter_empty_or_full_masks(input_tensor, target_tensor, min_threshold, max_threshold)
        
        print("After filtering:")
        print(f"  Filtered input tensor shape: {filtered_input.shape}")
        print(f"  Filtered target tensor shape: {filtered_target.shape}")
        
        # Visualize the first few samples of filtered data
        visualize_tensors(filtered_input, filtered_target, bands)
    else:
        print(f"Failed to create tensors for {city_name}")
        return None, None
    
    return filtered_input, filtered_target

# Process all cities
bands = ['B02', 'B03', 'B04', 'B08']
all_city_tensors = {}
def invert_mask(mask):
    return 1 - mask


for city_name, city_data in cities.items():
    #input_tensor, target_tensor = process_city(city_name, city_data, bands)
    input_tensor_normal, target_tensor_normal = process_city_with_filtering(city_name, city_data, bands,min_threshold=0.0, aug_type='normal')
    input_tensor_reflect, target_tensor_reflect = process_city_with_filtering(city_name, city_data, bands,min_threshold=0.0, aug_type='reflect')
    #input_tensor, target_tensor = process_city_with_filtering(city_name, city_data, bands,min_threshold=0.0, aug_type='shear')
    #input_tensor, target_tensor = process_city_with_filtering(city_name, city_data, bands,min_threshold=0.0, aug_type='rotate')
    input_tensor = tf.concat([input_tensor_normal, input_tensor_reflect], axis=0)  
    target_tensor = tf.concat([target_tensor_normal, target_tensor_reflect], axis=0)  
    # Apply this to your target tensor
    inverted_target_tensor = invert_mask(target_tensor)
    if input_tensor is not None and target_tensor is not None:
        all_city_tensors[city_name] = {
            "input": input_tensor,
            "target": inverted_target_tensor
        }

# Print summary of processed cities
print("\nSummary of processed cities:")
for city_name, tensors in all_city_tensors.items():
    print(f"{city_name}:")
    print(f"  Input tensor shape: {tensors['input'].shape}")
    print(f"  Target tensor shape: {tensors['target'].shape}")

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2

# Debug statement to check if all_city_tensors is loaded correctly
print("Loaded city tensors keys:", all_city_tensors.keys())

# Assuming all_city_tensors is already created from the previous step

# Separate Berlin (test set) from other cities
test_input = all_city_tensors['berlin']['input']
test_target = all_city_tensors['berlin']['target']
print(f"Berlin test set shapes - Input: {test_input.shape}, Target: {test_target.shape}")

# Combine other cities for training and validation
train_val_input = []
train_val_target = []
for city, tensors in all_city_tensors.items():
    if city != 'berlin':
        train_val_input.append(tensors['input'])
        train_val_target.append(tensors['target'])
        print(f"Adding data from city: {city}, Input shape: {tensors['input'].shape}, Target shape: {tensors['target'].shape}")

train_val_input = np.concatenate(train_val_input, axis=0)
train_val_target = np.concatenate(train_val_target, axis=0)
print(f"Combined train/val shapes - Input: {train_val_input.shape}, Target: {train_val_target.shape}")

# Split into training and validation sets
train_input, val_input, train_target, val_target = train_test_split(
    train_val_input, train_val_target, test_size=0.2, random_state=42
)



print(f"Training set shapes - Input: {train_input.shape}, Target: {train_target.shape}")
print(f"Validation set shapes - Input: {val_input.shape}, Target: {val_target.shape}")

# Define the CNN model
def create_model(input_shape):
    p = 'same'
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', padding=p, input_shape=input_shape),
        layers.Conv2D(64, (3, 3), activation='relu', padding=p),
        layers.Conv2D(128, (3, 3), activation='relu', padding=p),
        layers.Conv2D(1, (1, 1), padding=p)
    ])
    return model

# Create the model
input_shape = train_input.shape[1:]  # (H, W, C)
print(f"Model input shape: {input_shape}")
model = create_model(input_shape)
model.summary()  # Print model summary

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model with early stopping
early_stopping = EarlyStopping(patience=5, restore_best_weights=True)
history = model.fit(
    train_input, train_target,
    validation_data=(val_input, val_target),
    epochs=50,
    batch_size=32,
    callbacks=[early_stopping]
)

# Evaluate the model on test set (Berlin)
test_loss, test_accuracy = model.evaluate(test_input, test_target)
print(f"Test accuracy: {test_accuracy}")

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

# Hyperparameter tuning
def create_tuned_model(input_shape, l2_reg=0.01, learning_rate=0.001):
    p = 'same'
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', padding=p, input_shape=input_shape, kernel_regularizer=l2(l2_reg)),
        layers.Conv2D(64, (3, 3), activation='relu', padding=p, kernel_regularizer=l2(l2_reg)),
        layers.Conv2D(128, (3, 3), activation='relu', padding=p, kernel_regularizer=l2(l2_reg)),
        layers.Conv2D(1, (1, 1), padding=p, kernel_regularizer=l2(l2_reg))
    ])
    model.compile(optimizer=Adam(learning_rate=learning_rate), loss='binary_crossentropy', metrics=['accuracy'])
    return model

# Hyperparameter grid
l2_regs = [0.001, 0.01, 0.1]
learning_rates = [0.0001, 0.001, 0.01]

best_val_accuracy = 0
best_model = None
best_params = None

for l2_reg in l2_regs:
    for lr in learning_rates:
        print(f"Training with L2 reg: {l2_reg}, Learning rate: {lr}")
        model = create_tuned_model(input_shape, l2_reg, lr)
        history = model.fit(
            train_input, train_target,
            validation_data=(val_input, val_target),
            epochs=30,
            batch_size=32,
            callbacks=[EarlyStopping(patience=5, restore_best_weights=True)],
            verbose=0
        )
        val_accuracy = max(history.history['val_accuracy'])
        print(f"Validation accuracy for L2 reg {l2_reg}, Learning rate {lr}: {val_accuracy}")
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model = model
            best_params = {'l2_reg': l2_reg, 'learning_rate': lr}

print("Best hyperparameters:", best_params)
print("Best validation accuracy:", best_val_accuracy)

# Evaluate best model on test set (Berlin)
test_loss, test_accuracy = best_model.evaluate(test_input, test_target)
print(f"Test accuracy with best model: {test_accuracy}")

# U-Net implementation
def unet_model(input_shape):
    inputs = layers.Input(input_shape)
    # Encoder
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    # Bridge
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)

    # Decoder
    up4 = layers.UpSampling2D(size=(2, 2))(conv3)
    up4 = layers.concatenate([up4, conv2])
    conv4 = layers.Conv2D(128, 3, activation='relu', padding='same')(up4)

    up5 = layers.UpSampling2D(size=(2, 2))(conv4)
    up5 = layers.concatenate([up5, conv1])
    conv5 = layers.Conv2D(64, 3, activation='relu', padding='same')(up5)

    outputs = layers.Conv2D(1, 1, activation='sigmoid')(conv5)

    model = models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

# Train U-Net model
unet = unet_model(input_shape)
unet_history = unet.fit(
    train_input, train_target,
    validation_data=(val_input, val_target),
    epochs=50,
    batch_size=32,
    callbacks=[EarlyStopping(patience=5, restore_best_weights=True)]
)

# Evaluate U-Net on test set (Berlin)
unet_test_loss, unet_test_accuracy = unet.evaluate(test_input, test_target)
print(f"U-Net Test accuracy: {unet_test_accuracy}")

# Compare CNN and U-Net performance
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='CNN Training')
plt.plot(history.history['val_accuracy'], label='CNN Validation')
plt.plot(unet_history.history['accuracy'], label='U-Net Training')
plt.plot(unet_history.history['val_accuracy'], label='U-Net Validation')
plt.title('Model Accuracy Comparison')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='CNN Training')
plt.plot(history.history['val_loss'], label='CNN Validation')
plt.plot(unet_history.history['loss'], label='U-Net Training')
plt.plot(unet_history.history['val_loss'], label='U-Net Validation')
plt.title('Model Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()
