# Visualize IoU

This Python code reads geospatial raster files from two directories, creates binary masks for a specific class (class 2), and visualizes the original data, the binary masks, and the intersection and union of the masks. It uses matplotlib for visualization and rasterio for reading the raster files. The visualization is a 2x3 grid of plots. The code also removes the x and y ticks from all subplots for a cleaner look.

In [None]:
import matplotlib.pyplot as plt
import rasterio
import os
from random import randint

gt_path = "/proj-soil/data/GT/20231004/3-cut_to_grid/1024px_extent1/"
ign_path = "/proj-soil/data/IGN/2-cut_to_grid/1024px_extent1/smp-unet-resnet34-imagenet_RVBI"

random_idx = randint(0, 20)

for i, gt_file in enumerate(os.listdir(gt_path)):

    id = gt_file.split("_")[-1].split(".")[0]
    print(f'{id = }')

    ign_filename = [file for file in os.listdir(ign_path) if id in file]

    if len(ign_filename) == 0:
        continue
    with rasterio.open(os.path.join(gt_path, gt_file)) as file:
        gt = file.read()
    with rasterio.open(os.path.join(ign_path, ign_filename[0])) as file:
        ign = file.read()

    break
_, axs = plt.subplots(ncols=3, nrows=2, figsize=(15, 10))

axs[0, 0].set_title("Ground Truth")
axs[0, 0].imshow(gt[0])
axs[1, 0].set_title("IGN")
axs[1, 0].imshow(ign[0])

gt_class = gt[0].copy()
gt_class[gt_class != 2] = 0
axs[0, 1].set_title("GT, Class 2")
axs[0, 1].imshow(gt_class, cmap="Greys_r")

ign_class = ign[0].copy()
ign_class[ign_class != 2] = 0
axs[1, 1].set_title("IGN, Class 2")
axs[1, 1].imshow(ign_class, cmap="Greys_r")

intersection = gt_class & ign_class
axs[0, 2].set_title("intersection, Class 2")
axs[0, 2].imshow(intersection, cmap="Greys_r")

union = gt_class | ign_class
axs[1, 2].set_title("union, Class 2")
axs[1, 2].imshow(union, cmap="Greys_r")


for ax1 in axs:
    for ax2 in ax1: 
        ax2.set_xticks([])
        ax2.set_yticks([])

# Metrics

## Load from CSV

In [None]:
import pandas as pd
import glob
import os
import numpy as np


import matplotlib.pyplot as plt
import matplotlib.patches as patches
plt.rc('text', usetex=True)

import seaborn as sns
import re

root = '/soil_fribourg/proj-soils/csv/metrics'

# define a function to rename the architectures, such that it is clear
# from which institution they come from
def rename_archs(x):
    new_arch = []
    for arch in x["arch"]:
        if "heig" in arch:
            # arch = re.sub(r'[-_]\d+cm', '', arch)
            arch = re.sub(r'[_]\d+cm', '', arch)
            arch = arch.replace("heig-vd", "HEIG-VD-original")
            arch = arch.replace("heigvd", "HEIG-VD")
            new_arch.append(arch)


        elif "ADELE" in arch:
            new_arch.append("OFS_" + arch)
        else:
            new_arch.append("IGN_" + arch)    
    return new_arch

def rename_extents(x):
    new_extents = []
    for extent in x["extent"]:
        if "seed6" in extent:
            new_extents.append("seed6-adj")
            continue
        new_extents.append(extent)
    return new_extents

        

all_metrics = []
all_paths = glob.glob(os.path.join(root, "*.csv"))
for path in all_paths:
    file = path.split(os.sep)[-1]
    if not file.startswith("metrics"):
        continue

    try:
        kind, inst, typee, extent, res = file.split(".")[0].split("_")
    except ValueError:
        raise ValueError(f'file {file} does not have the right format. It should be metrics_<kind>_<inst>_<type>_<extent>_<res>.csv\n(e.g., metrics_heigvd-mixed-145k-10cm_mc_seed6-adjusted_10cm)')

    # print(f'{kind, inst, typee, extent = }')
    all_metrics.append(pd.read_csv(path).assign(
        extent=extent, res=res))


metrics = (pd
    .concat(all_metrics, ignore_index=True)
    .reset_index(drop=True)
    .assign(
        arch = lambda x: rename_archs(x),
        extent = lambda x: rename_extents(x))
)

metrics.to_csv(os.path.join(root, "STDL_proj-soils_evaluation_metrics.csv"), index=False)
# define subset of metrics to compare across institutions
cross_inst = (metrics
    .query("arch in ['heig-vd', 'IGN_smp-unet-resnet34-imagenet_RVBI', 'OFS']")
    .assign(arch_short = lambda x: [arch if not "IGN" in arch else "IGN" for arch in x["arch"]])
    )

binary = metrics.query("type == 'binary'")
multiclass = metrics.query("type == 'multiclass'")

## Plots

### Definitions

In [None]:
def plot_archs_multiclass(data, y, hue, ylabel, title, dpi, figsize=(5, 5), ax=None, **kwargs):
    
    sns.set(
        rc={
            'figure.figsize':figsize,
            'figure.dpi': dpi,
            'axes.facecolor': 'white'
            },
        style="darkgrid"
        )

    sns.barplot(
        data=data,
        x="class_name",
        y=y,
        hue=data[hue],
        **kwargs
    )

    # Add horizontal gridlines only
    plt.gca().yaxis.grid(True, color='grey', linestyle='-', linewidth=0.5)
    plt.gca().yaxis.set_major_locator(plt.MultipleLocator(0.1))
    plt.gca().xaxis.grid(False)

    # Place the legend in the top right corner
    plt.legend(loc='lower right', framealpha=0.9)

    # Set the aesthetic style of the plots
    sns.set_style("whitegrid")


    # Add labels and title
    plt.xlabel('Class')
    plt.xticks(rotation=45)
    plt.ylabel(ylabel)
    plt.title(title)

    # Highlight specific class names with a green rectangle
    ax = plt.gca()  # Get the current Axes instance

    # Loop over each class name in the DataFrame
    for i, class_name in enumerate(data.class_name):
        # If the class name contains "sol", set the rectangle properties for "sol"
        if "sol" in class_name:
            col = "green"  # color of the rectangle border
            xy = (i-0.5, ax.get_ylim()[0] + 0.005)  # bottom left corner of the rectangle
            width = 1  # width of the rectangle
            height = ax.get_ylim()[1]-ax.get_ylim()[0] - 0.03  # height of the rectangle
            linewidth = 3  # width of the rectangle border
        # If the class name is "mean", set the rectangle properties for "mean"
        elif class_name == "mean":
            col = "black"  # color of the rectangle border
            xy = (-1, ax.get_ylim()[0] - 0.1)  # bottom left corner of the rectangle
            width = 1.5  # width of the rectangle
            height = 1.5  # height of the rectangle
            linewidth = 2  # width of the rectangle border
        # If the class name is neither "sol" nor "mean", skip to the next iteration
        else:
            continue
            col = "grey"  # color of the rectangle border
            xy = (i-0.5, ax.get_ylim()[0] + 0.005)  # bottom left corner of the rectangle
            width = 1  # width of the rectangle
            height = ax.get_ylim()[1]-ax.get_ylim()[0] - 0.03  # height of the rectangle
            linewidth = 1  # width of the rectangle border
        # Create a rectangle patch with the specified properties
        rect = patches.Rectangle(
            xy=xy,
            width=width,
            height=height,
            linewidth=linewidth,
            edgecolor=col,
            facecolor='none',  # no fill color
            zorder=5)  # draw order of the rectangle
        # Add the rectangle patch to the plot
        ax.add_patch(rect)


    # Display the plot
    plt.show()
    return ax

In [None]:
def plot_archs_binary(x, y, hue, data, ylabel, title, dpi, figsize=(5, 5), ax=None, max_ytick=None, **kwargs):
    
    sns.set(
        rc={
            'figure.figsize':figsize,
            'figure.dpi': dpi,
            'axes.facecolor': 'white'
            },
        style="darkgrid"
    )
    
    if ax is None:
        _, ax = plt.subplots()

    sns.barplot(
        x=data[x],
        y=data[y],
        hue=data[hue],
        ax=ax,
        **kwargs
    )
    
    # Add horizontal gridlines only every 0.1
    ax.yaxis.grid(True, color='grey', linestyle='-', linewidth=0.5)
    if max_ytick is None:
        ax.yaxis.set_major_locator(plt.MultipleLocator(0.1))
    else:
        ax.set_yticks(np.arange(0, max_ytick+0.1, 0.1))
    ax.xaxis.grid(False)

    # Place the legend in the low right corner
    ax.legend(loc='lower right', fontsize=9, framealpha=0.9)

    # Add labels and title
    ax.set_xlabel('Class')
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    return ax

In [None]:
data = (metrics
    .query("extent.str.contains('extent')")
    .query("arch != 'OFS_ADELE_proto'")
    .assign(arch = lambda x: x["arch"].str.replace("HEIG-VD-original", "HEIG-VD"))
)

### MCC, Extents

In [None]:
(data
    .query("type == 'binary' & class_name == 'mean'")
    .pivot_table(index=["arch"], columns="extent", values="mcc")
    .sort_values("extent1", ascending=False)
    .reset_index()
)

In [None]:
data.query("class_name == 'mean' & extent == 'extent1' & type == 'binary'")

In [None]:
# enable tex
plt.rc('text', usetex=True)

plot_archs_binary(
    x="extent",
    y="iou",
    hue="arch",
    data=data.query("class_name == 'mean' & type == 'binary'"),
    ylabel="mIoU",
    title=r"\textbf{mIoU Values for different extents and models}",
    dpi=200,
    figsize=(6, 6),
    max_ytick=0.8,
    order=["extent1", "extent1-masked", "extent2"],
    hue_order=['IGN_smp-unet-resnet34-imagenet_RVBI', 'OFS_ADELE2(+SAM)',
       'IGN_smp-unet-resnet34-imagenet_RVBIE', 'HEIG-VD',
       'IGN_smp-fpn-resnet34-imagenet_RVBIE', 'IGN_odeon-unet-vgg16_RVBI',
       'IGN_smp-fpn-resnet34-imagenet_RVBI', 'IGN_odeon-unet-vgg16_RVBIE'],
)

plt.xlabel("Model")
plt.show()

### IoU, Multiclass, Architectures, seed6-adj

In [None]:
data.query("type == 'multiclass' & extent == 'extent1' & class_name == 'mean'").sort_values("iou", ascending=False).arch.unique()

In [None]:
plt.figure(figsize=(15, 6))

extent = "extent1"

plot_archs_multiclass(
    data=data.query("type == 'multiclass' & extent == @extent"),
    y = "iou",
    hue = "arch",
    ylabel="mIoU",
    title=rf"\textbf{{mIou values for different classes and models on the {extent}}}",
    dpi=200,
    hue_order=['OFS_ADELE2(+SAM)', 'HEIG-VD',
       'IGN_smp-unet-resnet34-imagenet_RVBIE',
       'IGN_smp-unet-resnet34-imagenet_RVBI',
       'IGN_smp-fpn-resnet34-imagenet_RVBIE',
       'IGN_smp-fpn-resnet34-imagenet_RVBI', 'IGN_odeon-unet-vgg16_RVBI',
       'IGN_odeon-unet-vgg16_RVBIE'],
)

### MCC, Binary, resolutions

In [None]:
metrics.query("type == 'binary' & extent == 'seed6-adj'")[~metrics["arch"].str.contains("32k")].arch.unique()

In [None]:
metrics.query("type == 'binary' & extent == 'seed6-adj' & class_name == 'mean'")[~metrics["arch"].str.contains("32k")][["arch", "mcc", "res"]].sort_values("res", ascending=True)

In [None]:
# enable tex
plt.rc('text', usetex=True)

plot_archs_binary(
    x="arch",
    y="mcc",
    hue="res",
    data=metrics.query("type == 'binary' & extent == 'seed6-adj' & class_name == 'mean'")[~metrics["arch"].str.contains("32k")],
    ylabel="MCC",
    title=r"\textbf{MCC Values for different models and resolutions}",
    dpi=50,
    figsize=(6, 6),
    hue_order=("10cm", "20cm", "40cm"),
    order=[
        'HEIG-VD-original',
        'HEIG-VD-10cm-71k',
        'HEIG-VD-mixed-145k',
        ]
)

plt.xlabel("Model")
plt.show()

In [None]:
multiclass

### IoU, Multiclass, Architectures, seed6-adj

In [None]:
plt.figure(figsize=(15, 6))
# sns.barplot(
plot_archs_multiclass(
    data=multiclass.query("arch == 'HEIG-VD-mixed-145k'"),
    # data=multiclass.query("arch == 'HEIG-VD-10cm-71k'"),
    y = "iou",
    hue = "res",
    # hue_order="hue",
    ylabel="IoU",
    title=r"\textbf{IoU values of the HEIG-VD-mixed-145k model for different classes and resolutions}",
    # title=r"\textbf{IoU values of the HEIG-VD-10cm-71k model for different classes and resolutions}",
    dpi=200,
    hue_order=("10cm", "20cm", "40cm"),
)

# Counts

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patheffects as path_effects
plt.rcParams['text.usetex'] = True
import numpy as np
import os

This function `transform2short(df)` transforms a DataFrame of class counts into a long format DataFrame that includes both the original counts and their relative frequencies. It first calculates the total count and relative frequencies for each class. Then, it reshapes the counts and relative frequencies into long format. After combining these two long format DataFrames, it filters out rows where the class name starts with 'rel_freq_' and rows where the count is less than or equal to 1. The function returns the transformed DataFrame.

In [None]:
def transform2short(df):
    # Calculate the total count across all classes
    total_count = df.sum(axis=1)

    # Compute the relative frequency for each class
    relative_frequencies = df.div(total_count, axis=0)

    # Reshape the original counts into a long format
    counts_df = df.melt(var_name='class', value_name='count')

    # Calculate the relative frequencies in a similar long format
    relative_frequencies_long = relative_frequencies.melt(var_name='class', value_name='relative')

    # Combine the two DataFrames
    combined_df = pd.DataFrame({
        'class': counts_df['class'],
        'count': counts_df['count'],
        'relative': relative_frequencies_long['relative']
    })

    # Filter out the extra rows
    df = combined_df[~combined_df['class'].str.startswith('rel_freq_')]
    df.reset_index(drop=True, inplace=True)

    df = df.query("count > 1")

    return df


In [None]:
root= "/soil_fribourg/proj-soils/csv/counts"
extent = "extent2"

# ground truth
allgt_multiclass_counts = transform2short(pd.read_csv(
    os.path.join(root, "all_gt_counts.csv")).drop(["type", "data"], axis=1))

# binary
gt_binary_counts = transform2short(
    pd.read_csv(os.path.join(root, f"counts_heig_vd_binary_{extent}.csv")).query("data == 'gt'").drop(["type", "data"], axis=1))
heigvd_binary_counts = transform2short(
    pd.read_csv(os.path.join(root, f"counts_heig_vd_binary_{extent}.csv")).query("data == 'pred'").drop(["type", "data"], axis=1))
ign_binary_counts = transform2short(
    pd.read_csv(os.path.join(root, f"counts_ign_binary_{extent}.csv")).query("data == 'pred'").drop(["type", "data"], axis=1))

# multiclass
try:
    gt_multiclass_counts = transform2short(
        pd.read_csv(os.path.join(root, f"counts_heig_vd_multiclass_{extent}.csv")).query("data == 'gt'").drop(["type", "data"], axis=1))
    heigvd_multiclass_counts = transform2short(
        pd.read_csv(os.path.join(root, f"counts_heig_vd_multiclass_{extent}.csv")).query("data == 'pred'").drop(["type", "data"], axis=1))
    ign_multiclass_counts = transform2short(
        pd.read_csv(os.path.join(root, f"counts_ign_multiclass_{extent}.csv")).query("data == 'pred'").drop(["type", "data"], axis=1))
    ofs_multiclass_counts = transform2short(
        pd.read_csv(os.path.join(root, f"counts_OFS_multiclass_{extent}.csv")).query("data == 'pred'").drop(["type", "data"], axis=1))
    
except FileNotFoundError:
    print(f"!! No multiclass counts found for {extent} !!")


In [None]:
def create_multiclass_piecharts(df, suptitle, dpi, threshold=0.03, subthreshold=0.01):
    
    # Sort the data by relative frequencies
    series_relative = df[['class', 'relative']].set_index('class')['relative'].sort_values(ascending=True)
    series_count = df.set_index('class')['count']

    # Set threshold for small categories
    threshold_value = threshold * series_relative.sum()

    # Define colors for each class
    colors = {
        "batiment": "#4949e7",
        "surface_non_beton": "#949452",
        "surface_beton": "#f04d68",
        "eau_bassin": "#76d6ff",
        "roche_dure_meuble": "#929292",
        "eau_naturelle": "#76d6ff",
        "sol_neige": "#ffffff",
        "sol_vegetalise": "#45b72f",
        "sol_vigne": "#5bc08c",
        "sol_agricole": "#8dba37",
        "serre_permanente": "#00fdff",
    }

    # Map colors and group smaller categories into 'Other'
    colors_mapped = {col: colors.get(col, '#808080') for col in series_relative.index}
    colors_mapped["Other"] = "#808080"
    small_categories_relative = series_relative[series_relative < threshold]
    series_relative = series_relative[series_relative >= threshold]

    # Add 'Other' category for relative and count
    series_relative["Other"] = small_categories_relative.sum()
    small_categories_count = series_count[series_count.index.isin(small_categories_relative.index)]
    series_count["Other"] = small_categories_count.sum()

    # Apply second threshold for very small categories within 'Other'
    subthreshold_value = subthreshold * small_categories_relative.sum()
    very_small_categories = small_categories_relative[small_categories_relative < subthreshold_value]
    small_categories_relative = small_categories_relative[small_categories_relative >= subthreshold_value]

    # Add 'Other' subgroup for very small categories
    if not very_small_categories.empty:
        small_categories_relative["Other"] = very_small_categories.sum()
        very_small_categories_count = series_count[series_count.index.isin(very_small_categories.index)]
        series_count["Other in Other"] = very_small_categories_count.sum()

   # Text effect for white buffer
    text_effect = [path_effects.withStroke(linewidth=5, foreground="white", alpha=0.8)]

    # Create subplots and plot pie charts
    fig, axs = plt.subplots(1, 2, figsize=(18, 9), dpi=dpi)
    for ax, data, title, radius in zip(axs, [series_relative, small_categories_relative], 
                                       [r'\textbf{Most Frequent Classes}', r'\textbf{"Other" Categories}'], [1, 0.8]):
        wedges, texts = ax.pie(data, labels=data.index, colors=[colors_mapped.get(col) for col in data.index], 
                               startangle=90, radius=radius)
        ax.add_artist(plt.Circle((0,0),0.55,fc='white'))
        ax.set_title(title)

        # Add text annotations manually for relative and absolute values
        for i, (wedge, label) in enumerate(zip(wedges, data.index)):
            angle = (wedge.theta2 - wedge.theta1)/2. + wedge.theta1
            x = (radius - 0.1 - (0.2 * (i%2==0))) * np.cos(np.deg2rad(angle))
            y = (radius - (0.2 * (i%2==0))) * np.sin(np.deg2rad(angle))
            relative_value = data[label] * 100
            absolute_value = series_count[label]
            text = "{:.2f}\%\n({:,d})".format(relative_value, absolute_value)
            ax.text(x, y, text, ha="center", va="center", path_effects=text_effect, size=10)

    plt.suptitle(rf"\textbf{{{suptitle}}}")
    plt.tight_layout()
    plt.show()

In [None]:
def create_binary_piecharts(gt, heigvd, ign, suptitle, dpi=50):

    gt = gt.set_index('class')['count']
    heigvd = heigvd.set_index('class')['count']
    ign = ign.set_index('class')['count']

    colors = {
        "non-soil": "grey",
        "soil": "green"
        }


    # Custom function for displaying percentages
    def pct_func(pct, allvals):
        return "{:,d}\n({:.1f}\%)".format(int(pct/100.*np.sum(allvals)), pct) if pct > 0 else ''

    # Text effect for white buffer
    text_effect = [path_effects.withStroke(linewidth=3, foreground="white", alpha=0.8)]

    fig = plt.figure(tight_layout=True, figsize=(5, 5), dpi=dpi)
    gs = gridspec.GridSpec(2, 2)
    axs = [fig.add_subplot(gs[0, :]), fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[1, 1])]

    for ax, data, title in zip(
        axs, [gt, heigvd, ign],
        [r'\textbf{Ground Truth}', r'\textbf{heig-vd predictions}', r'\textbf{IGN predictions}']):

        wedges, texts, autotexts = ax.pie(
            data,
            labels=data.index, autopct=lambda pct: pct_func(pct, data), 
            colors=[colors.get(col) for col in data.index], 
            startangle=90, pctdistance=0.73, textprops={"size": 8}
            )
        
        ax.add_artist(plt.Circle((0,0),0.55,fc='white'))
        ax.set_title(title)

        # Apply the text effect to each autotext
        for autotext in autotexts:
            autotext.set_path_effects(text_effect)

    plt.suptitle(suptitle)
    plt.tight_layout()
    plt.show()

## Binary Piecharts

In [None]:
create_binary_piecharts(gt_binary_counts, heigvd_binary_counts, ign_binary_counts, suptitle=r"\textbf{Extent1-masked: Binary Pixel Counts}\\", dpi=200)

## Multiclass Piecharts

In [None]:
dpi=200
create_multiclass_piecharts(
    # df=gt_multiclass_counts,
    # suptitle="Extent2: Ground Truth Class Pixel Counts",
    # df=ign_multiclass_counts,
    # suptitle="Extent2: smp-unet-resnet34-imagenet_RVBI (IGN) Predictions Class Pixel Counts",
    # df=heigvd_multiclass_counts,
    # suptitle="Extent2: mask2former_beit-adapter (heig-vd) Predictions Class Pixel Counts",
    df=ofs_multiclass_counts,
    suptitle="Extent2: OFS Model Class Pixel Counts",
    dpi=dpi,
    threshold=0.05,
    subthreshold=0.01
    )

In [None]:
dpi=200
threshold = 0.02
create_multiclass_piecharts( ## CM: This is the same as below !
    df=heigvd_multiclass_counts,
    suptitle="Class Pixel Counts in the GT of the 04.10.2023",
    dpi=dpi,
    threshold=threshold
    )
create_multiclass_piecharts( ## CM: This is the same as above !
    df=heigvd_multiclass_counts,
    suptitle="Class Pixel Counts in the Predictions of heig-vd",
    dpi=dpi,
    threshold=threshold
)
create_multiclass_piecharts(
    df=ign_multiclass_counts,
    suptitle="Class Pixel Counts in the Predictions of IGN",
    dpi=dpi,
    threshold=threshold
)

# Visualize Dataset Balance

In [None]:
import os
import rasterio
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['text.usetex'] = True
import yaml

def process_dataset(tgt_dir, split_path, merge_dic, class_names):
    dic = {
        "train": [],
        "val": []
    }

    with open(split_path, 'r') as file:
        split_dic = yaml.safe_load(file)
    

    for file in os.listdir(tgt_dir):
        if not file.endswith(".tif"):
            continue
        with rasterio.open(os.path.join(tgt_dir, file)) as tif:
            ar = tif.read()

        ar = np.vectorize(merge_dic.__getitem__)(ar)

        id = file.split("_")[0].split("-")[0]
        for split in split_dic:
            if id in split_dic[split]:
                dic[split].append(
                    np.unique(ar, return_counts=True))
                break
    print(f'{dic = }')

    combined_counts = { "train": {}, "val": {} }
    for dataset, counts_list in dic.items():
        for classes, counts in counts_list:
            for cls, count in zip(classes, counts):
                if cls not in combined_counts[dataset]:
                    combined_counts[dataset][cls] = count
                else:
                    combined_counts[dataset][cls] += count
    print(f'{combined_counts = }')

    sorted_counts = { "train": [], "val": [] }
    for dataset, counts in combined_counts.items():
        # dataset = "train"
        # counts = {11: 354365012, 8: 293740296, ...}
        for i in range(1, len(class_names)+1):
            if i in combined_counts[dataset]:
                sorted_counts[dataset].append(combined_counts[dataset][i])
            else:
                sorted_counts[dataset].append(0)
    print(f'{sorted_counts = }')

    df = pd.DataFrame(sorted_counts)
    df = df.assign(
        class_name = class_names,
        sum=lambda x: x["train"] + x["val"],
        sum_normalized=lambda x: x["sum"] / x["sum"].sum(),
        train_normalized=lambda x: x["train"] / x["train"].sum(),
        val_normalized=lambda x: x["val"] / x["val"].sum(),
    )

    return df


In [None]:
merge_dic_17to12cl = {
        1: 1, # batiment
        2: 1, # toit_vegetalise -> batiment
        3: 2, # surface_non_beton
        4: 3, # surface_beton
        5: 5, # eau_bassin -> eau_naturelle
        6: 4, # roche_dure_meuble
        7: 5, # eau_naturelle
        8: 6, # roseliere
        9: 7, # sol_neige
        10: 8, # sol_vegetalise
        11: 8, # surface_riparienne -> sol_vegetalise
        12: 9, # sol_divers
        13: 10, # sol_vigne
        14: 11, # sol_agricole
        15: 12, # sol_bache
        16: 12, # sol_serre_temporaire -> sol_bache
        17: 1, # serre_permanente -> batiment
    }

merge_dic = {}
for i in range(0, 12+1):
    merge_dic[i] = i

classes = [
            "batiment", # 1
            # "toit_vegetalise", # 2
            "surface_non_beton", # 3
            "surface_beton", # 4
            # "eau_bassin", # 5
            "roche_dure_meuble", # 6
            "eau_naturelle", # 7
            "roseliere", # 8
            "sol_neige", # 9
            "sol_vegetalise", # 10
            # "surface_riparienne", # 11
            "sol_divers", # 12
            "sol_vigne", # 13
            "sol_agricole", # 14
            "sol_bache", # 15
            # "sol_serre_temporaire", # 16
            # "serre_permanente" # 17
        ]

for seed in ["6_adjusted"]:
    tgt_dir = "/soil_fribourg/data/GT/20240216/4-cut-to-grid-12cl/512px"
    split_path = f"/soil_fribourg/data/datasets.nosync/dataset_12cl_seed{seed}/split.yaml"
    csv_path = f"/soil_fribourg/proj-soils/csv/counts/tv-split_seed{seed}_12classes.csv"
    assert not os.path.exists(csv_path), f"File {csv_path} already exists"


    df = process_dataset(tgt_dir, split_path, merge_dic, classes)

    # df.to_csv(csv_path, index=False)

In [None]:
classes = [
            "batiment", # 1
            # "toit_vegetalise", # 2
            "surface_non_beton", # 3
            "surface_beton", # 4
            # "eau_bassin", # 5
            "roche_dure_meuble", # 6
            "eau_naturelle", # 7
            "roseliere", # 8
            "sol_neige", # 9
            "sol_vegetalise", # 10
            # "surface_riparienne", # 11
            "sol_divers", # 12
            "sol_vigne", # 13
            "sol_agricole", # 14
            "sol_bache", # 15
            # "sol_serre_temporaire", # 16
            # "serre_permanente" # 17
        ]
def class_numbers(x):
    class_numbers = []
    dic = {
        "batiment": 1,
        "toit_vegetalise": 2,
        "surface_non_beton": 3,
        "surface_beton": 4,
        "eau_bassin": 5,
        "roche_dure_meuble": 6,
        "eau_naturelle": 7,
        "roseliere": 8,
        "sol_neige": 9,
        "sol_vegetalise": 10,
        "surface_riparienne": 11,
        "sol_divers": 12,
        "sol_vigne": 13,
        "sol_agricole": 14,
        "sol_bache": 15,
        "sol_serre_temporaire": 16,
        "serre_permanente": 17}
    
    for class_name in x["class_name"]:
        class_numbers.append(dic[class_name])
    return class_numbers

In [None]:
def compare_splits(dfs, split_names):
    new_dfs = []
    for df, splitname in zip(dfs, split_names):
        df = df.assign(
            class_number=lambda x: class_numbers(x),
            num_classes=12,
            tvt_split=splitname
        )
        new_dfs.append(df)

    concat = pd.concat(new_dfs, ignore_index=True).reset_index(drop=True)
    # concat = concat.assign(valtest=lambda x: x["val"] + x["test"])


    # _, axs = plt.subplots(nrows=4, figsize=(10, 20))
    # for i, tvt in enumerate(["train", "val", "test", "valtest"]):
    _, axs = plt.subplots(nrows=2, figsize=(10, 10))
    for i, tvt in enumerate(["train", "val"]):
        sns.barplot(
            x="class_name",
            y=tvt,
            data=concat,
            hue="tvt_split",
            # hue_order=[split_names],
            # order=dfs[0].sort_values("train", ascending=False)["class_name"],
            order=classes,
            ax=axs[i]
        )

        axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=60)
        axs[i].set_xlabel(r"\textbf{Class Name}")
        axs[i].set_title(rf"\textbf{{{tvt} set}}", size=15)
        axs[i].legend(title=r"\textbf{Split Version}", loc="upper right")

        axs[i].set_yscale("log")
        axs[i].set_yticks([10**i for i in range(3, 10)])
        axs[i].set_yticklabels([f"{i:,}" for i in axs[i].get_yticks()])
        # Add horizontal lines at the yticks
        # yticks = [10**i for i in range(4, 10)]
        yticks = [(10**i, j) for i in range(4, 10) for j in range(2, 10)]
        for ytick in yticks:
            if ytick[0]*ytick[1] > 10**9:
                break
            axs[i].axhline(y=ytick[0]*ytick[1], color='gray', linewidth=0.1)
            axs[i].axhline(y=ytick[0], color='gray', linestyle='--')
        
        # axs[i].set_ylim([0, 6*10**7])

        

    # plt.xticks(rotation=45, size=14)
    plt.suptitle(r"\textbf{Class Frequencies in the Train, Validation, and Test Sets}", size=20)
    plt.tight_layout()
    plt.show()
    return concat


# seeds = [pd.read_csv(f"/soil_fribourg/proj-soils/csv/counts/tv-split_seed{i}_12classes.csv") for i in range(7)] + [pd.read_csv(f"/soil_fribourg/proj-soils/csv/counts/tv-split_seed6_adjusted_12classes.csv")]
# names = [f"seed{i}" for i in range(7)] + ["seed6_adjusted"]
seeds = [pd.read_csv(f"/soil_fribourg/proj-soils/csv/counts/tv-split_seed{i}_12classes.csv") for i in range(6,7)] + [pd.read_csv(f"/soil_fribourg/proj-soils/csv/counts/tv-split_seed6_adjusted_12classes.csv")]
names = [f"seed{i}" for i in range(6,7)] + ["seed6_adjusted"]


concat = compare_splits(seeds, names)

In [None]:
# cl17 = pd.read_csv("/soil_fribourg/proj-soils/csv/counts/tvt-split_v1_17classes.csv")
# cl12 = pd.read_csv("/soil_fribourg/proj-soils/csv/counts/tvt-split_v2_12classes.csv")

# cl17 = cl17.assign(
#     class_number=lambda x: class_numbers(x),
#     num_classes=17,
#     tvt_split = "V1, 17 Classes")
# cl12 = cl12.assign(
#     class_number=lambda x: class_numbers(x),
#     num_classes=12,
#     tvt_split = "V2, 12 Classes")
# both = pd.concat([cl12, cl17], ignore_index=True).reset_index(drop=True)
# both = both.assign(
#     valtest = lambda x: x["val"] + x["test"]
# )

In [None]:
_, axs = plt.subplots(nrows=3, figsize=(10, 15))
for i, tvt in enumerate(["train", "val", "test"]):
    sns.barplot(
        x="class_name",
        y=tvt,
        data=cl12,
        color="lightblue",
        order=cl12.sort_values(tvt, ascending=False)["class_name"],
        ax=axs[i]
    )

    axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=60)
    axs[i].set_xlabel(r"\textbf{Class Name}")
    axs[i].set_title(rf"\textbf{{{tvt} set}}", size=15)
    axs[i].set_yscale("log")
    axs[i].set_yticks([10**i for i in range(3, 10)])
    axs[i].set_yticklabels([f"{10**i:,}" for i in range(3, 10)])
    # Add horizontal lines at the yticks
    yticks = [10**i for i in range(4, 10)]
    for ytick in yticks:
        axs[i].axhline(y=ytick, color='gray', linestyle='--')

# plt.xticks(rotation=45, size=14)
plt.suptitle(r"\textbf{Class Frequencies in the Train, Validation, and Test Sets}", size=20)
plt.tight_layout()
plt.show()

In [None]:
_, axs = plt.subplots(nrows=3, figsize=(10, 15))
for i, tvt in enumerate(["train", "val", "test"]):
    sns.barplot(
        x="class_name",
        y=tvt,
        data=both,
        hue="tvt_split",
        hue_order=["V1, 17 Classes", "V2, 12 Classes"],
        order=cl17.sort_values(tvt, ascending=False)["class_name"],
        ax=axs[i]
    )

    axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=60)
    axs[i].set_xlabel(r"\textbf{Class Name}")
    axs[i].set_title(rf"\textbf{{{tvt} set}}", size=15)
    axs[i].set_yscale("log")
    axs[i].set_yticks([10**i for i in range(3, 10)])
    axs[i].set_yticklabels([f"{10**i:,}" for i in range(3, 10)])
    axs[i].legend(title=r"\textbf{Split Version}", loc="upper right")
    # Add horizontal lines at the yticks
    yticks = [10**i for i in range(4, 10)]
    for ytick in yticks:
        axs[i].axhline(y=ytick, color='gray', linestyle='--')

# plt.xticks(rotation=45, size=14)
plt.suptitle(r"\textbf{Class Frequencies in the Train, Validation, and Test Sets}", size=20)
plt.tight_layout()
plt.show()

In [None]:
df_melted = both.melt(
    id_vars=['class_name', 'sum', 'sum_normalized', 'train_normalized', 'val_normalized', 'test_normalized', 'class_number', 'num_classes', 'tvt_split'],
    value_vars=['train', 'val', 'test', "valtest"],
    var_name='tvt',
    value_name='counts')

In [None]:
plt.figure(figsize=(15, 10))
sns.barplot(
    x="class_name",
    y="counts",
    data=df_melted.query("tvt == 'valtest'"),
    hue="tvt_split",
    hue_order=["V1, 17 Classes", "V2, 12 Classes"],
    order=df_melted.query("tvt == 'valtest' & tvt_split == 'V1, 17 Classes'").sort_values("counts", ascending=False)["class_name"]
)

plt.xticks(rotation=45, size=12)
plt.xlabel(r"\textbf{Class Name}", size=14)
# plt.yscale("log")
# plt.yticks([10**i for i in range(3, 10)], [f"{10**i:,}" for i in range(3, 10)])
# plt.ylim(0, 100000000)
plt.ylabel(r"\textbf{Pixel Count}", size=14)
plt.title(r"\textbf{Class Frequencies in the Validation + Test Set: V1 vs V2}\newline", size="xx-large")
plt.legend(title="Split Version", loc="upper right")

# Add horizontal lines at the yticks
# yticks = [10**i for i in range(4, 10)]
# for ytick in yticks:
#     plt.axhline(y=ytick, color='gray', linestyle='--')

plt.show()

# Testscratch

In [None]:
import os
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
plt.rc('text', usetex=True)

root = "/soil_fribourg/data/heig-vd_finetuned/0-inferences/testscratch"

mixed = os.path.join("mixed-145k", "stride512")
cm10 = os.path.join("10cm-71k", "stride512")

countryside = "scratch_20200323_1151_12501_0_11_1km.tif"
urban = "scratch_20200319_1025_12501_0_51_1km.tif"
mountainous = "scratch_20200805_0940_12504_0_2_1km.tif"

save_folder = "/soil_fribourg/figures/inference_testcratch"

In [None]:
cmap_multiclass = ListedColormap([
    '#E86767', # 1: batiment
    '#FFCC00', # 2: surface_non_beton
    '#FFBECE', # 3: surface_beton
    '#FFFFFF', # 4: roche_dure_meuble
    '#2B78D9', # 5: eau_naturelle
    '#99C7FF', # 6: roseliere
    '#CCCCCC', # 7: sol_neige
    '#B3F135', # 8: sol_vegetalise
    '#6EF3C1', # 9: sol_divers
    '#35886C', # 10: sol_vigne
    '#FFFF67', # 11: sol_agricole
    '#A8A800', # 12: sol_bache
])
cmap_binary = ListedColormap(["#ABABAB" for _ in range(6)] + ["#3AA336" for _ in range(6)])

## Comparison of Models

Vertical

In [None]:
plt.ioff()

for res in ["10cm", "20cm", "40cm"]:
    _, axs = plt.subplots(nrows=3, ncols=2, figsize=(10, 15), dpi=200)

    for i, img in enumerate([countryside, urban, mountainous]):
        for j, model in enumerate([cm10, mixed]):
            with rasterio.open(os.path.join(root, model, res, img)) as tif:
                ar = tif.read(1)
            axs[i, j].imshow(ar, cmap=cmap, vmin=1, vmax=12)

    plt.suptitle(
        r"\textbf{Comparison of the 10cm Model and the Mixed Model}" + \
            "\n" + \
            rf"\textbf{{with input resolution of {res}}}"+ \
            "\n",
        size=20)
    axs[0, 0].set_title("10cm Model", size="xx-large")
    axs[0, 1].set_title("Mixed Model", size="xx-large")
    axs[0, 0].set_ylabel("Countryside", size="xx-large")
    axs[1, 0].set_ylabel("Urban", size="xx-large")
    axs[2, 0].set_ylabel("Mountainous", size="xx-large")

    # no ticks
    for ax in axs.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(save_folder, f"vertical_model_comparison_{res}.png"))

plt.ion()

Horizontal

In [None]:
plt.ioff()

for res in ["10cm", "20cm", "40cm"]:
    _, axs = plt.subplots(nrows=2, ncols=3, figsize=(14, 11), dpi=200)

    for i, model in enumerate([cm10, mixed]):
        for j, img  in enumerate([countryside, urban, mountainous]):
            with rasterio.open(os.path.join(root, model, res, img)) as tif:
                ar = tif.read(1)
            axs[i, j].imshow(ar, cmap=cmap_multiclass, vmin=1, vmax=12)

    plt.suptitle(
        r"\textbf{Comparison of the 10cm Model and the Mixed Model}" + \
            "\n" + \
            rf"\textbf{{with input resolution of {res}}}"+ \
            "\n",
        size=20)
    
    axs[0, 0].set_ylabel("10cm Model", size="xx-large")
    axs[1, 0].set_ylabel("Mixed Model", size="xx-large")

    axs[0, 0].set_title("Countryside", size="xx-large")
    axs[0, 1].set_title("Urban", size="xx-large")
    axs[0, 2].set_title("Mountainous", size="xx-large")

    # no ticks
    for ax in axs.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(save_folder, f"horizontal_model_comparison_{res}.png"))

plt.ion()

## Comparison of Resolutions

vertical

In [None]:
plt.ioff()

for (img, scene_name) in zip([countryside, urban, mountainous], ["Countryside", "Urban", "Mountainous"]):
    _, axs = plt.subplots(nrows=3, ncols=2, figsize=(10, 15), dpi=200)

    for i, res in enumerate(["10cm", "20cm", "40cm"]):
        for j, model in enumerate([cm10, mixed]):
            with rasterio.open(os.path.join(root, model, res, img)) as tif:
                ar = tif.read(1)
            axs[i, j].imshow(ar, cmap=cmap, vmin=1, vmax=12)

    plt.suptitle(
        r"\textbf{Comparison of the 10cm Model and the Mixed Model}" + \
            "\n" + \
            rf"\textbf{{for different resolutions on a {scene_name} scene}}" + \
            "\n",
        size=20)
    axs[0, 0].set_title("10cm Model", size="xx-large")
    axs[0, 1].set_title("Mixed Model", size="xx-large")
    axs[0, 0].set_ylabel("10cm", size="xx-large")
    axs[1, 0].set_ylabel("20cm", size="xx-large")
    axs[2, 0].set_ylabel("40cm", size="xx-large")

    # no ticks
    for ax in axs.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(save_folder, f"vertical_resolution_comparison_{scene_name}.png"))
    plt.close()
plt.ion()

horizontal

In [None]:
plt.ioff()

for (img, scene_name) in zip([countryside, urban, mountainous], ["Countryside", "Urban", "Mountainous"]):
    _, axs = plt.subplots(nrows=2, ncols=3, figsize=(14, 11), dpi=200)

    for i, model in enumerate([cm10, mixed]):
        for j, res in enumerate(["10cm", "20cm", "40cm"]):
            with rasterio.open(os.path.join(root, model, res, img)) as tif:
                ar = tif.read(1)
            axs[i, j].imshow(ar, cmap=cmap_multiclass, vmin=1, vmax=12)

    plt.suptitle(
        r"\textbf{Comparison of the 10cm Model and the Mixed Model}" + \
            "\n" + \
            rf"\textbf{{for different resolutions on a {scene_name} scene}}" + \
            "\n",
        size=20)
    
    axs[0, 0].set_ylabel("10cm Model", size="xx-large")
    axs[1, 0].set_ylabel("Mixed Model", size="xx-large")

    axs[0, 0].set_title("10cm", size="xx-large")
    axs[0, 1].set_title("20cm", size="xx-large")
    axs[0, 2].set_title("40cm", size="xx-large")

    # no ticks
    for ax in axs.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(save_folder, f"horizontal_resolution_comparison_{scene_name}.png"))
    plt.close()
plt.ion()

In [None]:
plt.ioff()

for i_model, (model_path, model_name) in enumerate(zip([cm10, mixed], ["10cm", "Mixed"])):

    for (img, scene_name) in zip([countryside, urban, mountainous], ["Countryside", "Urban", "Mountainous"]):
        _, axs = plt.subplots(nrows=2, ncols=3, figsize=(14, 11), dpi=200)
        for i, (cmap, cmap_name) in enumerate(zip([cmap_multiclass, cmap_binary], ["multiclass", "binary"])):
            for j, res in enumerate(["10cm", "20cm", "40cm"]):
                with rasterio.open(os.path.join(root, model_path, res, img)) as tif:
                    ar = tif.read(1)
                axs[i, j].imshow(ar, cmap=cmap, vmin=1, vmax=12)

        plt.suptitle(
            r"\textbf{Comparison of the 10cm Model and the Mixed Model}" + \
                "\n" + \
                rf"\textbf{{for different resolutions on a {scene_name} scene}}" + \
                "\n",
            size=20)
        
        axs[0, 0].set_ylabel("10cm Model", size="xx-large")
        axs[1, 0].set_ylabel("Mixed Model", size="xx-large")

        axs[0, 0].set_title("10cm", size="xx-large")
        axs[0, 1].set_title("20cm", size="xx-large")
        axs[0, 2].set_title("40cm", size="xx-large")

        # no ticks
        for ax in axs.flatten():
            ax.set_xticks([])
            ax.set_yticks([])

        plt.tight_layout()
        plt.savefig(os.path.join(save_folder, f"horizontal_resolution_comparison_{model_name}_{scene_name}.png"))
        plt.close()
plt.ion()

# Training Log

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.lines as mlines

cm10 = pd.read_csv("/proj-soils/csv/train_log/miou_mask2former_beit_adapter_large_512_160k_proj-soils_12class_10cm_tf_logs.csv")

mixed = pd.read_csv("/proj-soils/csv/train_log/miou_mask2former_beit_adapter_large_512_160k_proj-soils_12class_multiscale_tf_logs.csv")



In [None]:
import time
import datetime

def convert_to_datetime(x):
    return datetime.datetime.fromtimestamp(x)


total_time_mixed = mixed["Wall time"].max() - mixed["Wall time"].min()
total_time_mixed = time.strftime('%d days, %H hours, %M minutes, %S seconds', time.gmtime(total_time_mixed))
print(total_time_mixed)

total_time_cm10 = cm10["Wall time"].max() - cm10["Wall time"].min()
total_time_cm10 = time.strftime('%d days, %H hours, %M minutes, %S seconds', time.gmtime(total_time_cm10))
print(total_time_cm10)

In [None]:
plt.figure(figsize=(10, 5), dpi=200)

sns.scatterplot(data=cm10, x="Step", y="Value", label="10cm Model", s=20)
sns.lineplot(data=cm10, x="Step", y="Value", label="10cm Model", alpha=0.8)
sns.scatterplot(data=mixed, x="Step", y="Value", label="Mixed Model", s=20)
line = sns.lineplot(data=mixed, x="Step", y="Value", label="Mixed Model", alpha=0.8)

# display x labels in 1k, 2k, format
labels = np.arange(0, 180_000, 20_000)
plt.xticks(labels, [f"{i//1000}k" for i in labels])

plt.grid(axis='y', linestyle='-', color='gray', alpha=0.25)


# for the color of the 10cm model, make a vertical dashed line at x=71k, for the mixed model at x=145k
plt.axvline(x=71_250, color="#4C72B0", linestyle="--")
plt.axvline(x=145_250, color="#DD8452", linestyle="--")

# add additional tick labels for the vertical lines in the respective colors
plt.text(69_000, 0.441, "71k", color="#4C72B0", fontsize=12)
plt.text(145_000, 0.441, "145k", color="#DD8452", fontsize=12)

plt.xlabel("Iteration")
plt.ylabel("mIoU")
plt.title(
    r"\textbf{Evolution of mIoU values during fine-tuning}" + \
    "\nThe dashed lines indicate the best mIoU for each model")

line_marker_10cm = mlines.Line2D([], [], color='#4C72B0', marker='o', markersize=4, label='10cm Model')
line_marker_mixed = mlines.Line2D([], [], color='#DD8452', marker='o', markersize=4, label='Mixed Model')
plt.legend(handles=[line_marker_10cm, line_marker_mixed])

# horizontal guide


plt.tight_layout()
plt.show()

# Inference Example

In [None]:
import rasterio
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
matplotlib.rcParams['text.usetex'] = True

inf_path = "/soil_fribourg/data/heig-vd_finetuned/0-inferences/testscratch/mixed-145k/stride512/10cm/scratch_20200319_1025_12501_0_51_1km.tif"
img_path = "/Documents/AutoDelete/scratch_20200319_1025_12501_0_51_1km.tif"

In [None]:
classes = [
    "Building", # 1
    "Pervious (Non-Soil)", # 2
    "Impervious", # 3
    "Hard Rock and Pebbles", # 4
    "Natural Water", # 5
    "Reedbed", # 6
    "Snow", # 7
    "Vegetated Soil", # 8
    "Diverse Soil", # 9
    "Vineyard", # 10
    "Agricultural Soil", # 11
    "Plastic Cover", # 12
]

cmap_multiclass = ListedColormap([
    '#E86767', # 1: batiment
    '#FFCC00', # 2: surface_non_beton
    '#FFBECE', # 3: surface_beton
    '#FFFFFF', # 4: roche_dure_meuble
    '#2B78D9', # 5: eau_naturelle
    '#99C7FF', # 6: roseliere
    '#CCCCCC', # 7: sol_neige
    '#B3F135', # 8: sol_vegetalise
    '#6EF3C1', # 9: sol_divers
    '#35886C', # 10: sol_vigne
    '#FFFF67', # 11: sol_agricole
    '#A8A800', # 12: sol_bache
])
cmap_binary = ListedColormap(["#c9665f" for _ in range(6)] + ["#3AA336" for _ in range(6)])

In [None]:
with rasterio.open(inf_path) as tif:
    ar_inf = tif.read(1)
with rasterio.open(img_path) as tif:
    ar_img = tif.read()

print(f'{ar_img.shape = }')

In [None]:
ymin = 0
xmin = 8250
size = 1500

ymax = ymin + size
xmax = xmin + size

_, ax = plt.subplots(ncols=2, figsize=(12, 7), dpi=200)

ax[0].imshow(ar_img[:, ymin:ymax, xmin:xmax].transpose(1, 2, 0))
ax[0].imshow(ar_inf[ymin:ymax, xmin:xmax], alpha=0.5, cmap=cmap_multiclass, vmin=1, vmax=12)
handles = [mpatches.Patch(color=cmap_multiclass(i), label=classes[i]) for i in range(12)]
ax[0].legend(handles=handles, loc="lower right")
ax[0].set_title("Multiclass", size="x-large")
ax[0].axis("off")

ax[1].imshow(ar_img[:, ymin:ymax, xmin:xmax].transpose(1, 2, 0))
ax[1].imshow(ar_inf[ymin:ymax, xmin:xmax], alpha=0.5, cmap=cmap_binary, vmin=1, vmax=12)
handles = [mpatches.Patch(color=["#c9665f", "#3AA336"][i], label=["Non-Soil", "Soil"][i]) for i in range(2)]
ax[1].legend(handles=handles, loc="lower right")
ax[1].set_title("Binary", size="x-large")
ax[1].axis("off")

plt.suptitle(r"\textbf{Examplary Inference on Urban Scene in Bulles, FR}", size="xx-large")
plt.tight_layout()
plt.show()