In [None]:
import sys

print("Python %s on %s" % (sys.version, sys.platform))
sys.path.extend(
    [
        "/home/xultaeculcis/repos/climate-super-resolution",
        "/home/xultaeculcis/repos/climate-super-resolution/sr",
    ]
)

import os
from dataclasses import dataclass
from typing import List, Optional, Union

import folium
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from pre_processing.cruts_config import CRUTSConfig
from sklearn.metrics import mean_absolute_error, mean_squared_error

## Load data

Load data from Tomasz. Each location is represented by 3 values (lat, lon, alt). We split them into separate arrays per
feature for convenience.

In [None]:
lats = [
    51.0571888888889,
    51.0422222222222,
    50.7934805555556,
    50.7901777777778,
    50.4172222222222,
    50.3986111111111,
    50.4443333333333,
    50.4421111111111,
    50.7614722222222,
    50.7603333333333,
    50.8071111111111,
    50.8091111111111,
    50.81275,
    50.7593611111111,
    50.6729722222222,
    50.6456666666667,
    50.6971666666667,
    50.8953611111111,
    50.8190277777778,
    50.8917222222222,
    50.96025,
    50.76033,
    51.04222,
]

lons = [
    15.478,
    15.6838888888889,
    15.6849972222222,
    15.6788361111111,
    16.48975,
    16.4648611111111,
    16.7188055555556,
    16.7666388888889,
    15.7469166666667,
    15.7261388888889,
    15.6083055555556,
    15.5868611111111,
    15.6107777777778,
    15.7611666666667,
    16.1326666666667,
    16.3668333333333,
    16.4656944444444,
    15.6205,
    15.5150555555556,
    15.3588611111111,
    15.4858888888889,
    15.72614,
    15.68389,
]

alts = [
    402,
    325,
    646,
    709,
    486,
    499,
    486,
    546,
    825,
    930,
    726,
    744,
    620,
    752,
    604,
    706,
    658,
    543,
    717,
    606,
    469,
    930,
    325,
]

Load Net-CDF datasets from CRU-TS and from our Super-Resolution NN.

In [None]:
ds_temp_nn = xr.open_dataset(
    "/media/xultaeculcis/2TB/datasets/cruts/inference-europe-extent-nc/cru_ts4.04.nn.inference.1901.2019.tmp-combined.dat.nc"
)
ds_tmin_nn = xr.open_dataset(
    "/media/xultaeculcis/2TB/datasets/cruts/inference-europe-extent-nc/cru_ts4.04.nn.inference.1901.2019.tmn.dat.nc"
)
ds_tmax_nn = xr.open_dataset(
    "/media/xultaeculcis/2TB/datasets/cruts/inference-europe-extent-nc/cru_ts4.04.nn.inference.1901.2019.tmx.dat.nc"
)
ds_temp_cru = xr.open_dataset(
    "/media/xultaeculcis/2TB/datasets/cruts/original/cru_ts4.04.1901.2019.tmp.dat.nc"
)
ds_tmax_cru = xr.open_dataset(
    "/media/xultaeculcis/2TB/datasets/cruts/original/cru_ts4.04.1901.2019.tmx.dat.nc"
)
ds_tmin_cru = xr.open_dataset(
    "/media/xultaeculcis/2TB/datasets/cruts/original/cru_ts4.04.1901.2019.tmn.dat.nc"
)

Load 200 mountain peaks dataset.

In [None]:
peaks = pd.read_csv("../datasets/mountain_peaks.csv")

## Visualize data

Visualize data from Tomasz first.

In [None]:
f = folium.Figure(width=500, height=500)
m = folium.Map(location=[lats[14], lons[14]], zoom_start=7)
m.add_to(f)

for idx, tup in enumerate(zip(lats, lons, alts)):
    lat, lon, alt = tup
    folium.Marker(
        [lat, lon], tooltip=f"Location #{str(idx).rjust(2, '0')} - alt: {alt}m"
    ).add_to(m)
f

Now visualize the 200 peaks dataset.

In [None]:
f = folium.Figure(width=500, height=500)
m = folium.Map(location=[50.200328, 19.770119], zoom_start=6)
m.add_to(f)
for idx, row in peaks.iterrows():
    lat, lon, alt = (
        row["lat"],
        row["lon"],
        row["altitude"],
    )
    folium.Marker(
        [lat, lon],
        tooltip=f"{row['mountain_range']} - {row['mountain_peak_name']} - alt: {alt}m",
    ).add_to(m)
f

## Helper classes
Lets define some helper classes to compute our metrics and visualize the time series results.

In [None]:
@dataclass
class StatsResult:
    minima: np.ndarray
    means: np.ndarray
    medians: np.ndarray
    q25: np.ndarray
    q50: np.ndarray
    q75: np.ndarray
    maxima: np.ndarray

    @classmethod
    def empty(cls, size):
        return cls(
            minima=np.zeros(size),
            means=np.zeros(size),
            medians=np.zeros(size),
            q25=np.zeros(size),
            q50=np.zeros(size),
            q75=np.zeros(size),
            maxima=np.zeros(size),
        )


@dataclass
class CompareStatsResults:
    stats_cru: StatsResult
    stats_nn: StatsResult
    var: str
    ds_cru: xr.Dataset
    ds_nn: xr.Dataset
    time_range: pd.DatetimeIndex
    lats: Union[List, np.ndarray]
    lons: Union[List, np.ndarray]
    alts: Union[List, np.ndarray]
    names: Optional[Union[List, np.ndarray]]
    mse: float
    rmse: float
    mae: float

    @classmethod
    def compute(cls, var, time_range, lats, lons, alts, ds_cru, ds_nn, names=None):
        cru_stats = StatsResult.empty(len(lats))
        nn_stats = StatsResult.empty(len(lats))

        mae = np.zeros(len(lats))
        mse = np.zeros(len(lats))
        rmse = np.zeros(len(lats))

        for idx, (lat, lon) in enumerate(zip(lats, lons)):
            cru_data = ds_cru[var].sel(
                lat=lat, lon=lon, time=time_range, method="nearest"
            )

            cru_stats.q25[idx] = cru_data.quantile(0.25)
            cru_stats.q50[idx] = cru_data.quantile(0.5)
            cru_stats.q75[idx] = cru_data.quantile(0.75)
            cru_stats.minima[idx] = cru_data.min()
            cru_stats.maxima[idx] = cru_data.max()
            cru_stats.means[idx] = cru_data.mean()
            cru_stats.medians[idx] = cru_data.median()

            nn_data = ds_nn[var].sel(
                lat=lat, lon=lon, time=time_range, method="nearest"
            )

            nn_stats.q25[idx] = nn_data.quantile(0.25)
            nn_stats.q50[idx] = nn_data.quantile(0.5)
            nn_stats.q75[idx] = nn_data.quantile(0.75)
            nn_stats.minima[idx] = nn_data.min()
            nn_stats.maxima[idx] = nn_data.max()
            nn_stats.means[idx] = nn_data.mean()
            nn_stats.medians[idx] = nn_data.median()

            mae[idx] = mean_absolute_error(cru_data, nn_data)
            mse[idx] = mean_squared_error(cru_data, nn_data)
            rmse[idx] = mean_squared_error(cru_data, nn_data, squared=False)

        return cls(
            stats_cru=cru_stats,
            stats_nn=nn_stats,
            var=var,
            ds_cru=ds_cru,
            ds_nn=ds_nn,
            time_range=time_range,
            lats=lats,
            lons=lons,
            alts=alts,
            names=names,
            mae=mae.mean(),
            mse=mse.mean(),
            rmse=rmse.mean(),
        )

    def line_plot(self):
        plt.figure(figsize=(15, 15))
        ax = plt.subplot(1, 1, 1)
        for idx, (lat, lon) in enumerate(zip(self.lats, self.lons)):
            cru_data = self.ds_cru[self.var].sel(
                lat=lat, lon=lon, time=self.time_range, method="nearest"
            )
            nn_data = self.ds_nn[self.var].sel(
                lat=lat, lon=lon, time=self.time_range, method="nearest"
            )
            cru_data.plot(marker="x", color="blue", ax=ax)
            nn_data.plot(marker="o", color="orange", ax=ax)

        ax.set_title("Temperature comparison between CRU-TS and SR across time")

        plt.gca().legend(("CRU-TS", "SR"))
        plt.show()

    def box_plot(self):
        plt.figure(figsize=(20, 10))
        values = []
        locations = []
        sources = []

        for idx, (lat, lon) in enumerate(zip(self.lats, self.lons)):
            cru_data = self.ds_cru[self.var].sel(
                lat=lat, lon=lon, time=self.time_range, method="nearest"
            )
            nn_data = self.ds_nn[self.var].sel(
                lat=lat, lon=lon, time=self.time_range, method="nearest"
            )

            values.extend(cru_data.values.tolist())
            values.extend(nn_data.values.tolist())
            sources.extend(["CRU-TS" for _ in range(len(nn_data))])
            sources.extend(["SR" for _ in range(len(nn_data))])

            if self.names is None:
                locations.extend(
                    [
                        f"#{idx} - {self.alts[idx]} m"
                        for _ in range(len(nn_data) + len(cru_data))
                    ]
                )
            else:
                locations.extend(
                    [
                        f"{self.names[idx]} - {self.alts[idx]} m"
                        for _ in range(len(nn_data) + len(cru_data))
                    ]
                )

        df = pd.DataFrame(
            data={
                "Location": locations,
                "Data source": sources,
                "Temperature (Celsius)": values,
            }
        )
        plt.figure(figsize=(np.maximum(0.25 * len(df["Location"].unique()), 20), 10))
        chart = sns.boxplot(
            x="Location", y="Temperature (Celsius)", data=df, hue="Data source"
        )
        chart.set_xticklabels(
            chart.get_xticklabels(), rotation=45, horizontalalignment="right"
        )
        plt.show()

    def print_comparison_summary(self):
        print(
            f"Avg min CTS: {self.stats_cru.minima.mean()}, "
            f"NN: {self.stats_nn.minima.mean()}, "
            f"diff: {(self.stats_cru.minima - self.stats_nn.minima).mean()}"
        )
        print(
            f"Avg mean CTS: {self.stats_cru.means.mean()}, "
            f"NN: {self.stats_nn.means.mean()}, "
            f"diff: {(self.stats_cru.means - self.stats_nn.means).mean()}"
        )
        print(
            f"Avg median CTS: {self.stats_cru.medians.mean()}, "
            f"NN: {self.stats_nn.medians.mean()}, "
            f"diff: {(self.stats_cru.medians - self.stats_nn.medians).mean()}"
        )
        print(
            f"Avg max CTS: {self.stats_cru.maxima.mean()}, "
            f"NN: {self.stats_nn.maxima.mean()}, "
            f"diff: {(self.stats_cru.maxima - self.stats_nn.maxima).mean()}"
        )
        print(
            f"Avg q25 CTS: {self.stats_cru.q25.mean()}, "
            f"NN: {self.stats_nn.q25.mean()}, "
            f"diff: {(self.stats_cru.q25 - self.stats_nn.q25).mean()}"
        )
        print(
            f"Avg q50 CTS: {self.stats_cru.q50.mean()}, "
            f"NN: {self.stats_nn.q50.mean()}, "
            f"diff: {(self.stats_cru.q50 - self.stats_nn.q50).mean()}"
        )
        print(
            f"Avg q75 CTS: {self.stats_cru.q75.mean()}, "
            f"NN: {self.stats_nn.q75.mean()}, "
            f"diff: {(self.stats_cru.q75 - self.stats_nn.q75).mean()}"
        )
        print(f"Mean Absolute Error between CRU-TS and SR-NN: {self.mae}")
        print(f"Mean Squared Error between CRU-TS and SR-NN: {self.mse}")
        print(f"Root Mean Squared Error between CRU-TS and SR-NN: {self.rmse}")

## Compute statistics

Now, lets compute some stats to compare our NN to CRU-TS.

### Data from Tomasz

In [None]:
may_only = pd.date_range(
    ds_temp_cru["time"][5].values, ds_temp_cru["time"][-1].values, freq="AS-MAY"
)

results = CompareStatsResults.compute(
    CRUTSConfig.tmp, may_only, lats, lons, alts, ds_temp_cru, ds_temp_nn
)
results.print_comparison_summary()
results.line_plot()
results.box_plot()

Data generated from our NN follows CRU-TS very closely. On average, it produces data that has slightly higher
value (negative differences). RMSE=0.54

### Mountain peaks dataset
Let's see how our NN compares with CRU-TS when the sample size is increased.

In [None]:
results = CompareStatsResults.compute(
    CRUTSConfig.tmp,
    may_only,
    peaks["lat"].values,
    peaks["lon"].values,
    peaks["altitude"].values,
    ds_temp_cru,
    ds_temp_nn,
    peaks["mountain_peak_name"].values,
)
results.print_comparison_summary()
results.line_plot()
results.box_plot()

The situation is similar with the larger dataset. Bigger sample size shows that our NN actually produces coolder data points in comparison with CRU-TS. 
This is the behaviour that we expected. Our NN is still on average, producing a bit higher output than what we get from CRU-TS. RMSE=0.63.

In [None]:
results = CompareStatsResults.compute(
    CRUTSConfig.tmn,
    may_only,
    peaks["lat"].values,
    peaks["lon"].values,
    peaks["altitude"].values,
    ds_tmin_cru,
    ds_tmin_nn,
    peaks["mountain_peak_name"].values,
)
results.print_comparison_summary()
results.line_plot()
results.box_plot()

When it comes to the NN model trained only on **tmn** data, it produces data points that are also cooler than CRU-TS. However this time we have a lot more outliers on both ends of the temperature spectrum.

In [None]:
results = CompareStatsResults.compute(
    CRUTSConfig.tmx,
    may_only,
    peaks["lat"].values,
    peaks["lon"].values,
    peaks["altitude"].values,
    ds_tmax_cru,
    ds_tmax_nn,
    peaks["mountain_peak_name"].values,
)
results.print_comparison_summary()
results.line_plot()
results.box_plot()

The model trained on only **tmx** data presents the biggest differences between CRU-TS. The RMSE is the biggest of the 3 models - RMSE=0.71. 
The overall trend of the model is to produce on average slightly coolder datapoints - this is expected behaviour.  

### Only two locations
Let's focus on only two locations one having elevation of 325m and the other of 930m. We'll go through:
1. predictions generated by model trained on combined tmn and tmx data
2. predictions generated by model trained on only tmn data
3. predictions generated by model trained on only tmx data

In [None]:
f = folium.Figure(width=500, height=500)
m = folium.Map(location=[lats[-1], lons[-1]], zoom_start=9)
m.add_to(f)

for idx, tup in enumerate(zip(lats[-2:], lons[-2:], alts[-2:])):
    lat, lon, alt = tup
    folium.Marker(
        [lat, lon], tooltip=f"Location #{str(idx).rjust(2, '0')} - alt: {alt}m"
    ).add_to(m)
f

In [None]:
results = CompareStatsResults.compute(
    CRUTSConfig.tmp, may_only, lats[-2:], lons[-2:], alts[-2:], ds_temp_cru, ds_temp_nn
)
results.print_comparison_summary()
results.line_plot()
results.box_plot()

For the **tmp** data - you can see our model produces data points that are between the CRU-TS data points. Model follows very closely the
325m data points (although it tends to produce slightly cooler data points). However, for the 930m elevation it
produces data points that are slightly warmer. Our assumption was that it will do the opposite.

In [None]:
results = CompareStatsResults.compute(
    CRUTSConfig.tmn, may_only, lats[-2:], lons[-2:], alts[-2:], ds_tmin_cru, ds_tmin_nn
)
results.print_comparison_summary()
results.line_plot()
results.box_plot()

The RMSE for the **tmn** model (notice that we are checking on tmn CRU-TS dataset) is much lower that for the model trained
on combined data. The 325m elevation data poins are almost spot on when compared with CRU-TS.

In [None]:
results = CompareStatsResults.compute(
    CRUTSConfig.tmx, may_only, lats[-2:], lons[-2:], alts[-2:], ds_tmax_cru, ds_tmax_nn
)
results.print_comparison_summary()
results.line_plot()
results.box_plot()

The RMSE is the only slightly lower than that of the model trained on both **tmn** and **tmx** data. The model still has a tendency to produce slightly higher data
points for the 930m elevation.