In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

import geopandas as gpd
import matplotlib.pyplot as plt 
import numpy as np
import xarray as xa
import pandas as pd
import cartopy.crs as ccrs
from pathlib import Path
from tqdm import tqdm


from coralshift import functions_creche
from coralshift.plotting import spatial_plots
from coralshift.processing import spatial_data

In [None]:
!pwd
# relative path to data directory
data_dir_p = Path('/home/rt582/rds/hpc-work/coralshift/data/')

# Load data

In [None]:
# specify data resolution
resolution_lat, resolution_lon = 1., 1.
# region of interest
lats = [-30, -10]
lons = [140, 160]
depths = [0, 200]

## Ground Truth – UNEP-WCMC data


In [None]:
# load in shapefile to geopandas dataframe
unep_fp = data_dir_p / "ground_truth/unep_wcmc/01_Data/WCMC008_CoralReef2021_Py_v4_1.shp"
unep_gdf = gpd.read_file(unep_fp)

unep_gdf.head()

In [None]:
# generate gt raster
unep_raster = functions_creche.rasterize_geodf(unep_gdf, resolution_lat=resolution_lat, resolution_lon=resolution_lon)
# generate gt xarray
xa_unep = functions_creche.raster_to_xarray(
    unep_raster, x_y_limits=functions_creche.lat_lon_vals_from_geo_df(unep_gdf)[:4], 
    resolution_lat=resolution_lat, resolution_lon=resolution_lon, name="unep_coral_presence")

# Plot the xarray DataArray
spatial_plots.plot_spatial(xa_unep, title="Rasterised UNEP Reef Presence", orient_colorbar="horizontal")

# Environmental variables

### WOA 2018

In [None]:
env_data_dir_p = data_dir_p / "env_vars/woa/woa_2018/monthly_1981-2010_temp"


In [None]:
# relevant filepath
env_data_p = env_data_dir_p / "env_vars/woa/woa_2018/monthly_1981-2010_temp"

# TODO: expand to other variables. 
# Could also use open_mfdataset with a preprocess function to limit spatial range: https://docs.xarray.dev/en/stable/generated/xarray.open_mfdataset.html
temp_list = []
# iterate through files in dir ending .nc
for file_p in tqdm(env_data_dir_p.glob("*.nc"), desc=f"Opening .nc files in {env_data_dir_p}"):
    # load in file as xarray dataarray
    temp_array = xa.open_dataset(file_p, decode_times=False)    # TODO: can't understand time format
    # select spatial region of interest
    temp_array = temp_array.sel(lat=slice(*lats), lon=slice(*lons), depth=slice(*depths))   
    # temp_array = temp_array.sel(lat=slice(*lats[::-1]), lon=slice(*lons)) 
    # append to list
    temp_list.append(temp_array)

# concat list of dataarrays into one dataarray. N.B. may not be this simple
env_xa = spatial_data.process_xa_d(xa.concat(temp_list, dim="time"))

In [None]:
from pyinterp.backends import xarray
from pyinterp import fill

var_data = env_xa["t_an"].isel(depth=0)
grid = xarray.Grid2D(var_data.isel(time=0))
filled = fill.loess(grid, nx=5, ny=5)

plt.imshow(filled)

In [None]:
buffered_ds = spatial_data.process_xa_d(
    functions_creche.apply_fill_loess(env_xa.isel(depth=0), nx=2, ny=2))

f, a =plt.subplots(ncols=2, figsize=(14,5), subplot_kw={"projection": ccrs.PlateCarree()})

spatial_plots.plot_spatial(env_xa["t_an"].isel(time=10,depth=0), fax=[f,a[0]])
a[0].set_title("Original Data")
spatial_plots.plot_spatial(buffered_ds["t_an"].isel(time=10), fax=[f,a[1]])
a[1].set_title("Filled Data")


In [None]:
# parameterisation à la Couce et al. 2013

# annual average: t_an
buffered_ds["t_an"].isel(latitude=19, longitude=19).plot()

In [None]:
spatial_data.process_xa_d(buffered_ds).coords

In [None]:
xa_unep.coords

In [None]:
# generate comparable temporal parameterisations

# load in other relevant environmental data

# era5 irradiance

# oras5 currents

# nice-to-haves:
# cyclones
# population

### ORAS5 Timeseries

In [None]:
# TODO: fix this projection nightmare

In [None]:
oras5_var_dir_p = data_dir_p / "env_vars/oras5/test/annual_sosaline"

# create subdirectory in file directory (if doesn't already exist) with info about region selection
subdir_name = functions_creche.create_coord_subdir_name(oras5_var_dir_p, lats, lons, depths)
create_subdirectory(oras5_var_dir_p, subdir_name)

# for function in directory
for file_p in oras5_var_dir_p.glob("*.nc"):
    # load into xarray
    array = xa.open_dataset(file_p)
    # convert x and y to lat lon
    
    # select region of interest (lats, lons, depths) from open xarray
    # save as .nc file to subdirectory with original name
    # delete original file from main directory

In [None]:
test_xa = xa.open_dataarray(file_p)
test_xa

In [None]:
f, ax = plt.subplots(ncols=2, figsize=(14,5))

test_xa["nav_lat"].plot(ax=ax[0])
test_xa["nav_lon"].plot(ax=ax[1])

In [None]:
import cartopy.crs as ccrs

def polar_axis():
    '''cartopy geoaxes centered at north pole'''
    plt.figure(figsize=(6, 5))
    ax = plt.axes(projection=ccrs.NorthPolarStereo(central_longitude=0))
    ax.coastlines(linewidth=0.75, color='black', resolution='50m')
    ax.gridlines(crs=ccrs.PlateCarree(), linestyle='-')
    ax.set_extent([-180, 180, 60, 90], crs=ccrs.PlateCarree())
    return ax

In [None]:
# # generates <cartopy.mpl.geocollection.GeoQuadMesh at 0x150b260a6ce0> doesn't plot
# ax = polar_axis()
# test_xa.plot.pcolormesh(ax=ax, x="nav_lat", y="nav_lon", transform=ccrs.PlateCarree())

In [None]:
test_xa_rename = test_xa.rename({"nav_lat": "lat", "nav_lon": "lon"})

In [None]:
import os
os.environ["ESMFMKFILE"] = "/home/rt582/rds/.conda/envs/pycoral/lib/esmf.mk"

In [None]:
import xesmf as xe

ds_out = xe.util.grid_global(1, 1)
ds_out


In [None]:
regridder = xe.Regridder(test_xa_rename, ds_out, "bilinear", ignore_degenerate=True)

In [None]:
# remove coordinate values of y for which values of nav_lat < 60
test_xa["y"].where(test_xa["nav_lat"] > 60)

In [None]:
ax = polar_axis()
test_xa.plot.pcolormesh(ax=ax, transform=ccrs.PlateCarree(), x='nav_lon', y='nav_lat', add_colorbar=False)

In [None]:
# xa.open_dataarray(file_p).isel(deptht=0).plot()
xa.open_dataarray(file_p).plot()

## Bathymetry – GEBCO

In [None]:
# GEBCO
gebco_f_path = data_dir_p / "bathymetry/gebco/gebco_2023_n-10.0_s-30.0_w140.0_e160.0.nc"
# TODO: processing not working properly e.g. wrt crs
gebco_xa = spatial_data.process_xa_d(xa.open_dataarray(gebco_f_path))
# gebco_nc["elevation"].plot()
spatial_plots.plot_spatial(gebco_xa)

In [None]:
# calculate slopes. Look in bathymetry.py
from coralshift.dataloading import bathymetry
gebco_slopes_xa = bathymetry.calculate_gradient_magnitude(gebco_xa)
spatial_plots.plot_spatial(gebco_slopes_xa, title="Seafloor gradients")

# save slopes to new file in data dir

# Pre-processing

In [None]:
# rename lat and lon to latitude and longitude
# spatially align xarrays
input_dss = [spatial_data.process_xa_d(buffered_ds), spatial_data.process_xa_d(gebco_xa), 
    spatial_data.process_xa_d(xa_unep)
    ]

In [None]:
# spatially align datasets into a single xarray dataset
common_dataset = functions_creche.spatially_combine_xa_d_list(input_dss, lats, lons, resolution_lat, resolution_lon) 

In [None]:
import cartopy.crs as ccrs
# initialise subplots with crs = PlateCarree projection
f, ax = plt.subplots(nrows=3, figsize=(40,8), subplot_kw={"projection": ccrs.PlateCarree()})
vars_to_plot = ["t_gp", "elevation", "unep_coral_presence"]

# for i, var in tqdm(enumerate(vars_to_plot), desc=f"Plotting {var}"):
for i, var in (enumerate(tqdm(vars_to_plot, desc="Plotting..."))):
    array_to_plot = common_dataset[var]
    # if time, select first
    if "time" in array_to_plot.dims:
        array_to_plot = array_to_plot.isel(time=0)
        
    spatial_plots.plot_spatial(array_to_plot, fax=[f,ax[i]])

In [None]:
f, ax = plt.subplots(ncols=2, figsize=(14,5))

common_dataset["elevation"].sel(latitude=slice(-30,-28), longitude=slice(155,157)).plot(ax=ax[0])
ax[0].set_title("elevation")
common_dataset["t_an"].sel(latitude=slice(-30,-28), longitude=slice(155,157)).isel(time=0).plot(ax=ax[1])
ax[0].set_title("t_an")

In [None]:
common_dataset

In [None]:
# specify variables to keep
predictors = ["t_an", "t_mn", "t_dd", "t_sd", "t_se", "t_oa", "t_ma", "t_gp", "elevation"]
gt = "unep_coral_presence"
# send to dataframe with selected variables
combined_df = common_dataset[predictors + [gt]].to_dataframe()
# train-test-val split: spatial/pixel-wise

df_list = functions_creche.tvt_spatial_split(combined_df, [0.6, 0.2, 0.2])
# generate and save scaling parameters, ignoring nans. Start with min-max scaling
# one-hot encode nans
# cast to numpy array


In [None]:
order = ["train", "test", "val"]

f,ax = plt.subplots(nrows=3, figsize=(20,5), subplot_kw={"projection": ccrs.PlateCarree()})

for i, df in enumerate(df_list):
    ds = xa.Dataset.from_dataframe(df)
    if "depth" in ds.dims:
        ds = ds.isel(depth=0)
    if "time" in ds.dims:
        ds = ds.isel(time=0)
    spatial_plots.plot_spatial(ds["t_an"], title=f"{order[i]} dataset", fax=[f,ax[i]])


In [None]:
((X_train, y_train), (X_val, y_val), (X_test, y_test)), dfs_list = functions_creche.process_df_for_rfr(combined_df, predictors, gt)
# vals= process_df_for_rfr(combined_df, predictors, gt)


In [None]:
# train model iteratively (using batching)
from sklearn.ensemble import RandomForestRegressor

rfr_model = RandomForestRegressor(random_state=42, warm_start=True)

def train_rf_model_iteratively(rf_model, train_X: np.ndarray, train_y: np.ndarray, batch_size: int=100):
    train_points = len(train_X)
    for batch in tqdm(range(0, len(train_X), batch_size), desc=f"Batched training"):
        # if not enough data for a complete batch, use remaining data
        if batch + batch_size > train_points:
            batch_size = train_points - batch

        X_batch = train_X[batch : batch + batch_size]
        y_batch = train_y[batch : batch + batch_size]

        rf_model.fit(X_batch, y_batch)
        rf_model.n_estimators += 1

    return rf_model

basic_model = train_rf_model_iteratively(rfr_model, X_train, y_train)

In [None]:
from coralshift.machine_learning import baselines


random_model = baselines.train_tune(
    X_train, y_train, "rf_reg", resolution = 1, name="first_random", search_type="random", n_jobs=-1, verbose=False)
# rfr_grid = baselines.rf_search_grid()

# grid_search_cv = baselines.initialise_grid_search(model_type="rf_reg")

In [None]:
# grid search
best_params_dict = random_model.best_params_

grid_model = baselines.train_tune(
    X_train, y_train, "rf_reg", resolution = 1, name="first_grid", search_type="grid", n_jobs=-1, verbose=0, best_params_dict=best_params_dict)

In [None]:
from sklearn.metrics import mean_squared_error

y_pred = random_model.predict(X_train)
train_pred_df = functions_creche.reform_df(dfs_list[0], y_pred)

# mean_squared_error(y_train, y_pred)
mean_squared_error(train_pred_df["unep_coral_presence"], train_pred_df["prediction"])

In [None]:
y_pred = grid_model.predict(X_train)
xa.Dataset.from_dataframe(functions_creche.reform_df(dfs_list[0], y_pred))

# def xarray_from_df(df: pd.DataFrame, )
# don't think it's necessary

In [None]:
spatial_plots.plot_spatial_diffs(
    pred_xa["unep_coral_presence"].isel(time=0), 
    pred_xa["prediction"].isel(time=0), 
    # title="Prediction/Ground Truth Residual"
    )

In [None]:
# evaluate_model(random_model, dfs_list[1], X_val, y_val)
# evaluate_model(random_model, dfs_list[0], X_train, y_train)
# functions_creche.evaluate_model(grid_model, dfs_list[0], X_train, y_train) 
# TODO: implement this (side-by-side spatial comparison, plot with mse)

def evaluate_model(model, df: pd.DataFrame, X: np.ndarray, y: np.ndarray, figsize: tuple=[4,4]):
    """
    Evaluate model (visually and mse) on a given dataset, returning an xarray with predictions and ground truth.

    Args:
        model (sklearn model): trained model
        df (pd.DataFrame): dataframe with ground truth
        X (np.ndarray): input data
        y (np.ndarray): ground truth
        figsize (tuple, optional): figure size. Defaults to [4,4].
    
    Returns:
        pred_xa (xa.Dataset): xarray dataset with ground truth and predictions
    """
    y_pred = model.predict(X)
    pred_df = functions_creche.reform_df(df, y_pred)
    mse = mean_squared_error(pred_df["unep_coral_presence"], pred_df["prediction"])

    f,ax = plt.subplots(figsize=figsize)
    ax.scatter(y, y_pred)
    # y=x for comparison
    ax.axline((0, 0), slope=1, c="k")
    ax.axis("equal")
    ax.set_xlabel("Ground Truth")
    ax.set_ylabel("Prediction")
    ax.set_xlim([0,1])

    plt.suptitle(f"MSE: {mse:.04f}")

    return pred_xa


pred_xa = evaluate_model(grid_model, dfs_list[0], X_train, y_train) 


In [None]:
common_dataset["unep_coral_presence"].values.flatten().shape

In [None]:
gts_vals = common_dataset["unep_coral_presence"].values.flatten()
plt.hist(gts_vals, bins=20);

In [None]:
df_list[0].unep_coral_presence

In [None]:
df_list[0].unep_coral_presence.plot.hist(bins=100)

## PyTorch

In [None]:
from coralshift.machine_learning import transformer_utils
from coralshift.machine_learning import run_transformer

pyt_train_x, pyt_train_y = transformer_utils.create_dummy_data()
dh = transformer_utils.get_data()

In [None]:
run_transformer.run_all()

In [None]:
dh.torch_dataset

In [None]:
# # TODO: visualise any differences in distributions between train and test/val datasets
# # save scaling parameters to json
# # N.B. this file saving might not be necessary
# from coralshift.utils import file_ops
# file_ops.save_json(scale_dict, data_dir_p / "scaling_params.json")
# # TODO: include more metadata about the exact dataset used to generate the scaling parameters
# # loading scaling data
# import json
# # TODO: smoother way to do this?
# scaling_params_p = data_dir_p / "scaling_params.json"
# f = open(scaling_params_p)
# scale_dict = json.load(f)