## Notebook for visualising experiment 2 results
### This notebook is for user to visualise the results from experiment 2


Defining Functions for Plotting: This section defines utility functions used for handling projections in the plots.

In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeat
from pyproj import Transformer
from global_land_mask import globe


def EASE2toWGS84(x, y, return_vals="both", lon_0=0, lat_0=90):

    valid_return_vals = ['both', 'lon', 'lat']
    assert return_vals in ['both', 'lon', 'lat'], f"return_val: {return_vals} is not in valid set: {valid_return_vals}"
    EASE2 = f"+proj=laea +lon_0={lon_0} +lat_0={lat_0} +x_0=0 +y_0=0 +ellps=WGS84 +towgs84=0,0,0,0,0,0,0 +units=m +no_defs"
    WGS84 = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs"
    transformer = Transformer.from_crs(EASE2, WGS84)
    lon, lat = transformer.transform(x, y)
    if return_vals == "both":
        return lon, lat
    elif return_vals == "lon":
        return lon
    elif return_vals == "lat":
        return lat

def plot_f_star(data, lon_col, lat_col, f_star_col, title=None, vmin=None, vmax=None, cmap='YlGnBu_r', point_size=1):

    fig = plt.figure(figsize=(12, 12))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.NorthPolarStereo())
    ax.coastlines(resolution='50m', color='white')
    ax.add_feature(cfeat.LAND.with_scale('50m'), facecolor='lightgray', edgecolor='black', zorder=6)
    ax.add_feature(cfeat.OCEAN.with_scale('50m'), facecolor='dimgray', edgecolor='black', zorder=5)
    ax.gridlines()

    scatter = ax.scatter(data[lon_col], data[lat_col], c=data[f_star_col],
                         cmap=cmap, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(), s=point_size, zorder=7)
    cbar = fig.colorbar(scatter, ax=ax, orientation='horizontal', pad=0.05, shrink=0.8)
    cbar.set_label(f"{f_star_col} values", fontsize=14)
    cbar.ax.tick_params(labelsize=12)

    if title:
        ax.set_title(title, fontsize=16)

    plt.tight_layout()
    plt.show()


Loading Data: The next cell loads the necessary data for plotting.

In [None]:
predicted_mean = torch.load('final_predictions_exp2.pt')
predicted_var = torch.load('final_variance_exp2.pt')
test_data_path = '/test_loc.csv'
pred_df = pd.read_csv(test_data_path)

print(pred_df.shape) 

print(pred_df.head())
print(pred_df.isnull().sum())

pred_df['lon'], pred_df['lat'] = EASE2toWGS84(pred_df['pred_loc_x'], pred_df['pred_loc_y'])
average_predictions = predicted_mean
average_predictions = torch.Tensor(average_predictions)
predicted_var = torch.Tensor(predicted_var)
predicted_var_np = predicted_var.cpu().numpy().flatten()
predicted_var_np = predicted_var_np.flatten()
print(average_predictions.shape)
average_predictions_np = average_predictions.cpu().numpy().flatten()
average_predictions_np = average_predictions_np.flatten()

pred_df['f*'] = average_predictions_np
pred_df['f*_var'] = predicted_var_np
pred_df["is_in_ocean"] = globe.is_ocean(pred_df['lat'], pred_df['lon'])
pred_df = pred_df.loc[pred_df['is_in_ocean']]


Final Plotting: With all data prepared, we now plot the model predictions (prediceted mean and variance)

In [None]:
plot_f_star(pred_df, lon_col='lon', lat_col='lat', f_star_col='f*',
            title="f*", vmin=-0.1, vmax=0.3, point_size=1)


In [None]:
plot_f_star(pred_df, lon_col='lon', lat_col='lat', f_star_col='f*_var',
            title="f*", vmin=0, vmax=0.01, point_size=1)
