In [6]:
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path.append("../")

# choose whether to work on a remote machine
location = "remote"
if location == "remote":
    # change this line to the where the GitHub repository is located
    os.chdir("/lustre_scratch/orlando-code/coralshift/")

In [7]:
from __future__ import annotations

from pathlib import Path
import xarray as xa
import numpy as np
import json
import math as m
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
from sklearn import model_selection

# from sklearn import datasets, ensemble
# from sklearn.inspection import permutation_importance
# from sklearn.metrics import mean_squared_error, log_loss
# from sklearn.model_selection import train_test_split
# import sklearn.metrics as sklmetrics
# from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, GradientBoostingRegressor
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split, RandomizedSearchCV


import cartopy.crs as ccrs
import cartopy.feature as cfeature
import rasterio
from rasterio.plot import show
from rasterio.enums import Resampling
import rioxarray as rio
import pickle

from coralshift.utils import directories, file_ops, utils
from coralshift.processing import spatial_data, baselines
from coralshift.plotting import spatial_plots, model_results

ERROR 1: PROJ: proj_create_from_database: Open of /home/jovyan/lustre_scratch/conda-envs/coralshift/share/proj failed


## Data Derivation

In [None]:
# load and process variables for reproducing Couce et al. (2013)

reproduction_xa_list = load_and_process_reproducing_xa_das()

In [None]:
def load_and_process_reproducing_xa_das() -> list[xa.DataArray]:
    """
    Load and process xarray data arrays for reproducing metrics.

    Returns
    -------
        list[xa.DataArray]: A list containing the processed xarray data arrays.
    """
    # load in daily sea water potential temp
    thetao_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/thetao.nc")
    
    # load in daily sea water salinity means
    salinity_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/so.nc")

    # load in daily latitudinal and longitudinal currents
    uo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/uo.nc")
    vo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/vo.nc")
    # calculate current magnitude
    current_daily = baselines.calculate_magnitude(uo_daily, vo_daily).rename("current")

    # bathymetry = xa.open_dataset(
    #     directories.get_bathymetry_datasets_dir() / "bathymetry_A_0-00030d.nc").rio.write_crs("EPSG:4326")["bathymetry_A"]

    # fetch resolution
    # correct bathymetry file
    bathymetry = xa.open_dataset(directories.get_bathymetry_datasets_dir() / "bathymetry_A_0-00030d.nc")

    # Load in ERA5 surface net solar radiation and upscale to climate variable resolution
    # solar_radiation = xa.open_dataarray(
    #     directories.get_era5_data_dir() / "weather_parameters/VAR_surface_net_solar_radiation_LATS_-10_-17_LONS_142_147_YEAR_1993-2020.nc"
    #     ).rio.write_crs("EPSG:4326")
    solar_radiation = xa.open_dataarray(
        directories.get_era5_data_dir() / "weather_parameters/VAR_surface_net_solar_radiation_LATS_-10_-17_LONS_142_147_YEAR_1993-2020.nc"
        )
    # average solar_radiation daily
    solar_radiation_daily = solar_radiation.resample(time="1D").mean(dim="time")

    # Load in ground truth data
    gt_1000m = xa.open_dataarray(directories.get_processed_dir() / "arrays/coral_raster_1000m.nc").rename("gt")

    return [thetao_daily, salinity_daily, current_daily, solar_radiation_daily, bathymetry, gt_1000m]




# def plot_class_balance(xa_ds):








def plot_train_test_spatial(xa_da: xa.DataArray, figsize: tuple[float,float] = (7,7), bath_mask: xa.DataArray = None):
    """
    Plot two spatial variables from a dataset with different colors and labels.

    Parameters
    ----------
    dataset (xarray.Dataset): The dataset containing the variables.

    Returns
    -------
    None
    """
    # Create a figure and axes
    fig, ax = plt.subplots(figsize = figsize, subplot_kw=dict(projection=ccrs.PlateCarree()))

    cmap = spatial_plots.get_cbar()
    bounds = [0,0.5,1]
    # TODO: fix cmap
    # https://matplotlib.org/stable/api/_as_gen/matplotlib.colors.BoundaryNorm.html    
    if bath_mask.any():
        xa_da = xa_da.where(bath_mask, np.nan)

    im = xa_da.isel(time=-1).plot.pcolormesh(ax=ax, cmap = cmap, add_colorbar=False)
    ax.set_aspect("equal")
    ax.set_title("Geographical visualisation of train-test split")
    ax.add_feature(
        cfeature.NaturalEarthFeature(
            "physical", "land", "10m", edgecolor="black", facecolor="#cccccc"
        )
    )    
    ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True)
    # plt.colorbar(im)
    # format categorical colorbar
    bounds = [0,0.5,1]
    norm = mpl.colors.BoundaryNorm(bounds, 2)
    
    # calculate the position of the tick labels
    min_, max_ = 0,1
    positions = [0.25, 0.75]
    val_lookup = dict(zip(positions, ["train", "test"]))

    def formatter_func(x, pos):
        'The two args are the value and tick position'
        val = val_lookup[x]
        return str(val)

    formatter = plt.FuncFormatter(formatter_func)
    fig.colorbar(ax=ax, mappable=mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
        ticks=positions, format=formatter, spacing='proportional', pad=0.1, fraction=0.046);
    return xa_da

    

def generate_reproducing_metrics_at_different_resolutions(resolutions: list[float], units: list[str]) -> xa.Dataset:

    target_resolutions = [spatial_data.choose_resolution(number, string)[1] for number, string in zip(resolutions, units)]
    for res in tqdm(target_resolutions, total=len(target_resolutions), desc="Processing metrics at various resolutions"):
        
        files_dir = directories.get_comparison_dir() / utils.replace_dot_with_dash(f"{res:.05f}d_arrays")
        files = file_ops.return_list_filepaths(files_dir, ".nc")
        xa_ds_dict = {}

        for fl in files:
            name = (fl.stem).split("_")[0]
            ds = xa.open_dataset(fl).rio.write_crs("epsg:4326")
            variable_name = next((var_name for var_name in ds.data_vars if var_name != "spatial_ref"), None)
            xa_ds_dict[name] = ds[variable_name]

        generate_reproducing_metrics(xa_ds_dict, res)
        











In [None]:
resolutions = [1,  0.5, 0.25, 1/12, 4000]
units = ["d", "d", "d", "d", "m"]


In [None]:
xa_list = load_and_process_reproducing_xa_das()
resample_list_xa_ds_to_target_res_list_and_save(xa_list, resolutions, units)

In [None]:
# generate_reproducing_metrics_at_different_resolutions(resolutions, units)


In [None]:
twelve = xa.open_dataset(directories.get_comparison_dir() / "0-03691d_arrays/all_0-03691d_comparative.nc")
twelve

In [None]:
twelve["thetao_annual_range"].plot()

In [None]:
# def generate_varied_data_resolutions(resolutions: list[float], xa_list: list[xa.DataArray]):
#     for res in tqdm(resolutions, desc="Generating various resolution xa.Datasets"):
#         save_path = (directories.get_comparison_dir() / f"all_{res:.05f}d").with_suffix(".nc")
#         if not save_path.exists():
#             resampled_xa_das_dict = resample_list_xa_ds_to_target_resolution_and_merge(xa_list, target_resolution=res, unit="m")
#             all = xa.merge(list(resampled_xa_das_dict.values()))
#             all.to_netcdf(save_path)


In [None]:
type(list(resampled_xa_das_dict.values()))

In [None]:
xa_list = load_and_process_reproducing_xa_das()
# resampled_xa_das_dict = resample_list_xa_ds_to_target_resolution_and_merge(xa_list, target_resolution=4000, unit="m")
# all_4km = generate_reproducing_metrics(resampled_xa_das_dict)

In [14]:
# def resample_train_predict(data_):

import time

import joblib
# from ray.util.joblib import register_ray


    



In [None]:
loaded_model = pickle.load(open(directories.get_datasets_dir() / "model_params/all_0-03691d_comparative_10_runs_0.pickle", 'rb'))

In [12]:
resolutions = [1,  0.5, 0.25, 1/12, 4000]
units = ["d", "d", "d", "d", "m"]

target_resolutions = [spatial_data.choose_resolution(number, string)[1] for number, string in zip(resolutions, units)]
target_resolutions

[1, 0.5, 0.25, 0.08333333333333333, 0.03691157541051416]

In [15]:
models = ["rf_reg", "brt", "maxent", "rf_cla"]

all_outputs = train_test_across_models(models, d_resolution=0.03691)
# run_outputs = train_test_visualise_roc_across_resolutions(d_resolutions=target_resolutions,model_type="rf_cla")

Fitting each model via random search:   0%|          | 0/4 [00:00<?, ?it/s]

Fitting 3 folds for each of 10 candidates, totalling 30 fits


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using 

[CV] END bootstrap=True, max_depth=60, min_samples_leaf=1, min_samples_split=5, n_estimators=1400; total time=36.6min
[CV] END bootstrap=False, max_depth=30, min_samples_leaf=2, min_samples_split=2, n_estimators=1000; total time=38.0min


[Parallel(n_jobs=1)]: Done 600 out of 600 | elapsed: 34.0min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 600 out of 600 | elapsed:    3.7s finished


[CV] END bootstrap=False, max_depth=10, min_samples_leaf=1, min_samples_split=5, n_estimators=1200; total time=41.3min
[CV] END bootstrap=False, max_depth=10, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time= 9.6min
[CV] END bootstrap=False, max_depth=50, min_samples_leaf=2, min_samples_split=5, n_estimators=600; total time=31.9min
[CV] END bootstrap=False, max_depth=10, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time=10.1min
[CV] END bootstrap=False, max_depth=50, min_samples_leaf=2, min_samples_split=5, n_estimators=600; total time=34.1min


[Parallel(n_jobs=1)]: Done 1000 out of 1000 | elapsed: 52.2min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1000 out of 1000 | elapsed:    3.4s finished
[Parallel(n_jobs=1)]: Done 1400 out of 1400 | elapsed: 52.8min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1400 out of 1400 | elapsed:    5.0s finished
[Parallel(n_jobs=1)]: Done 1400 out of 1400 | elapsed: 55.3min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1400 out of 1400 | elapsed:    8.5s finished
[Parallel(n_jobs=1)]: Done 1000 out of 1000 | elapsed: 56.0min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 56.1min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)

[CV] END bootstrap=False, max_depth=30, min_samples_leaf=2, min_samples_split=2, n_estimators=1000; total time=52.3min
[CV] END bootstrap=True, max_depth=60, min_samples_leaf=1, min_samples_split=5, n_estimators=1400; total time=52.9min


[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 58.7min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    5.0s finished


[CV] END bootstrap=True, max_depth=60, min_samples_leaf=1, min_samples_split=5, n_estimators=1400; total time=55.4min
[CV] END bootstrap=False, max_depth=30, min_samples_leaf=2, min_samples_split=2, n_estimators=1000; total time=56.1min
[CV] END bootstrap=False, max_depth=10, min_samples_leaf=1, min_samples_split=5, n_estimators=1200; total time=56.2min
[CV] END bootstrap=False, max_depth=10, min_samples_leaf=1, min_samples_split=5, n_estimators=1200; total time=58.8min


[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 65.9min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:    6.6s finished
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 65.2min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:    6.6s finished


[CV] END bootstrap=False, max_depth=70, min_samples_leaf=1, min_samples_split=5, n_estimators=1800; total time=66.1min


[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 59.6min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    3.7s finished


[CV] END bootstrap=True, max_depth=100, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time= 5.3min
[CV] END bootstrap=False, max_depth=20, min_samples_leaf=4, min_samples_split=2, n_estimators=1800; total time=65.3min


[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 40.8min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    4.1s finished
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 62.9min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    6.6s finished


[CV] END bootstrap=False, max_depth=10, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time= 7.0min
[CV] END bootstrap=True, max_depth=40, min_samples_leaf=4, min_samples_split=2, n_estimators=200; total time= 7.7min
[CV] END bootstrap=False, max_depth=None, min_samples_leaf=4, min_samples_split=2, n_estimators=1200; total time=59.6min
[CV] END bootstrap=True, max_depth=100, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time= 8.2min
[CV] END bootstrap=True, max_depth=40, min_samples_leaf=4, min_samples_split=2, n_estimators=200; total time= 5.3min
[CV] END bootstrap=False, max_depth=50, min_samples_leaf=2, min_samples_split=5, n_estimators=600; total time=22.8min
[CV] END bootstrap=False, max_depth=None, min_samples_leaf=4, min_samples_split=2, n_estimators=1200; total time=40.9min
[CV] END bootstrap=True, max_depth=100, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time= 7.7min
[CV] END bootstrap=True, max_depth=40, min_sample

[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 87.0min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:    5.1s finished
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 88.2min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:    5.2s finished


[CV] END bootstrap=False, max_depth=70, min_samples_leaf=1, min_samples_split=5, n_estimators=1800; total time=87.1min


[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 92.5min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:    9.0s finished


[CV] END bootstrap=False, max_depth=20, min_samples_leaf=4, min_samples_split=2, n_estimators=1800; total time=88.3min


[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 93.9min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:    9.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


[CV] END bootstrap=False, max_depth=70, min_samples_leaf=1, min_samples_split=5, n_estimators=1800; total time=92.7min
[CV] END bootstrap=False, max_depth=20, min_samples_leaf=4, min_samples_split=2, n_estimators=1800; total time=94.1min


[Parallel(n_jobs=1)]: Done 200 out of 200 | elapsed:  9.6min finished


Saved model to /lustre_scratch/orlando-code/datasets/model_params/best_models/rf_reg_0-03691d_tuned.pickle.
rf_reg_0-03691d_tuned metadata saved to /lustre_scratch/orlando-code/datasets/model_params/best_models/rf_reg_0-03691d_tuned_metadata.json


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 200 out of 200 | elapsed:    1.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 200 out of 200 | elapsed:    1.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 200 out of 200 | elapsed:    1.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 200 out of 200 | elapsed:    1.1s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 200 out of 200 | elapsed:    1.1s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 200 out of 200 | elapsed:    1.1s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_j

Fitting 3 folds for each of 10 candidates, totalling 30 fits




[CV] END criterion=friedman_mse, learning_rate=0.01, loss=lad, max_depth=1, max_features=sqrt, min_samples_leaf=4, min_samples_split=10, n_estimators=1577, subsample=0.9; total time=   0.1s
[CV] END criterion=friedman_mse, learning_rate=0.0021544346900318843, loss=lad, max_depth=7, max_features=auto, min_samples_leaf=2, min_samples_split=5, n_estimators=1366, subsample=0.9; total time=   0.1s




[CV] END criterion=friedman_mse, learning_rate=0.46415888336127775, loss=lad, max_depth=10, max_features=sqrt, min_samples_leaf=4, min_samples_split=2, n_estimators=1155, subsample=0.6; total time=   0.1s
[CV] END criterion=mse, learning_rate=0.21544346900318823, loss=quantile, max_depth=5, max_features=sqrt, min_samples_leaf=2, min_samples_split=10, n_estimators=311, subsample=0.5; total time=   0.1s
[CV] END criterion=friedman_mse, learning_rate=1.0, loss=ls, max_depth=None, max_features=sqrt, min_samples_leaf=2, min_samples_split=5, n_estimators=1788, subsample=0.1; total time=   0.1s
[CV] END criterion=friedman_mse, learning_rate=0.0021544346900318843, loss=lad, max_depth=7, max_features=auto, min_samples_leaf=2, min_samples_split=5, n_estimators=1366, subsample=0.9; total time=   0.1s
[CV] END criterion=mse, learning_rate=0.01, loss=quantile, max_depth=9, max_features=auto, min_samples_leaf=1, min_samples_split=2, n_estimators=1366, subsample=0.30000000000000004; total time=   0.1



[CV] END criterion=mse, learning_rate=0.21544346900318823, loss=quantile, max_depth=5, max_features=sqrt, min_samples_leaf=2, min_samples_split=10, n_estimators=311, subsample=0.5; total time=   0.1s
[CV] END criterion=mse, learning_rate=0.004641588833612777, loss=huber, max_depth=6, max_features=auto, min_samples_leaf=4, min_samples_split=10, n_estimators=2000, subsample=0.1; total time=   0.1s
[CV] END criterion=friedman_mse, learning_rate=0.01, loss=lad, max_depth=1, max_features=sqrt, min_samples_leaf=4, min_samples_split=10, n_estimators=1577, subsample=0.9; total time=   0.1s




[CV] END criterion=friedman_mse, learning_rate=1.0, loss=ls, max_depth=None, max_features=sqrt, min_samples_leaf=2, min_samples_split=5, n_estimators=1788, subsample=0.1; total time=   0.1s
[CV] END criterion=mse, learning_rate=0.004641588833612777, loss=huber, max_depth=6, max_features=auto, min_samples_leaf=4, min_samples_split=10, n_estimators=2000, subsample=0.1; total time=   0.1s




      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0553           0.0002            4.29m
         2           0.0554           0.0001            4.65m
         3           0.0552           0.0002            4.88m
         4           0.0549           0.0001            4.72m
         5           0.0548           0.0001            4.54m
         6           0.0548           0.0001            4.32m
         7           0.0548           0.0001            4.47m
         8           0.0547           0.0001            4.48m
         9           0.0542           0.0002            4.50m
        10           0.0542           0.0001            4.43m
        20           0.0528           0.0001            4.23m
        30           0.0514           0.0001            4.07m
        40           0.0503           0.0001            3.99m
        50           0.0490           0.0001            3.85m
        60           0.0478           0.0001            3.83m
       



      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0824           0.0002            4.71m
         2           0.0824           0.0002            4.81m
         3           0.0818           0.0002            4.97m
         4           0.0819           0.0002            4.73m
         5           0.0817           0.0002            4.53m
         6           0.0813           0.0002            4.35m
         7           0.0816           0.0002            4.55m
         8           0.0813           0.0002            4.57m
         9           0.0813           0.0002            4.68m
        10           0.0808           0.0002            4.62m
        20           0.0788           0.0002            4.46m
        30           0.0773           0.0002            4.32m
        40           0.0755           0.0002            4.26m
        50           0.0740           0.0002            4.11m
        60           0.0724           0.0001            4.08m
       



      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0925           0.0002            5.39m
         2           0.0923           0.0002            5.27m
         3           0.0917           0.0002            5.52m
         4           0.0918           0.0002            5.26m
         5           0.0915           0.0002            4.93m
         6           0.0911           0.0002            4.69m
         7           0.0911           0.0002            4.82m
         8           0.0907           0.0002            4.83m
         9           0.0908           0.0002            4.87m
        10           0.0903           0.0002            4.80m
        20           0.0881           0.0002            4.66m
        30           0.0862           0.0002            4.52m
        40           0.0840           0.0002            4.41m
        50           0.0820           0.0002            4.26m
        60           0.0802           0.0001            4.25m
       



      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0551           0.0002            9.70m
         2           0.0555           0.0002           10.36m
         3           0.0551           0.0002           10.23m
         4           0.0546           0.0002            9.97m
         5           0.0541           0.0002            9.85m
         6           0.0537           0.0002            9.82m
         7           0.0540           0.0002            9.98m
         8           0.0544           0.0002            9.93m
         9           0.0540           0.0002            9.86m
        10           0.0540           0.0002            9.80m
        20           0.0516           0.0002           10.05m
        30           0.0499           0.0002           10.06m
        40           0.0486           0.0002           10.11m
        50           0.0466           0.0002           10.02m
        60           0.0449           0.0002            9.91m
       



      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0829           0.0003           13.09m
         2           0.0821           0.0003           12.27m
         3           0.0812           0.0003           12.18m
         4           0.0813           0.0003           12.14m
         5           0.0810           0.0003           12.17m
         6           0.0800           0.0003           12.17m
         7           0.0808           0.0003           12.20m
         8           0.0813           0.0003           12.16m
         9           0.0813           0.0003           12.12m
        10           0.0791           0.0003           12.07m
        20           0.0769           0.0003           12.06m
        30           0.0749           0.0003           11.94m
        40           0.0714           0.0003           11.76m
        50           0.0692           0.0002           11.57m
        60           0.0679           0.0002           11.37m
       



      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0929           0.0003           12.97m
         2           0.0928           0.0003           12.74m
         3           0.0917           0.0003           12.63m
         4           0.0912           0.0003           12.57m
         5           0.0911           0.0003           12.54m
         6           0.0907           0.0003           12.77m
         7           0.0906           0.0003           12.89m
         8           0.0900           0.0003           12.99m
         9           0.0906           0.0003           12.89m
        10           0.0884           0.0003           12.78m
        20           0.0867           0.0003           12.58m
        30           0.0840           0.0003           12.71m
        40           0.0799           0.0003           12.78m
        50           0.0771           0.0003           12.71m
        60           0.0757           0.0003           12.61m
       



[CV] END criterion=friedman_mse, learning_rate=0.01, loss=lad, max_depth=1, max_features=sqrt, min_samples_leaf=4, min_samples_split=10, n_estimators=1577, subsample=0.9; total time=   0.1s
      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0554           0.0002           23.71m
         2           0.0552           0.0002           24.80m
         3           0.0548           0.0002           24.90m
         4           0.0545           0.0002           24.81m
         5           0.0543           0.0002           24.79m
         6           0.0542           0.0002           24.80m
         7           0.0540           0.0002           25.09m
         8           0.0540           0.0002           25.06m
         9           0.0533           0.0002           24.79m
        10           0.0534           0.0002           24.58m
        20           0.0509           0.0002           24.26m
        30           0.0488           0.0002           24.38m
   

21 fits failed out of a total of 30.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "/home/jovyan/lustre_scratch/conda-envs/coralshift/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/home/jovyan/lustre_scratch/conda-envs/coralshift/lib/python3.10/site-packages/sklearn/ensemble/_gb.py", line 420, in fit
    self._validate_params()
  File "/home/jovyan/lustre_scratch/conda-envs/coralshift/lib/python3.10/site-packages/sklearn/base.py", line 600, in _validate_params
    validate_parameter_constraints(
  File "/home/jovyan/lustre_scratch/conda-envs/coral

      Iter       Train Loss      OOB Improve   Remaining Time 
         1           0.0767           0.0002            4.28m
         2           0.0765           0.0002            4.10m
         3           0.0764           0.0002            4.48m
         4           0.0761           0.0002            4.48m
         5           0.0763           0.0001            4.51m
         6           0.0758           0.0002            4.63m
         7           0.0757           0.0002            4.56m
         8           0.0755           0.0002            4.76m
         9           0.0753           0.0002            4.86m
        10           0.0753           0.0002            4.85m
        20           0.0736           0.0002            4.67m
        30           0.0718           0.0002            4.73m
        40           0.0703           0.0002            4.75m
        50           0.0688           0.0002            4.73m
        60           0.0671           0.0002            4.68m
       

 Running inference on 10 randomly-initialised test splits with 0.25 test fraction: 100%|██████████| 10/10 [00:13<00:00,  1.38s/it]
Fitting each model via random search:  50%|█████     | 2/4 [2:30:12<2:20:00, 4200.03s/it]

Fitting 3 folds for each of 10 candidates, totalling 30 fits


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    3.0s finished
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    3.9s finished
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    6.2s finished
24 fits failed out of a total of 30.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
3 fits failed with the following error:
Traceback (most recent call last):
  File "/home/jovyan/lustre_scratch/conda-envs/coralshift/lib/python3.10/site-packages/sklearn/model_selection/_validation.py", line 6

Saved model to /lustre_scratch/orlando-code/datasets/model_params/best_models/maxent_0-03691d_tuned.pickle.
maxent_0-03691d_tuned metadata saved to /lustre_scratch/orlando-code/datasets/model_params/best_models/maxent_0-03691d_tuned_metadata.json


 Running inference on 10 randomly-initialised test splits with 0.25 test fraction: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s]
Fitting each model via random search:  75%|███████▌  | 3/4 [2:30:29<38:09, 2289.96s/it]  

Fitting 3 folds for each of 10 candidates, totalling 30 fits


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Using 

[CV] END C=0.1, dual=True, fit_intercept=True, intercept_scaling=1.0, max_iter=200, multi_class=auto, penalty=none, solver=newton-cholesky, tol=0.01, verbose=1, warm_start=True; total time=   0.1s
[CV] END C=0.1, dual=False, fit_intercept=True, intercept_scaling=5.0, max_iter=100, multi_class=multinomial, penalty=l1, solver=newton-cholesky, tol=0.0001, verbose=1, warm_start=False; total time=   0.1s
[CV] END C=10.0, dual=False, fit_intercept=False, intercept_scaling=2.0, max_iter=200, multi_class=multinomial, penalty=l1, solver=newton-cholesky, tol=0.001, verbose=1, warm_start=False; total time=   0.1s
[CV] END bootstrap=False, max_depth=10, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time= 2.3min
[CV] END bootstrap=False, max_depth=50, min_samples_leaf=2, min_samples_split=5, n_estimators=600; total time= 8.0min
[CV] END C=0.1, dual=True, fit_intercept=True, intercept_scaling=1.0, max_iter=200, multi_class=auto, penalty=none, solver=newton-cholesky, tol=0.01, ver

[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 14.0min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:   10.6s finished


[CV] END C=10.0, dual=True, fit_intercept=False, intercept_scaling=2.0, max_iter=100, multi_class=multinomial, penalty=l1, solver=saga, tol=0.0001, verbose=2, warm_start=True; total time=   0.1s
[CV] END C=1.0, dual=False, fit_intercept=False, intercept_scaling=1.0, max_iter=100, multi_class=ovr, penalty=l1, solver=newton-cholesky, tol=0.01, verbose=1, warm_start=True; total time=   0.1s
[CV] END bootstrap=False, max_depth=30, min_samples_leaf=2, min_samples_split=2, n_estimators=1000; total time=13.4min
[CV] END C=1.0, dual=False, fit_intercept=False, intercept_scaling=2.0, max_iter=500, multi_class=auto, penalty=l2, solver=sag, tol=0.01, verbose=0, warm_start=False; total time=   1.6s
[CV] END bootstrap=True, max_depth=60, min_samples_leaf=1, min_samples_split=5, n_estimators=1400; total time=13.5min
[CV] END C=1.0, dual=False, fit_intercept=True, intercept_scaling=2.0, max_iter=100, multi_class=ovr, penalty=l1, solver=newton-cholesky, tol=0.001, verbose=0, warm_start=True; total tim

[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 15.0min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:   10.6s finished


[CV] END C=1.0, dual=False, fit_intercept=True, intercept_scaling=2.0, max_iter=100, multi_class=ovr, penalty=l1, solver=newton-cholesky, tol=0.001, verbose=0, warm_start=True; total time=   0.1s
[CV] END bootstrap=False, max_depth=10, min_samples_leaf=1, min_samples_split=5, n_estimators=1200; total time=14.4min


[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 19.7min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:   13.5s finished
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 20.7min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:   15.0s finished
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 21.0min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 19.4min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:   12.1s finished
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed:   14.9s finished
[Parallel(n_jobs=1)]: Done 1800 out of 1800 | elapsed: 22.0min finished
[Parallel(n_jobs=1)]: Using 

[CV] END C=0.1, dual=True, fit_intercept=False, intercept_scaling=1.0, max_iter=100, multi_class=multinomial, penalty=l1, solver=sag, tol=0.0001, verbose=0, warm_start=False; total time=   0.1s
[CV] END bootstrap=True, max_depth=100, min_samples_leaf=2, min_samples_split=10, n_estimators=200; total time= 2.0min
[CV] END bootstrap=True, max_depth=40, min_samples_leaf=4, min_samples_split=2, n_estimators=200; total time= 1.8min
[CV] END bootstrap=False, max_depth=None, min_samples_leaf=4, min_samples_split=2, n_estimators=1200; total time=14.2min
[CV] END C=10.0, dual=True, fit_intercept=False, intercept_scaling=2.0, max_iter=100, multi_class=multinomial, penalty=l1, solver=saga, tol=0.0001, verbose=2, warm_start=True; total time=   0.1s
[CV] END C=0.1, dual=False, fit_intercept=True, intercept_scaling=5.0, max_iter=100, multi_class=multinomial, penalty=l1, solver=newton-cholesky, tol=0.0001, verbose=1, warm_start=False; total time=   0.1s
[CV] END C=10.0, dual=False, fit_intercept=False

[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed: 17.4min finished


Saved model to /lustre_scratch/orlando-code/datasets/model_params/best_models/rf_cla_0-03691d_tuned.pickle.
rf_cla_0-03691d_tuned metadata saved to /lustre_scratch/orlando-code/datasets/model_params/best_models/rf_cla_0-03691d_tuned_metadata.json


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    9.6s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    9.6s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    9.5s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    9.6s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    9.6s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1200 out of 1200 | elapsed:    9.5s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[

In [None]:
bathymetry_flat

In [None]:
all_data = get_comparison_xa_ds(d_resolution=0.03691)


# Assuming you have an xarray dataset called 'ds' with variables 'gt' and 'bathymetry'
# Access the variables
gt_values = all_data['gt']
bathymetry_values = all_data['bathymetry_A']

# Flatten the variables into 1D arrays
gt_flat = gt_values.values.flatten()
bathymetry_flat = bathymetry_values.values.flatten()

# Remove NaN values
valid_indices = np.logical_and(np.isfinite(gt_flat), np.isfinite(bathymetry_flat))
gt_valid = gt_flat[valid_indices]
bathymetry_valid = bathymetry_flat[valid_indices]

# Define the threshold values for bathymetry
threshold_min = -100  # Minimum threshold
threshold_max = 100  # Maximum threshold

# Create a mask to filter out values below and above the thresholds
mask = np.logical_and(bathymetry_valid >= threshold_min, bathymetry_valid <= threshold_max)
gt_filtered = gt_valid[mask]
bathymetry_filtered = bathymetry_valid[mask]

# Plot the histogram
# Plot the histogram
plt.hist2d(bathymetry_filtered, gt_filtered, bins=50, cmap='viridis')
plt.xlabel('Bathymetry')
plt.ylabel('gt')
plt.colorbar(label='Count')
plt.title('Histogram of gt vs Bathymetry')

In [None]:
all_data["bathymetry_A"].plot(vmin=-40, vmax=0)

In [None]:
all_data["gt"].plot()

In [None]:
# Access the "gt" variable from the dataset
gt_variable = all_data['gt']

# Count the number of non-zero and zero values
non_zero_count = (gt_variable != 0).sum()
zero_count = (gt_variable == 0).sum()

# Create a bar chart
labels = ['Non-Zero', 'Zero']
values = [non_zero_count, zero_count]

plt.bar(labels, values)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Number of Non-Zero and Zero Values in "gt"')

In [None]:
# Flatten the variable into a 1D array and remove NaN values
gt_values = gt_variable.values.flatten()
gt_values = gt_values[~np.isnan(gt_values)]

# Create a histogram
plt.hist(gt_values, bins=10)
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram of "gt" Values')

In [None]:
def visualise_class_balance(all_data: xa.Dataset):

In [None]:
np.shape(run_outputs['1.00000'])

In [None]:
run_outputs.keys()

In [None]:
rocs_n_runs(run_outputs['0.08333'])

In [None]:
1/0.083

In [None]:
### SST (sea water potential temperature)
# load in daily sea water potential temp
thetao_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/thetao.nc")

# annual average, stdev of annual averages, annual minimum, annual maximum
thetao_annual_average, _, (thetao_annual_min, thetao_annual_max) = baselines.calc_timeseries_params(thetao_daily, "y", "thetao")
# monthly average, stdev of monthly averages, monthly minimum, monthly maximum
thetao_monthly_average, thetao_monthly_stdev, (thetao_monthly_min, thetao_monthly_max) = baselines.calc_timeseries_params(
    thetao_daily, "m", "thetao")
# annual range (monthly max - monthly min)
thetao_annual_range = (thetao_annual_max - thetao_annual_min).rename("thetao_annual_range")
# weekly minimum, weekly maximum
_, _, (thetao_weekly_min, thetao_weekly_max) = baselines.calc_timeseries_params(thetao_daily, "w", "thetao")


In [None]:
### Salinity
# load in daily sea water salinity means
salinity_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/so.nc")

# annual average
salinity_annual_average, _, _ = baselines.calc_timeseries_params(salinity_daily, "y", "salinity")
# monthly min, monthly max
_, _, (salinity_monthly_min, salinity_monthly_max) = baselines.calc_timeseries_params(salinity_daily, "m", "salinity")

In [None]:
### Current speed (dot product of horizontal and vertical)
# load in daily currents (longitudinal and latitudinal)
uo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/uo.nc")
vo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/vo.nc")
# calculate current magnitude
current_daily = baselines.calculate_magnitude(uo_daily, vo_daily)

# annual average
current_annual_average, _, _ = baselines.calc_timeseries_params(current_daily, "y", "current")
# monthly min, monthly max
_, _, (current_monthly_min, current_monthly_max) = baselines.calc_timeseries_params(current_daily, "m", "current")

In [None]:
### Light penetration proxy
# Load in bathymetry and scale to climate variable resolution
bathymetry = xa.open_dataset(
    directories.get_bathymetry_datasets_dir() / "bathymetry_A_0-00030d.nc").rio.write_crs("EPSG:4326")["bathymetry_A"]
bathymetry_climate_res = spatial_data.upsample_xa_d_to_other(bathymetry, thetao_annual_average, name = "bathymetry")



In [None]:
# Load in ERA5 surface net solar radiation and upscale to climate variable resolution
solar_radiation = xa.open_dataarray(
    directories.get_era5_data_dir() / "weather_parameters/VAR_surface_net_solar_radiation_LATS_-10_-17_LONS_142_147_YEAR_1993-2020.nc"
    ).rio.write_crs("EPSG:4326")
    
# average solar_radiation daily
solar_radiation_daily = solar_radiation.resample(time="1D").mean(dim="time")
solar_climate_res = spatial_data.upsample_xa_d_to_other(solar_radiation_daily, thetao_annual_average, name = "solar_radiation")

# annual average
solar_annual_average, _, _ = baselines.calc_timeseries_params(solar_climate_res, "y", "net_solar")
# monthly min, monthly max
_, _, (solar_monthly_min, solar_monthly_max) = baselines.calc_timeseries_params(solar_climate_res, "m", "net_solar")


In [None]:
### Load in ground truth data
gt_1000m = xa.open_dataarray(directories.get_processed_dir() / "arrays/coral_raster_1000m.nc")
# upsample to same resolution as climate (1/12 of a degree)
gt_climate_res = spatial_data.upsample_xa_d_to_other(gt_1000m, thetao_annual_average, name = "gt")

In [None]:
### Display different resolutions
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(16,9), subplot_kw=dict(projection=ccrs.PlateCarree()))

ax1 = spatial_plots.plot_spatial(gt_1000m, fax= (fig,ax_left), title="Ground truth coral presence map at 1000m")
ax2 = spatial_plots.plot_spatial(
    gt_climate_res, fax=(fig, ax_right), val_lims = (0,1), title="Ground truth coral presence map at 1/12 degree")

## Baseline Machine Learning Models

In [None]:
# combine data
all_data = xa.merge([thetao_annual_average, thetao_annual_range, thetao_monthly_min, thetao_monthly_max, 
    thetao_monthly_stdev, thetao_weekly_min, thetao_weekly_max, 
    salinity_annual_average, salinity_monthly_min, salinity_monthly_max,
    current_annual_average, current_monthly_min, current_monthly_max,
    bathymetry_climate_res, solar_annual_average, solar_monthly_min, solar_monthly_max,
    gt_climate_res
    ])
    
all_data

In [None]:
test_df = all_data.to_dataframe()
# test_df = test_df.dropna(subset=subset, how="any",axis=0)
test_df = test_df.fillna(0)

test_df.describe()

In [None]:
# drop unnecessary columns
features_df = pd.get_dummies(test_df).drop(["spatial_ref","band","depth","gt"], axis=1)
features_df

In [None]:
target_df = test_df["gt"]
target_df

In [None]:
# X_train, X_test, y_train, y_test, train_coordinates,test_coordinates, xa_masked = spatial_split_train_test(
#     all_data, "gt", split_type="pixel", test_fraction = 0.2)

# test_train_da = visualise_train_test_split(all_data, train_coordinates, test_coordinates)
bath_mask = generate_var_mask(all_data)
test_xa = plot_train_test_spatial(test_train_da, bath_mask=bath_mask)


In [None]:
run_outcomes = n_random_runs_preds(rf_random, 10, all_data)

rocs_n_runs(run_outcomes)

In [None]:
def investigate_resolution_predictions(xa_ds: xa.DataArray):
    

In [None]:
out = np.mean(spatial_data.calculate_spatial_resolution(all_data))

In [None]:
1/out

In [None]:
def investigate_label_thresholds(
    thresholds: list[float],
    y_test: np.ndarray | pd.Series,
    y_predictions: np.ndarray | pd.Series,
    figsize=[7, 7],
):
    """
    Plot ROC curves with multiple lines for different label thresholds.

    Parameters
    ----------
        thresholds (list[float]): List of label thresholds.
        y_test (np.ndarray or pd.Series): True labels.
        y_predictions (np.ndarray or pd.Series): Predicted labels.
        figsize (list, optional): Figure size for the plot. Default is [7, 7].

    Returns
    -------
        None
    """
    # TODO: fix UndefinedMetricWarning: No positive samples in y_true, true positive value should be meaningless

    f, ax = plt.subplots(figsize=figsize)
    # prepare colour assignment
    color_map = spatial_plots.get_cbar("seq")
    num_colors = len(thresholds)
    colors = [color_map(i / num_colors) for i in range(num_colors)]

    # plot ROC curves
    for c, thresh in enumerate(thresholds):
        binary_y_labels, binary_predictions = model_results.threshold_label(
            y_test, y_predictions, thresh
        )
        fpr, tpr, _ = sklmetrics.roc_curve(
            binary_y_labels, binary_predictions, drop_intermediate=False
        )
        roc_auc = sklmetrics.auc(fpr, tpr)

        label = f"{thresh:.01f} | {roc_auc:.02f}"
        ax.plot(fpr, tpr, label=label, color=colors[c])

    # format
    title = "Receiver Operating Characteristic (ROC) Curve\nfor several coral presence/absence thresholds"
    format_roc(ax=ax, title=title)
    ax.legend(title="threshold value | auc")



In [None]:
rf_random_preds = rf_random.predict(X_test)

In [None]:
investigate_label_thresholds(np.linspace(0,1,10), y_test, rf_random_preds)

In [None]:
RANDOM_STATE = 42
rf_reg = RandomForestRegressor(random_state=RANDOM_STATE)
# rf_reg.get_params()
# rf_random = RandomizedSearchCV(
#     estimator = rf_reg, param_distributions = random_grid, n_iter = 100, cv = 3, verbose=2, 
#     random_state=RANDOM_STATE, n_jobs = -1)

rf_reg.fit(X_train, y_train)
# rf_reg.fit(X_train, y_train)

In [None]:
rf_reg.get_params()

In [None]:
best_rf_params={"n_estimators":400,
"min_samples_split":2,
"min_samples_leaf":4,
"max_features":"sqrt",
"max_depth":10,
"bootstrap":True}

In [None]:
rf_random = rf_reg.set_params(**best_rf_params)

In [None]:
rf_random_search_best_params = rf_random.best_params_ 
# save best parameteers to json (in coralshift folder)
import json

with open("rf_random_search_best_params.json", "w") as fp:
    json.dump(rf_random_search_best_params, fp)

In [None]:
X_test

In [None]:
# predictions = rf_random.predict(X_test)
# np.shape(predictions)
bce = log_loss(y_true=list(y_test), y_pred=predictions)

In [None]:
from sklearn.metrics import mean_squared_error, roc_auc_score
from sklearn import metrics
# predictions = rf_random.predict(X_test)

mean_squared_error(y_test,predictions)

In [None]:
# predictions = rf_reg.predict(X_test)
# predictions
sum(y_test.where(y_test <= 0,1))

In [None]:
rf_random.best_params_

In [None]:
val = 0
for i in range(11):
    print(val, sum(y_test.where(y_test >= val, 1)))
    val += 0.1

In [None]:
out = np.array(y_test)
sum(np.where(out > 0.1, 1, 0))

In [None]:

    

rf_reg_preds = rf_reg.predict(X_test)
rf_random_preds = rf_random.predict(X_test)

In [None]:
np.where(sum(np.array(y_test)) > 0.1, 1, 0)

In [None]:
# TODO: function to do N model runs and plot resultant ROC
# TODO: test whether training/optimising on binary helps


model_results.investigate_label_thresholds(np.linspace(0,1,100), y_test, rf_reg_preds)


In [None]:
fpr, tpr, thresholds = metrics.roc_curve(y_labels_bin, predictions_bin, drop_intermediate=False)
roc_auc = metrics.auc(fpr, tpr)

display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
                            estimator_name='example estimator')
display.plot()

### Maximum Entropy Model

In [None]:
from sklearn import preprocessing

#convert y values to categorical values
lab = preprocessing.LabelEncoder()
y_transformed = lab.fit_transform(target_df)

In [None]:
# maxent = LogisticRegression(random_state=0)
# maxent.fit(features_df, y_transformed)

In [None]:
# pred_maxent = maxent.predict(features_df)

# predicted_maxent_data = xa.DataArray(pred_maxent.reshape((85,61,28)),
#     coords=all_data.coords, 
#     dims=all_data.dims)

# f,a = plt.subplots(1,2,figsize=[14,7])
# all_data["gt"].plot(ax=a[0])
# predicted_maxent_data.isel(time=-1).plot(ax=a[1]
#     , vmin=all_data["gt"].values.min(), vmax=all_data["gt"].values.max()
#     )

### Classification and Regression Trees (CART)

In [None]:
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
clf = RandomForestRegressor()
clf_lim = RandomForestRegressor()

num_vals = 50000
clf.fit(features_df, target_df)
clf_lim.fit(features_df[:num_vals], target_df[:num_vals])
pred_lim = clf_lim.predict(features_df)

In [None]:
# clf.fit(features_df, target_df)
pred = clf.predict(features_df)
predicted_data = xa.DataArray(pred.reshape((85,61,28)),
    coords=all_data.coords, 
    dims=all_data.dims).isel(time=0)

predicted_lim_data = xa.DataArray(pred_lim.reshape((85,61,28)),
    coords=all_data.coords, 
    dims=all_data.dims).isel(time=0)

In [None]:
### Compare predicted and ground truth values
spatial_plots.plot_spatial_diffs(predicted_data, gt_climate_res, figsize=(14,13))


In [None]:
spatial_plots.plot_spatial_diffs(predicted_lim_data, gt_climate_res, figsize=(14,13))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle




def train_test_random_forest(xa_ds: xa.Dataset, target_variable: str = "gt", test_fraction=0.5, random_state=None):
    # TODO: tidy up and document
    # Extract latitude, longitude, and time values from the spatial image
    lats = xa_ds.latitude.values
    lons = xa_ds.longitude.values
    times = xa_ds.time.values


    if len(xa_ds.dims) > 2:
        # Flatten the spatial image into 2D arrays for training and testing
        flattened_data = xa_ds.stack(points=("latitude", "longitude", "time")).compute().to_dataframe().fillna(0).drop(
            ["time","spatial_ref","band","depth"], axis=1).astype("float32")
    else:
        flattened_data = xa_ds.stack(points=("latitude", "longitude")).compute().to_dataframe().fillna(0).drop(
            ["time","spatial_ref","band","depth"], axis=1).astype("float32")

    features = flattened_data.drop("gt", axis=1)
    # flattened_data = np.transpose(flattened_data, axes=(1, 0))
    labels = flattened_data["gt"]

    # # Split the data into training and testing datasets
    # X_train, X_test, y_train, y_test = train_test_split(
    #     features, labels, test_size=test_fraction, random_state=random_state
    # )

    # Train the random forest regressor
    regressor = RandomForestRegressor()
    regressor.fit(X_train, y_train)

    # Predict the target variable for the testing dataset
    y_pred = regressor.predict(X_test)

    train_indices = X_train.index.values
    test_indices = X_test.index.values

    lat_spacing = xa_ds.latitude.values[1] - xa_ds.latitude.values[0]
    lon_spacing = xa_ds.longitude.values[1] - xa_ds.longitude.values[0]

    # TODO: refer to generic data_var dimension rather than calling by variable
    train_pixs = np.empty(xa_ds["thetao_y_mean"].values.shape)
    train_pixs[:] = np.nan
    test_pixs = np.empty(xa_ds["thetao_y_mean"].values.shape)
    test_pixs[:] = np.nan
    # Color the spatial pixels corresponding to training and testing regions
    for train_index in tqdm(train_indices, desc="Coloring in training indices..."):
        row, col = find_index_pair(xa_ds, train_index[0], train_index[1], lat_spacing, lon_spacing)
        # ax.add_patch(Rectangle((lons[col], lats[row]), 1, 1, facecolor=train_color, alpha=0.2))
        train_pixs[row,col] = 0

    for test_index in tqdm(test_indices, desc="Coloring in training indices..."):
        row, col = find_index_pair(xa_ds, test_index[0], test_index[1], lat_spacing, lon_spacing)
        # ax.add_patch(Rectangle((lons[col], lats[row]), 1, 1, facecolor=test_color, alpha=0.2))
        test_pixs[row,col] = 1

    return regressor, y_pred, ds


reg,random_pred,ds = train_test_random_forest(all_data.isel(time=0), target_variable=all_data["gt"], random_state=42)

In [None]:
random_pred_data = xa.DataArray(random_pred.reshape((85,61,28)),
    coords=all_data.coords,
    dims=all_data.dims)


spatial_plots.plot_spatial_diffs(random_pred_data, gt_climate_res, figsize=(14,13))



In [None]:
def plot_train_test_spatial(dataset, figsize: tuple[float,float] = (7,7)):
    """
    Plot two spatial variables from a dataset with different colors and labels.

    Parameters:
    dataset (xarray.Dataset): The dataset containing the variables.

    Returns:
    None
    """
    # Create a figure and axes
    fig, ax = plt.subplots(figsize = figsize, subplot_kw=dict(projection=ccrs.PlateCarree()))

    # Plot variable1 with color and label
    dataset["train_pixs"].plot(ax=ax, vmin=0, vmax=1, levels=2, label="train_pixs",add_colorbar=False)

    # Plot variable2 with color and label
    dataset["test_pixs"].plot(ax=ax, vmin=0, vmax=1, levels=2, label="test_pixs", add_colorbar=False)
    ax.set_aspect("equal")
    ax.coastlines(resolution="10m", color="red", linewidth=3)
    # Add a colorbar for each variable
    # cbar1 = plt.colorbar(ax=ax, mappable=dataset["train_pixs"])
    # cbar2 = plt.colorbar(ax=ax, mappable=dataset["test_pixs"])
    ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True)




plot_train_test_spatial(ds)




In [None]:


def plot_categorical_data(dataset, variable_name,  color1='red', color2='blue'):
    """
    Plot categorical spatial data from an xarray dataset.

    Parameters:
    dataset (xarray.Dataset): The dataset containing the categorical data.
    variable_name (str): The name of the variable in the dataset to plot.

    Returns:
    None
    """
    # Extract the required data variable
    variable = dataset[variable_name]

    # Get the latitude, longitude, and values arrays
    latitudes = variable.latitude.values
    longitudes = variable.longitude.values
    values = variable.values

    # Create a meshgrid from latitude and longitude arrays
    lon_mesh, lat_mesh = np.meshgrid(longitudes, latitudes)

    # Plot the categorical data
    mask1 = values == 0
    mask2 = values == 1

    # Plot the binary categorical data
    plt.pcolormesh(lon_mesh, lat_mesh, mask1, cmap='Greys', facecolor=color1)
    plt.pcolormesh(lon_mesh, lat_mesh, mask2, cmap='Greys', facecolor=color2)

    # Add colorbar and labels
    # plt.colorbar()
    plt.xlabel('Longitude')
    plt.ylabel('Latitude')
    plt.title(variable_name)
    plt.legend()
    # Show the plot
    plt.show()

plot_categorical_data(ds,"test_train")


In [None]:
ds

In [None]:
all_data

In [None]:


predicted_lim_data = xa.DataArray(pred_lim.reshape((85,61,28)),
    # coords=all_data.coords, 
    dims=all_data.dims)

f,a = plt.subplots(1,2,figsize=[14,7])
predicted_data.isel(time=0).plot(ax=a[0])
predicted_lim_data.isel(time=0).plot(ax=a[1], vmin=predicted_data.min(), vmax=predicted_data.max())

In [None]:
# TODO: training and testing on subsamples of data (train_test_split for linear, somehow something spatial...)
# add in bathymetry
# try binary (classifier)
# function to plot difference between predicted and true
# hyperparameter tuning

### Boosted Regression Trees (BRT)

https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html

In [None]:
X_train_t, X_test_t, y_train_t, y_test_t = train_test_split(
    features_df, target_df, test_size=0.1, random_state=13
)

params = {
    "n_estimators": 500,
    "max_depth": 4,
    "min_samples_split": 5,
    "learning_rate": 0.01,
    "loss": "squared_error",
}

In [None]:
X_train_t

In [None]:
reg = ensemble.GradientBoostingRegressor(**params)
reg.fit(X_train, y_train)

mse = mean_squared_error(y_test, reg.predict(X_test))
print("The mean squared error (MSE) on test set: {:.4f}".format(mse))

In [None]:
test_score = np.zeros((params["n_estimators"],), dtype=np.float64)
for i, y_pred in enumerate(reg.staged_predict(X_test)):
    test_score[i] = mean_squared_error(y_test, y_pred)

fig = plt.figure(figsize=(6, 6))
plt.subplot(1, 1, 1)
plt.title("Deviance")
plt.plot(
    np.arange(params["n_estimators"]) + 1,
    reg.train_score_,
    "b-",
    label="Training Set Deviance",
)
plt.plot(
    np.arange(params["n_estimators"]) + 1, test_score, "r-", label="Test Set Deviance"
)
plt.legend(loc="upper right")
plt.xlabel("Boosting Iterations")
plt.ylabel("Deviance")
fig.tight_layout()
plt.show()

In [None]:
feature_importance = reg.feature_importances_
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0]) + 0.5
fig = plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.barh(pos, feature_importance[sorted_idx], align="center")
plt.yticks(pos, np.array(features_df.columns)[sorted_idx])
plt.title("Feature Importance (MDI)")

result = permutation_importance(
    reg, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)
sorted_idx = result.importances_mean.argsort()
plt.subplot(1, 2, 2)
plt.boxplot(
    result.importances[sorted_idx].T,
    vert=False,
    labels=np.array(features_df.columns)[sorted_idx],
)
plt.title("Permutation Importance (test set)")
fig.tight_layout()
plt.show()

In [None]:
gbr_pred = reg.predict(features_df)

In [None]:

predicted_gbr_data = xa.DataArray(gbr_pred.reshape((85,61,28)),
    coords=all_data.coords, 
    dims=all_data.dims).isel(time=0)


spatial_plots.plot_spatial_diffs(predicted_gbr_data, gt_climate_res, figsize=(14,13))



# f,a = plt.subplots(1,2,figsize=[14,7])
# all_data["gt"].plot(ax=a[0])
# predicted_gbr_data.isel(time=-1).plot(ax=a[1]
#     , vmin=all_data["gt"].values.min(), vmax=all_data["gt"].values.max()
#     )