In [4]:
import pathlib
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pathlib
from pathos.multiprocessing import ProcessingPool as Pool
from src.utils import load_data, predict_and_plot_SLC
import seaborn as sns
from tqdm import tqdm
from multiprocessing import Pool
sns.set_style('whitegrid')
sns.set_palette('colorblind')
mpl.rcParams['font.family'] = 'Arial'

In [5]:
# Load data
data = load_data([0,151])

# IMPORTANT: If you want to test the code you can use a small value here and it should run fast
# If you want to reproduce the results in the paper use sample_size=3000
sample_size = 10


In [6]:
# Process tasks sequentially with progress bar
total_tasks = len(range(2, 9))
with tqdm(total=7, desc="Generating plots") as pbar:
    for resolution in range(2, 9):
        result = predict_and_plot_SLC(data, 150, 350, resolution, plot=True, sample_size=sample_size)
        plt.savefig(
            f"./../assets/plots/fig_4/res_{resolution}_time_150.pdf",
            bbox_inches="tight",
        )
        plt.close()  # Close the figure to avoid too many open figures
        pbar.update(1)

Generating plots: 100%|█████████████████████████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.23it/s]


In [7]:
timestamp_list = [100, 125, 150]
max_val_list = [150, 200, 300, 400]
melt_average = [(1 + i) * 12.5 for i in range(11)]  # Create all the melt average except the first and last


def process_task(data, timestamp, max_val_un, resolution, melt, fit_gp, sample_size):
    data_label = data[(data["melt_average"] < melt + 0.05) & (data["melt_average"] > melt - 0.05)]
    x_m = np.array([melt]).reshape((-1, 1))
    result = predict_and_plot_SLC(data, timestamp, max_val_un, resolution, plot=False, sample_size=sample_size, X_m=x_m, fit_gp=fit_gp)

    pred = result[1][0][0][0]
    tmp = data_label[data_label["resolution"] == resolution]
    tmp = tmp[tmp["years"] == timestamp]
    label = tmp["SLC"].mean()
    return (pred, label)

def compute_rmse(results):
    return np.sqrt(((results[:, 0] - results[:, 1]) ** 2).mean())

for fit_gp in [True, False]:
    tasks = [(data, timestamp, max_val, resolution, melt, fit_gp, sample_size)
             for melt in melt_average
             for timestamp, max_val in zip(timestamp_list, max_val_list)
             for resolution in range(2, 9)]

    results = []
    with tqdm(total=len(tasks), desc=f"Processing tasks (GP fit: {fit_gp})") as pbar:
        for task in tasks:
            result = process_task(*task)
            results.append(result)
            pbar.update(1)

    results = np.array(results)
    rmse = compute_rmse(results)
    print("GP fit:", fit_gp)
    print("The root mean squared error averaged across all resolution and melt average is:", rmse)


Processing tasks (GP fit: True): 100%|██████████████████████████████████████████████████████| 231/231 [00:16<00:00, 13.92it/s]


GP fit: True
The root mean squared error averaged across all resolution and melt average is: 1.953486818464643


Processing tasks (GP fit: False): 100%|████████████████████████████████████████████████████| 231/231 [00:01<00:00, 119.95it/s]

GP fit: False
The root mean squared error averaged across all resolution and melt average is: 3.7350990760592593



