# Landscape Clustering by FFT Footprints
## Primary Notebook

Hi! This project is a personal project with a fun backstory. Have you ever been on vacation and wondered: Do I really have to travel this far, or is there a similar landscape closeer to my home delivering the same experience? Well I know this to be a really unusual question, but one that I asked myself, while on vacation all over Scotland.

This project aims to make similar "landscape types" easy to discover. Essentially the notebook creates one or more GeoTIFFs where different landscape types are colored. They are clustered by their topographic footprint.

### The topographic footprint


### Overview: 
This is the primary notebook of this repository, and the one that does the heavy work.

There are other notebooks to this repository:
- secondary_dem_retriever.ipynb -> This one connects to opentopography.org and retrieves DEM geotiffs (aka heightmaps) of parts of the earth.
- color_separator.ipynb -> This is just a little helper notebook (only rough code). This one helped creating custom color palettes for coloring the resulting images.



In [None]:
# What we might need eventually
# rasterio for geotiffs
# dask array for parallel computing of large arrays
# pyfftw for 2d fft

## Overview of the classes and their interconnection
The subject of this notebook is to cluster different landscapes across multiple DEM-GeoTIFFs. 

### GeographicBounds
This is an object that saves West, South, East, North bounds, as well as the projection, and maybe some other additional info.

### AugmentedDEM
This will be the "container" for all information regarding *one* specific DEM raster map.

### EmbeddingMap
This will be part of each AugmentedDEM. Here we will save the labels created by the clustering.




## How to make the GeoTIFF images usable by my algorithm:
It’s important to note, that the original DEM data is in arc seconds. To avoid skewed results it has to be compressed by cos(latitude) in width.  That way the metric distances (almost) resemble the distances pixel-wise

After that the images will be scaled to 1/4 (as of now), because processing becomes about 4^2 times faster.

Before processing, the data will have to be split up into quadratic tiles which are supposed to be equal in their length.

I guess we should combine the processes, namely selecting tiles of equal size and then reducing their size to a given size in pixels

In [None]:
# Imports

# Filesystem, JSON
import os
import json


# Rasterio for handling GeoTIFFs
import rasterio
from rasterio.transform import Affine, from_origin

from rasterio.warp import calculate_default_transform, transform_bounds, reproject, Resampling
from rasterio.windows import from_bounds
from rasterio.enums import Resampling as ResampleEnum


# Resampling of the tiles
from PIL import Image 


# For interpolation of the results
from scipy.interpolate import griddata
from scipy import ndimage # for filtering
 

# Math may not be missing
import math
import numpy as np
import dask.array as da
import pyfftw 


# The heart: we use k-means-clustering
from sklearn.cluster import KMeans
import hdbscan
from sklearn.cluster import OPTICS
from sklearn.cluster import MeanShift
from sklearn.cluster import DBSCAN


# Regex for updating filenames
import re


# For earth related numbers
from pyproj import Geod, CRS, Transformer
geod = Geod(ellps = "WGS84")


# Matplotlib for graphics
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.colors import LinearSegmentedColormap


# For timing 
import time


# For multiprocessing
from multiprocess import Pool, cpu_count

In [None]:
# Helper functions

# A simple limit function for integers
def clamp(x, min, max) -> int:
                return int(sorted([min,x,max])[1])


def mode_filter(label_map, radius = 10) -> np.ndarray:
    '''
    Returns the most prevalent label in a given radius inside a label map.
    The label map has to be in shape (y,x), the label values can be arbitrary integers.
    '''
    y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
    mask = x**2 + y**2 <= radius **2


    def mode_filter_func(values):
        counts = np.bincount(values.astype("uint8"))
        return np.argmax(counts)
    
    filtered = ndimage.generic_filter(
        label_map,
        function = mode_filter_func,
        footprint = mask,
        mode = "nearest"
    )

    return filtered



def constrain_labels(input) -> list:
    '''
    Takes a list of integer labels in any range and converts them to 1, 2, 3…
    '''

    unique_labels = np.sort(np.unique(input))
    new_labels = np.arange(0,len(unique_labels), 1)
    from_to = dict(zip(unique_labels,new_labels))

    constrained_labels = np.vectorize(from_to.get)(input)

    return constrained_labels



def color_to_numpy(stringlist) -> np.ndarray:
    '''
    This converts a list of string hex color codes to actual rgb values in a numpy array.
    '''

    colorlist = np.empty((len(stringlist),3))

    for i, string in enumerate(stringlist):
        removed_hashtag = re.search(r"(?i)([a-f0-9]+)", string).group(0)
        colorlist[i] = np.array((int(removed_hashtag[0:2],16),
                        int(removed_hashtag[2:4],16),
                        int(removed_hashtag[4:6],16)))

    return colorlist



def euclidean_distance (y_a, y_b, x_a, x_b) -> float:
    '''
    Calculates the pythagorean distance between two points.
    '''

    return math.sqrt((y_a-y_b)**2 + (x_a-x_b)**2)



def threshold (input, threshold, bandwidth = 1) -> float:
    '''
    Creates smooth edges in the circle masks.
    Values below threshold become 0, values above 1.
    Values at the threshold ± half bandwidth are faded.
    '''

    # This is essentially anti aliasing without subsampling.

    result = np.interp(input, [threshold-(bandwidth/2), threshold+(bandwidth/2)], [0,1])
    return result;



def sort_labels(labels):
    '''Rearrange label ids by their group size'''
    unique, counts = np.unique(labels, return_counts = True)
    size_order = np.argsort(-counts)

    lookup = np.empty_like(unique)
    lookup[size_order] = np.arange(len(size_order))

    relabeled = lookup[labels]
    return relabeled

def CircleImage (height, width, radius, inverted = False, bandwidth = 1) -> np.ndarray:

    '''
    Returns an antialiased image of a circle with a given radius as a numpy array for masking.
    '''

    # Height, width defines the shape of the 'image'
    # Inverted flips colors
    # The bandwidth defines the width of a smooth edge of the circle (for anti aliasing)

    circle_image = np.zeros((height, width))

    if radius == 0:
        return (1 - circle_image) if inverted else (circle_image)
    else:
        for x in range(width):
            for y in range(height):
                circle_image[y,x] = euclidean_distance(
                    y+(0.5 if height%2 == 1 else 0), height / 2,
                    x+(0.5 if width%2 == 1 else 0), width / 2
                )

    if inverted:
        return 1-threshold(circle_image, radius, 1)
    else:
        return threshold(circle_image, radius, 1)



def RingImage (height, width, inner_radius, outer_radius, bandwidth) -> np.ndarray:

    '''
    Utilizes two circles to create a ring mask with smooth edges.
    Please look into CircleImage for in-depth definition.
    '''

    if outer_radius < inner_radius:
        raise ValueError("The inner radius must be smaller than the outer radius.")

    # This combines two circles to form a ring mask
    outercircle = CircleImage(height, width, outer_radius, 
                              inverted = True, bandwidth = bandwidth)
    innercircle = CircleImage(height, width, inner_radius, 
                              inverted = True, bandwidth = bandwidth)
    
    ringimg = outercircle - innercircle

    return  (ringimg - ringimg.min()) / (ringimg.max() - ringimg.min())



def RingImageSeries(height,width,steps,bandwidth) -> np.ndarray:

    '''
    This creates a 3D numpy array of the shape
    (Masks, individual Height, individual Width)
    This is used to sum and average the FFT magnitudes
    '''

    # The diameter gets bigger logarithmically, starting with a diameter of 1
    smallest_side = min(height,width)

    outer_radii = np.logspace(0, np.log2(smallest_side / 2), steps, base = 2) # 0 means it starts with 1 (log)
    inner_radii = np.append(0,outer_radii[:-1])


    all_masks = np.zeros((steps,height, width))

    for i in range(steps):
        all_masks[i] = RingImage(height,width,inner_radii[i],outer_radii[i],bandwidth = bandwidth)
            
    return all_masks



In [None]:
# Helper classes


class SimpleTimer:
    '''This class is just a timer for checking the performance'''



    def __init__(self, description):
        self.description = description
    
    def __enter__(self):
        self.timer = time.perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.time_needed = time.perf_counter() - self.timer
        print(f"{self.description} took {self.time_needed:.1f} Seconds")



class VerboseInfoTimer:
    '''
    This class wraps processes to help identifying what is done by printing info about them.
    '''

    def __init__(self, process_description: str, current_index: int = 1, total_count: int = 1, single_description: str | None = None):
        self.sd = single_description
        self.pd = process_description
        self.curr_it = current_index + 1
        self.total_count = total_count
        self.rest = total_count - self.curr_it

    def __enter__(self):
        self.timer = time.perf_counter()

        if self.total_count > 1:
            if self.curr_it == 1:
                print(f"\n{self.pd} starting…\n")
            print(f"{self.sd} {self.curr_it} of {self.total_count}…")
        else:
            print(f"{self.pd} starting…")

    def __exit__(self, type, value, traceback):
        self.time_needed = time.perf_counter() - self.timer

        if self.total_count > 1:
            print(f"{self.sd} completed. This one took {self.time_needed:.1f} seconds.")

            if self.rest > 0:
                print(f"{self.rest} more to go.")
            else:
                print(f"{self.pd} ({self.total_count}) completed.")
        else:
            print(f"{self.pd} completed. This group {self.time_needed:.1f} seconds.")

        print("")

<a id = "settings">Anchor</a>
## Code explanation: Settings
The following section sets the parameters for the main functions. I chose dictionaries to later be able to import and export JSON configs easily. 

In general there are **internal_settings** that control the execution, as well as **temporary_data** in which the intermediary results will be stored.

The execution of the processing pipeline starts [at the end of the notebook](#execution).

In [None]:
# Settings for running this Notebook

internal_settings = {}

internal_settings["files"] = {}
internal_settings["files"]["dem_folder"] = "input_geotiffs" #Here the to-be-used-geotiffs are located

internal_settings["fft"] = {}
internal_settings["fft"]["tile_size_km"] = 9 # 9 average length and width of a tile that is processed individually
internal_settings["fft"]["tile_size_px"] = 20 # 20

internal_settings["fft"]["tile_overlap_percent"] = 95 # 95 How much (by a tile length in percent) do the tiles overlap
internal_settings["fft"]["fft_levels"] = 14  # 14

internal_settings["output"] = {}
internal_settings["output"]["folder_name"] = "output_label_images" 
internal_settings["output"]["q_factor"] = 10 # 5 Bigger Q-Factor, smaller image result, faster computation
internal_settings["output"]["label_mode_filter"] = 4  # 20 this is in arbitrary units. It will adapt to the final size, so bigger images are filtered in the same way
internal_settings["output"]["label_count"] = 10 # 10 as for now

# Speed settings to quickly test things. Just set to True
if False:
    internal_settings["fft"]["tile_size_km"] = 15
    internal_settings["fft"]["tile_overlap_percent"] = 65 # 85 How much (by a tile length in percent) do the tiles overlap

# This one calculates the actual filter size in pixels regarding the final image
internal_settings["output"]["label_mode_filter_radius"] = int(internal_settings["output"]["label_mode_filter"] / (internal_settings["output"]["q_factor"] / 5))
internal_settings["fft"]["tile_overlap_multi"] = 1 / (1 - (internal_settings["fft"]["tile_overlap_percent"]/100))


In [None]:
# Define classes specific to this project

class GeographicCoordinate:
    '''Stores coordinates and their respective coordinate system.'''

    def __init__ (self, x, y, crs):
        self.x = x
        self.y = y
        self.crs = crs

    def in_another_crs(self, dst_crs):
        '''Reproject the same coordinates into another coordinate system'''

        intermediary_transformer = Transformer.from_crs(self.crs, dst_crs, always_xy = True)
        projected_bounds = intermediary_transformer.transform(self.x, self.y)

        return GeographicCoordinate(*projected_bounds,dst_crs)
    
    def __str__(self):
        return f"X: {self.x:.1f}, Y: {self.y:.1f}"


class GeographicBounds:
    '''
    Stores bounds for both the dem maps and also the tiles.
    '''
    
    def __init__(self, xmin, ymin, xmax, ymax, crs, transform = None):
        self.xmin = xmin
        self.ymin = ymin
        self.xmax = xmax
        self.ymax = ymax
        self.crs = crs
        self.transform = transform


    # Convert the bounds to another coordinate system / projection
    def in_another_crs(self, dst_crs):
        '''Reproject the same bounds into another coordinate system'''

        projected_bounds = transform_bounds(self.crs, dst_crs, *self.bounds, densify_pts = 10) 
        # Densify helps keeping bulges from reprojecting inside the bounds. We try 10, if it’s too bad we improve it.

        return GeographicBounds(*projected_bounds, dst_crs)


    # Give back a tuple of bounds
    @property
    def bounds (self): 
        return (self.xmin,self.ymin,self.xmax,self.ymax)

    # Center along x extent (undefined unit. can be degrees, metres,…)
    @property
    def center_x (self): 
        return (self.xmin+self.xmax)/2


    # Center along y extent (undefined unit. can be degrees, metres,…)
    @property
    def center_y (self): 
        return (self.ymin+self.ymax)/2


    # This is a projected coordinate system for reprojecting the DEM file to
    # The so reprojected image is where the individual tiles are taken from 
    @property
    def intermediary_aeqd_crs (self):
        return CRS.from_proj4(
        f"+proj = aeqd +lat_0 = {self.in_another_crs(CRS('EPSG:4326')).center_y} +lon_0 = {self.in_another_crs(CRS('EPSG:4326')).center_x} "
        "x_0 = 0 +y_0 = 0 +ellps = WGS84 +units = m +no_defs +type = crs")

    def get_bound_length(self, side: str):
        '''You can ask for the side (left, top, right, bottom) and get the length in km in return'''

        if self.crs.is_geographic:
            if side == "left":
                _, _, distance = geod.inv(self.xmin, self.ymin, self.xmin, self.ymax)
            elif side == "top":
                _, _, distance = geod.inv(self.xmin, self.ymax, self.xmax, self.ymax)
            elif side == "right":
                _, _, distance = geod.inv(self.xmax, self.ymin, self.xmax, self.ymax)
            elif side == "bottom":
                _, _, distance = geod.inv(self.xmin, self.ymin, self.xmax, self.ymin)
            else:
                raise ValueError("The argument for side is not valid.")
        elif self.crs.is_projected:

            if side == "left":
                distance = abs(self.ymax-self.ymin)
            elif side == "top":
                distance = abs(self.xmax-self.xmin)
            elif side == "right":
                distance = abs(self.ymax-self.ymin)
            elif side == "bottom":
                distance = abs(self.xmax-self.xmin)
            else:
                raise ValueError("The argument for side is not valid.")
        else:
            raise ValueError("CRS is neither projected nor geographic.")
        
        return distance
    

    def get_projected_extent(self):
            '''Returns width and height in km of the azimuthal equidistant projected bounds (x,y) as a tuple.'''

            # If this bounds object refers to geographic bounds, rerun the function on the converted bounds recursively.
            # I think this is ingenious of me ha ha, if it not fails however... ^^

            if self.crs.is_geographic:
                return self.in_another_crs(self.intermediary_aeqd_crs).get_projected_extent()
            elif self.crs.is_projected:
                return (abs(self.xmax-self.xmin), abs(self.ymax-self.ymin))
            else: 
                raise ValueError("Coordinate system seems broken. Neither geographic nor projected.")
            


    # Return information about the object (for to be used in a print statement for example)
    def __str__(self):
        infostring = "Infos about the GeographicBounds (all rounded):\n"
        infostring += f"-----\n"
        infostring += f"Geographic Bounds:\n"
        infostring += f"X Min: {self.xmin:>7.2f} \n"
        infostring += f"Y Min: {self.ymin:>7.2f} \n"
        infostring += f"X Max: {self.xmax:>7.2f} \n"
        infostring += f"Y Max: {self.ymax:>7.2f}\n"
        infostring += f"-----\n"
        infostring += f"Left Width: {self.get_bound_length('left'):.2f} m\n"
        infostring += f"Bottom Width: {self.get_bound_length('bottom'):.2f} m\n"
        infostring += f"Right Width: {self.get_bound_length('right'):.2f} m\n"
        infostring += f"Top Width: {self.get_bound_length('top'):.2f} m\n"
        infostring += f"-----\n"
        infostring += f"Center: {self.center_y:.3f} (Y), {self.center_x:.3f} (X)°\n"
        infostring += "\n"
        return infostring


    def as_list(self):
        return [self.xmin, self.ymin, self.xmax, self.ymax]

In [None]:
# The heart of the algorithm

class AugmentedDEM:
    '''
    Handles the segmentation and individual processing of each GeoTIFF (aka DEM map).
    '''
    # Settings will be handed over before the processing begins
    settings = None

    def __init__(self, dem_path):
        '''
        Load key metadata from the GeoTIFF at the specified filepath.
        '''
    
        # Hint: No DEM data is loaded until processing.
        with rasterio.open(dem_path) as dem_file:
            self.src_crs = dem_file.crs
            
            
            # Save the geographic WGS84 bounds in a new nested object
            self.geo_bounds = GeographicBounds(dem_file.bounds.left, 
                                               dem_file.bounds.bottom,
                                               dem_file.bounds.right,
                                               dem_file.bounds.top,
                                               dem_file.crs,
                                               dem_file.transform)
            
            # Save the original dimensions inside this object
            self.dem_width_px = dem_file.width
            self.dem_height_px = dem_file.height
        
        # Save the original path for later 
        self.dem_path = dem_path


    def _resample_original_dem(self):
        # This is part of the fast pipeline
        '''Takes the original dem and reprojects it to azimuthal equidistant.
        This might lead to inaccuracies at the bounds. However those inaccuracies
        are negligible, that the areas labeled will fall into the right categories.'''

        with rasterio.open(self.dem_path) as src:

            tile_size_px = int(self.settings["fft"]["tile_size_px"])

            intermediary_geo_bounds = self.geo_bounds.in_another_crs(self.geo_bounds.intermediary_aeqd_crs)
            dst_crs = intermediary_geo_bounds.crs

            target_res = (self.settings["fft"]["tile_size_km"] * 1000) / self.settings["fft"]["tile_size_px"] 
            # Metres per pixel

            scaled_transform, projected_dem_width_pixels, projected_dem_height_pixels = \
            calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds, resolution = target_res, densify_pts=101
            )

            projected_dem = np.empty((src.count, projected_dem_height_pixels, projected_dem_width_pixels))


            for i in range(1, src.count + 1):
                    reproject(
                        source=rasterio.band(src, i),
                        destination=projected_dem[i-1],
                        src_transform=src.transform,
                        src_crs=src.crs,
                        dst_transform=scaled_transform,
                        dst_crs=dst_crs,
                        resampling=ResampleEnum.bilinear
                    )

            # Set the number of tiles, the projected DEM map will be split into
            tile_multiplier = internal_settings["fft"]["tile_overlap_multi"]

            num_tiles_x  = int((projected_dem_width_pixels // tile_size_px) * tile_multiplier)
            num_tiles_y  = int((projected_dem_height_pixels // tile_size_px) * tile_multiplier)

            # Set the starting positions of each tile
            start_left = (projected_dem_width_pixels % tile_size_px) // 2
            end_right = projected_dem_width_pixels - start_left - tile_size_px

            start_top = (projected_dem_height_pixels % tile_size_px) // 2
            end_bottom = projected_dem_height_pixels - start_top - tile_size_px
            
            tile_starts_x = np.linspace(start_left,end_right,num_tiles_x, endpoint=False)
            tile_starts_y = np.linspace(start_top,end_bottom,num_tiles_y, endpoint=False)

            tile_starts_x = tile_starts_x.astype(int)
            tile_starts_y = tile_starts_y.astype(int)
            

            tile_starts_grid_x, tile_starts_grid_y = np.meshgrid(tile_starts_x, tile_starts_y)
            
            tile_starts_grid_xy = np.stack([tile_starts_grid_x.reshape(-1), \
                                            tile_starts_grid_y.reshape(-1)], axis = 1)

            # Calculate the total number of tiles
            num_tiles_total = num_tiles_x * num_tiles_y

            # Create a list for the sample positions to later grid-interpolate from them
            self.tile_centers_orig = []
            
            # Create the array for the resampled and reprojected tiles that later will be analyzed
            self.tiles_resampled = np.zeros((num_tiles_total, tile_size_px, tile_size_px))


            intermediary_transformer = Transformer.from_crs(dst_crs, self.geo_bounds.crs, always_xy = True)
            
            # Flip the projected DEM map vertically
            projected_dem[0] = np.flip(projected_dem[0], axis=0)

            # Create a random generator for scattering the sample points a bit
            # This prevents the map from looking very "pixelated" and rather natural
            # It scatters the sample points, not the results, so the results are still accurate
            rng = np.random.default_rng()
            
            

            with SimpleTimer("Calculating the center positions"):
                for i in range(num_tiles_total):

                    # Create randomness as mentioned above for x and y coordinates
                    random_shift = rng.uniform(-tile_size_px/2,tile_size_px/2,2)                        

                    # Set up the horizontal bounds of the current tile
                    x_base = tile_starts_grid_xy[i,0]+random_shift[0]
                    x_base = clamp(x_base, start_left, end_right)
                    x_start = x_base
                    x_end = x_base + tile_size_px
                    x_center = (x_start + x_end) // 2

                    # Set up the vertical bounds of the current tile
                    y_base = tile_starts_grid_xy[i,1] + random_shift[1]
                    y_base = clamp(y_base, start_top, end_bottom)
                    y_start = y_base
                    y_end = y_base + tile_size_px
                    y_center = (y_start + y_end) // 2

                    # Extract the tiles
                    self.tiles_resampled[i] = projected_dem[0][y_start:y_end, x_start:x_end]

                    # Save the metric centerpoints relative to the center
                    # to later infer the geographic position of the samples
                    projected_x_center = (x_center - (projected_dem_width_pixels / 2)) * target_res
                    projected_y_center = (y_center - (projected_dem_height_pixels / 2)) * target_res
                
                    # Append the geographic coordinates (relative to the source coordinate system)
                    # to our list of the tile centers (which are the sample locations)
                    # Optimization idea: this could be parallelized
                    self.tile_centers_orig.append( GeographicCoordinate(*intermediary_transformer.transform(\
                        projected_x_center, projected_y_center), self.geo_bounds.crs) )




    def _create_fft_tiles(self):
        '''
        This creates an 2D FFT magnitude map for each individual tile of the map.
        '''
        fft_input_array = pyfftw.empty_aligned((len(self.tiles_resampled),
                                                self.settings["fft"]["tile_size_px"],
                                                self.settings["fft"]["tile_size_px"]),
                                                dtype = "complex64")

        fft_output_array = pyfftw.empty_aligned((len(self.tiles_resampled),
                                                self.settings["fft"]["tile_size_px"],
                                                self.settings["fft"]["tile_size_px"]),
                                                dtype = "complex64")

        fft_execution_plan = pyfftw.FFTW(fft_input_array,
                                        fft_output_array, axes = (1,2),
                                        direction = "FFTW_FORWARD",
                                        flags = ("FFTW_MEASURE",))

        fft_input_array[:] = self.tiles_resampled #Important. [:] ensures the reserved empty array is used, and no new one is created! It will not work without that special slicing.

        fft_execution_plan.execute()

        fft_magnitude_spectra_unshifted = np.log(np.abs(fft_output_array) + 1) 
        self.fft_footprint = np.fft.fftshift(fft_magnitude_spectra_unshifted, axes = (1,2)) #Centered FFT


        # Free memory. The resampled tiles are no longer needed
        del self.tiles_resampled




    def _bin_and_average_fft_tiles(self):
        '''
        This splits the FFT magnitude map for each tile into parts that average in a given distance.
        This is essentially the "FFT Footprint".
        The result will be a numpy array in the shape of (number_of_tiles, number_of_fft_levels)
        '''

        # Set up Dask Arrays for faster computation of the weighted averages


        # The general dimensions are:
        # 1. Number of tiles (total)
        # 2. Height of a tile
        # 3. Width of a tile
        # 4. Number of levels (circle filters)

        # To multiply and divide all the dimensions have to be in the same order for all arrays

        # Weights – here go the circles
        # Weights are still (#4, #2, #3), change them to #2, #3, #4


        with SimpleTimer("Creating the weighted averages per mask"):
            
            da_weights = da.from_array(
                # The height becomes axis 0, width becomes axis 1, and different masks become axis 2
                np.transpose(temporary_data["circle_masks"],(1,2,0)) 
            )

            # First dimension (#1) is missing, add it 
            # Just a note: the tiles themselves are not x-y adressed, but the get a continuous index (1D)
            da_weights = da_weights[np.newaxis, ...]
            # Now the form of the weights is (#1, #2, #3, #4) like explained above

            # Take the results of the fft and place them also in a dask array
            # of the wanted form (see above) 
            da_magnitude_spectra = da.from_array(self.fft_footprint)[...,np.newaxis]

            # Sum the values of the areas covered by the circles, to later have something to divide by
            # to generate weighted sums
            da_weights_sum = da.sum(da_weights, axis = (1,2)) 

            # Sum the results of the FFTs multiplied by the weights of the different levels (circle filters)
            da_sum_of_spectra = da.sum(da_magnitude_spectra * da_weights, axis = (1, 2))

            # Calculate the weighted average 
            da_weighted_averages = da_sum_of_spectra / da_weights_sum

            '''# Calculate the weighted variance 
            da_weighted_variance = da.sum(
                da_weights * (da_magnitude_spectra - da_weighted_averages[:, np.newaxis, np.newaxis, :]) ** 2,
                  axis=(1, 2)) / da_weights_sum'''

            # The result is now in the shape of #1, #4 -> Number of Tiles, Number of Levels
            self.fft_magnitude_all_levels = da_weighted_averages.compute() 
            # self.fft_magnitude_all_levels = da_weighted_variance.compute() 


    
    def create_image(self):
        '''
        This method creates a colored image. 
        After the clustering has taken place the found labels will be translated to colors for each pixel.
        '''

        # Each tile has gotten a label
        # We know the center of each tile
        # That way we can interpolate an image by combining the two data sources
        # Remember the tiles are not in a strict raster, because they become taller
        # the closer they get to the equator.

        position_tuple = [(tile.y, tile.x) for tile in self.tile_centers_orig]

        # Calculate dimensions of the output image
        output_pic_width = self.dem_width_px // self.settings["output"]["q_factor"]
        output_pic_height = self.dem_height_px // self.settings["output"]["q_factor"]

        # Set up grid for interpolating
        tmp_grid_res_y = np.linspace(self.geo_bounds.ymax, self.geo_bounds.ymin, output_pic_height )
        tmp_grid_res_x = np.linspace(self.geo_bounds.xmin, self.geo_bounds.xmax, output_pic_width )
        tmp_grid_y, tmp_grid_x = np.meshgrid(tmp_grid_res_y, tmp_grid_res_x, indexing = "ij")

        # Interpolate the data using the geographic positions of the sample points 
        # to the output image
        output_image = griddata(position_tuple,
        self.labels, (tmp_grid_y, tmp_grid_x), method = 'nearest') 


        with SimpleTimer("Filtering of the mapped labels"):
            # The mode filter smoothes the result in a way that 
            # it takes the mode in a certain radius (the values that are most common)
            self.labels_interpolated = mode_filter(output_image, self.settings["output"]["label_mode_filter_radius"])


        # Start creating the image
   
        # I create my own color mapping for I like those colors
        custom_colors_str = ["#053608" ,
                            "#bdf2e6" ,
                            "#1c5532" ,
                            "#62f8ee" ,
                            "#ffba20" ,
                            "#2c7eb5" ,
                            "#b9c5d6" ,
                            "#404383" ,
                            "#91f88f" ,
                            "#759d7c" ,
                            "#9395a3"  
        ]

        # The image becomes colors mapped from the custom colors indexed by the constrained labels
        output_image_rgb = color_to_numpy(custom_colors_str)[constrain_labels(self.labels_interpolated)]
        self.image_rgb = (np.transpose(output_image_rgb, (2,0,1))).astype("uint8")


    def write_image(self):
        '''Writes the created image to the disk.'''

        output_pic_width = self.dem_width_px // self.settings["output"]["q_factor"]
        output_pic_height = self.dem_height_px // self.settings["output"]["q_factor"]

        # Calculate pixel resolution for the GeoTIFF output image (degrees per pixel)
        res_x = ((self.geo_bounds.xmax - self.geo_bounds.xmin) / output_pic_width)
        res_y = ((self.geo_bounds.ymax - self.geo_bounds.ymin) / output_pic_height)

        # Create affine transform (pixel coordinates → WGS84)
        transform = self.geo_bounds.transform * Affine.scale(1*self.settings["output"]["q_factor"])


        # I put together a comprehensive (very long) filename to recall settings later
        filename_string = os.path.basename(self.dem_path) + \
        f"tlszkm {self.settings['fft']['tile_size_km']:.1f}  " + \
        f"tlszpx {self.settings['fft']['tile_size_px']:.0f}  " + \
        f"fftlvls {self.settings['fft']['fft_levels']:.0f}  " + \
        f"qfctr {self.settings['output']['q_factor']:.0f}  " + \
        f"{self.settings['fft']['tile_overlap_multi']:.1f}x "+ \
        f"-ovrlp-pct {self.settings['fft']['tile_overlap_percent']:.0f} "+ \
        f"fltrrds {self.settings['output']['label_mode_filter_radius']:.0f}"+ \
        f"lblct {self.settings['output']['label_count']:.0f}"+ \
        f".tif"

        filepath = os.path.join(self.settings["output"]["folder_name"], filename_string)

        # Write GeoTIFF
        with rasterio.open(
            filepath,
            "w",
            driver = "GTiff",
            height = output_pic_height,
            width = output_pic_width,
            count = 3, 
            dtype = "uint8",
            crs = self.geo_bounds.crs, 
            transform = transform,
        ) as dst:
            dst.write(self.image_rgb[0],1)
            dst.write(self.image_rgb[1],2)
            dst.write(self.image_rgb[2],3)


        # Hereafter the clustering takes place (global via class method)
        # After that each colored label map is written to the disk.

    def process_dem_fast(self):
        '''This is the faster and little less precise pipeline'''
        
        with SimpleTimer("Resampling the original DEM to azimuthal equidistant"):
            self._resample_original_dem()

        with SimpleTimer("Creating FFT footprints"):
            self._create_fft_tiles()

        with SimpleTimer("Binning and averaging FFT footprints"):
            self._bin_and_average_fft_tiles()
        # Hereafter the clustering takes place (global via class method)
        # After that each colored label map is written to the disk.


    @classmethod
    def create_labels(cls, instances):
        '''
        Create one long numpy array from all the tile’s labels of every given dem map.
        '''
         
        all_magnitudes_np = np.concatenate([d.fft_magnitude_all_levels for d in instances], axis = 0)

        # Normalize across all DEMs
        all_magnitudes_np = (all_magnitudes_np - np.median(all_magnitudes_np, axis = 0, keepdims = True)) / np.std(all_magnitudes_np, axis = 0, keepdims = True)

        # Store the original sizes to late split the labels again
        original_tile_count = [d.fft_magnitude_all_levels.shape[0] for d in instances]
        original_indices_for_splitting = np.cumsum(original_tile_count[:-1])

        # Clustering section
        n_clusters = cls.settings["output"]["label_count"]
        k_means_clusterer = KMeans(n_clusters = n_clusters, random_state = 550)


        with SimpleTimer("Clustering"):
                all_magnitudes_np = np.nan_to_num(all_magnitudes_np, nan = 0)
                all_dem_labels = k_means_clusterer.fit_predict(all_magnitudes_np)

        with SimpleTimer("Relabeling Cluster IDs by size"):
                all_dem_labels = sort_labels(all_dem_labels)

        # Distribute the results back to the augmented dem objects
        single_dem_lables = np.split(all_dem_labels, original_indices_for_splitting) 

        for i, labels in enumerate(single_dem_lables):
            instances[i].labels = labels
        

## Code explanation: This is the core of the project

The **AugmentedDem Class** is where most of the processing takes place.
Each instance corresponds to a GeoTIFF (that can be downloaded via the seccondary notebook).

### What it stores
It stores general information that is read from the GeoTIFF file. Namely it’s bounds

<a id = "execution">Anchor</a>

## Code explanation: Processing pipeline

The following cell executes all of the processing.
Settings are set at the [top of the notebook](#settings)

In [None]:
# Create a place for the newly made calculations
temporary_data = {}
temporary_data["augmented_dems"] = []


# Assign the settings to the class variable of AugmentedDem
AugmentedDEM.settings = internal_settings


# Create objects (AugmentedDem) for all .tif-Files inside the dem_folder
with os.scandir(internal_settings["files"]["dem_folder"]) as dirlist:
    temporary_data["augmented_dems"] = [AugmentedDEM(d.path) 
                                        for d in dirlist if re.search(r"\.tif$",d.name)]
    

# Set up the masks for the weighted average
with VerboseInfoTimer("Creating the ring masks"):
    temporary_data["circle_masks"] = RingImageSeries(internal_settings["fft"]["tile_size_px"],
                                                        internal_settings["fft"]["tile_size_px"],
                                                        internal_settings["fft"]["fft_levels"],
                                                        bandwidth = 1)


# This is the main processing. Here the "FFT" footprints are created
for i, d in enumerate(temporary_data["augmented_dems"]):
    with VerboseInfoTimer("Processing of all DEM maps", 
                    current_index = i, 
                    total_count = len(temporary_data["augmented_dems"]),
                    single_description = "Processing DEM map"):
        d.process_dem_fast()


# Create_labels(temporary_data["augmented_dems"], internal_settings)
with VerboseInfoTimer("Clustering and creating labels for all DEM maps"):
    AugmentedDEM.create_labels(temporary_data["augmented_dems"])


# Using the labels created across all DEMs we calculate the resulting maps for each individual dem map
for i, d in enumerate(temporary_data["augmented_dems"]):
    with VerboseInfoTimer("Creating labeled images for all regions", 
                    current_index = i, 
                    total_count = len(temporary_data["augmented_dems"]),
                    single_description = "Creating labeled image"):
        d.create_image()
        d.write_image()

In [None]:
if False:
    # Testing
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature

    current_augmented_dem = temporary_data["augmented_dems"][0]
    image = np.transpose(current_augmented_dem.image_rgb,[1,2,0])

    west, south, east, north = current_augmented_dem.geo_bounds.as_list()
    extent = (west, east, south, north)

    # Start plotting
    fig = plt.figure(figsize = (10, 15), dpi = 256)
    ax = plt.axes(projection = ccrs.Orthographic(central_longitude = current_augmented_dem.geo_bounds.center_x, central_latitude = current_augmented_dem.geo_bounds.center_y))

    ax.set_extent(extent)

    # Add features
    ax.add_feature(cfeature.BORDERS, edgecolor = "black", linewidth = 0.5)
    ax.add_feature(cfeature.COASTLINE, edgecolor = "white", linewidth = 1.0)
    ax.add_feature(cfeature.LAKES, edgecolor = 'black', facecolor = 'none')
    ax.add_feature(cfeature.RIVERS, edgecolor = 'blue', linewidth = 0.3)
    ax.add_feature(cfeature.OCEAN, zorder = 1, facecolor = '#102d2d')

    # Plot image (projecting from PlateCarree, which assumes lat/lon coordinates)
    ax.imshow(image, origin = 'upper', extent = extent, transform = ccrs.PlateCarree())

    # Optionally, add gridlines
    ax.gridlines(draw_labels = True, linewidth = 0.2)

    # Show
    plt.title(f"Colored landscape clusters for {os.path.basename(current_augmented_dem.dem_path)}")
    plt.show()