In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path.append("../")

# choose whether to work on a remote machine
location = "remote"
if location == "remote":
    # change this line to the where the GitHub repository is located
    os.chdir("/lustre_scratch/orlando-code/coralshift/")

In [2]:
from __future__ import annotations

from pathlib import Path
import xarray as xa
import numpy as np
import time
import json
import math as m
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
from sklearn import model_selection

# from sklearn import datasets, ensemble
# from sklearn.inspection import permutation_importance
# from sklearn.metrics import mean_squared_error, log_loss
# from sklearn.model_selection import train_test_split
# import sklearn.metrics as sklmetrics
# from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, GradientBoostingRegressor
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split, RandomizedSearchCV


import cartopy.crs as ccrs
import cartopy.feature as cfeature
import rasterio
from rasterio.plot import show
from rasterio.enums import Resampling
import rioxarray as rio
import pickle

from coralshift.utils import directories, file_ops, utils
from coralshift.processing import spatial_data
from coralshift.plotting import spatial_plots, model_results
from coralshift.machine_learning import baselines
from coralshift.dataloading import bathymetry

ERROR 1: PROJ: proj_create_from_database: Open of /home/jovyan/lustre_scratch/conda-envs/coralshift/share/proj failed


## Data Derivation

In [None]:
# # load and process variables for reproducing Couce et al. (2013)

# reproduction_xa_list = load_and_process_reproducing_xa_das("A")
# resampled_xa_das_dict = spatial_data.resample_list_xa_ds_to_target_resolution_and_merge(reproduction_xa_list, target_resolution=4000, unit="m")
# all_4km = generate_reproducing_metrics(resampled_xa_das_dict)

In [None]:
test = xa.open_dataset(directories.get_comparison_dir() / "Great_Barrier_Reef_C/0-0367d_arrays/all_0-0367d_comparative.nc")
test["gt"].plot()

In [None]:
reef_areas = bathymetry.ReefAreas()
comparison_dir = directories.get_comparison_dir()
region_list = ["A","B","C","D"]
resolution = 0.0368


das = []
paths = []
for i, region in enumerate(region_list):
    name = reef_areas.get_short_filename(region)
    path = comparison_dir / name / f"{utils.replace_dot_with_dash(str(resolution))}d_arrays/all_0-0368d_comparative.nc"
    paths.append(path)
    da = (xa.open_dataset(path, decode_coords="all"))
    da.attrs["region"] = region_list[i]
    das.append(da)
    

In [None]:
test0 = das[0].sel({"longitude": slice(142, 147), "latitude": slice(-17,-10)})
test1 = das[1].sel({"longitude": slice(147, 148), "latitude": slice(-18,-17)})
test2 = das[2].sel({"longitude": slice(148, 154), "latitude": slice(-24,-18)})
test3 = das[3].sel({"longitude": slice(154, 156), "latitude": slice(-29,-24)})

das_test = [test0,test1,test2,test3]
merged = xa.merge(das_test)


merged["gt"].plot()

In [37]:
def xa_dss_to_df(
    xa_dss: list[xa.Dataset],
    split_type: str = "pixel",
    test_lats: tuple[float] = None,
    test_lons: tuple[float] = None,
    test_fraction: float = 0.2,
    bath_mask: bool = True,
    ignore_vars: list = ["spatial_ref", "band", "depth"]
):
    train_coords, test_coords, dfs = [], [], []
    for xa_ds in xa_dss:
        # compute out dasked chunks, fill Nan values with 0, drop columns which would confuse model
        df = xa_ds.stack(points=("latitude", "longitude", "time")).compute().astype("float32").to_dataframe()
        df["onehotnan"] = df.isnull().any(axis=1).astype(int)
        # fill nans with 0 and drop datetime columns
        df = df.fillna(0).drop(columns=list(df.select_dtypes(include='datetime64').columns))
        # drop ignored vars
        df = df.drop(columns=list(set(ignore_vars).intersection(df.columns)))

        train_coordinates, test_coordinates = baselines.generate_test_train_coordinates(
            xa_ds, split_type, test_lats, test_lons, test_fraction, bath_mask
        )

        train_coords.extend(train_coordinates)
        test_coords.extend(test_coordinates)
        dfs.append(df)

    # flatten dataset for row indexing and model training
    return pd.concat(dfs), train_coords, test_coords


def spatial_split_train_test(
    xa_dss: list[xa.Dataset],
    gt_label: str = "gt",
    data_type: str = "continuous",
    ignore_vars: list = ["spatial_ref", "band", "depth"],
    split_type: str = "pixel",
    test_lats: tuple[float] = None,
    test_lons: tuple[float] = None,
    test_fraction: float = 0.25,
    bath_mask: bool = True,
) -> tuple:
    """
    Split the input dataset into train and test sets based on spatial coordinates.

    Parameters
    ----------
        xa_ds (xa.Dataset): The input xarray Dataset.
        gt_label: The ground truth label.
        ignore_vars (list): A list of variables to ignore during splitting. Default is
            ["time", "spatial_ref", "band", "depth"].
        split_type (str): The split type, either "pixel" or "region". Default is "pixel".
        test_lats (tuple[float]): The latitude range for the test region. Required for "region" split type.
            Default is None.
        test_lons (tuple[float]): The longitude range for the test region. Required for "region" split type.
            Default is None.
        test_fraction (float): The fraction of data to be used for the test set. Default is 0.2.

    Returns
    -------
        tuple: A tuple containing X_train, X_test, y_train, and y_test.
    """
    flattened_data, train_coords, test_coords = xa_dss_to_df(
        xa_dss,
        split_type=split_type,
        test_lats=test_lats,
        test_lons=test_lons,
        test_fraction=test_fraction,
        bath_mask=bath_mask,
    )
    # normalise data via min/max scaling
    normalised_data = (flattened_data - flattened_data.min()) / (
        flattened_data.max() - flattened_data.min()
    )

    # return train and test rows from dataframe
    train_rows = utils.select_df_rows_by_coords(normalised_data, train_coords)
    test_rows = utils.select_df_rows_by_coords(normalised_data, test_coords)

    # assign rows to test and train features/labels
    X_train, X_test = train_rows.drop("gt", axis=1), test_rows.drop("gt", axis=1)
    y_train, y_test = train_rows["gt"], test_rows["gt"]

    if data_type == "discrete":
        y_train, y_test = baselines.threshold_array(y_train), baselines.threshold_array(y_test)

    return X_train, X_test, y_train, y_test, train_coords, test_coords


def train_tune_across_models(
    model_types: list[str], d_resolution: float = 0.03691, split_type: str = "pixel", test_fraction: float=0.25):
    model_comp_dir = file_ops.guarantee_existence(
        directories.get_datasets_dir() / "model_params/best_models"
    )

    all_data = baselines.get_comparison_xa_ds(d_resolution=d_resolution)
    res_string = utils.replace_dot_with_dash(f"{d_resolution:.04f}d")

    # define train/test split so it's the same for all models
    (X_train, X_test, y_train, y_test, _, _) = spatial_split_train_test(
        utils.list_if_not_already(all_data),
        "gt",
        split_type=split_type,
        test_fraction=test_fraction,
    )

    for model in tqdm(model_types, total=len(model_types), desc="Fitting each model via random search"):
        train_tune(
            X_train, y_train,
            model_type=model,
            resolution = d_resolution,
            save_dir=model_comp_dir,
            name=f"{model}_{res_string}_tuned",
            test_fraction=0.25,
        )


def train_tune(
    X_train,
    y_train,
    model_type: str,
    resolution: float,
    name: str = "_",
    test_fraction: float = 0.25,
    save_dir: Path | str = None,
    n_iter: int = 50,
    cv: int = 3,
):
    model, data_type, search_grid = baselines.initialise_model(model_type)

    if data_type == "discrete":
        y_train = baselines.threshold_array(y_train)
    # register_ray()
    start_time = time.time()
    model_random = baselines.RandomizedSearchCV(
        estimator=model,
        param_distributions=search_grid,
        n_iter=n_iter,
        cv=cv,
        verbose=2,
        random_state=42,
        n_jobs=-1,
    )
    end_time = time.time()
    randomised_search_time = end_time - start_time

    print("Fitting model with a randomized hyperparameter search...")
    # with joblib.parallel_backend("ray"):
    start_time = time.time()
    model_random.fit(X_train, y_train)
    end_time = time.time()
    fit_time = end_time - start_time

    # resolution = np.mean(spatial_data.calculate_spatial_resolution(all_data))

    # save best parameters
    if not save_dir:
        save_dir = file_ops.guarantee_existence(
            directories.get_datasets_dir() / "model_params"
        )

    save_path = baselines.save_sklearn_model(model_random, save_dir, name)
    baselines.create_train_metadata(
        name=name,
        model_path=save_path,
        model_type=model_type,
        data_type=data_type,
        randomised_search_time=randomised_search_time,
        fit_time=fit_time,
        test_fraction=test_fraction,
        features=list(X_train.columns),
        resolution=resolution,
    )



# names = [region["short_name"] for region in bathymetry.ReefAreas().datasets]
# X_train, X_test, y_train, y_test, train_coordinates, test_coordinates = spatial_split_train_test(das)



In [38]:
train_tune_across_models(model_types=["maxent", "rf_reg", "brt", "rf_cla"],
    d_resolution=0.0368
    )

Fitting each model via random search:   0%|          | 0/4 [00:00<?, ?it/s]

Fitting model with a randomized hyperparameter search...
Fitting 3 folds for each of 50 candidates, totalling 150 fits


1196.23s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
1196.41s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
1196.60s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
1196.79s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
1196.97s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
1197.16s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
 /lustre_scratch/orlando-code/conda-envs/coralshift/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydev_bundle/__init__.py
  /lustre_scratch/orlando-code/conda-envs/coralshift/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydev_bundle/_pydev_calltip_util.py
  /lustre_scratch/orlando-code/conda-envs/coralshift/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydev_bundle/_pydev_completer.py
  /lustre_scr

In [None]:
len(get_comparison_xa_ds())

In [None]:
das[0]

In [None]:
calculate_class_weight(y_test)

In [None]:
# class weighting necessary, as shown by this plot
region_dict = dict(zip([da.region for da in das], das))
region_imbalance = calculate_region_class_imbalance(region_dict)
model_results.visualise_region_class_imbalance(region_imbalance_dict)

In [None]:
model_results.visualise_region_class_imbalance(region_imbalance)

In [None]:


# def plot_class_balance(xa_ds):


def plot_train_test_spatial(xa_da: xa.DataArray, figsize: tuple[float,float] = (7,7), bath_mask: xa.DataArray = None):
    """
    Plot two spatial variables from a dataset with different colors and labels.

    Parameters
    ----------
    dataset (xarray.Dataset): The dataset containing the variables.

    Returns
    -------
    None
    """
    # Create a figure and axes
    fig, ax = plt.subplots(figsize = figsize, subplot_kw=dict(projection=ccrs.PlateCarree()))

    cmap = spatial_plots.get_cbar()
    bounds = [0,0.5,1]
    # TODO: fix cmap
    # https://matplotlib.org/stable/api/_as_gen/matplotlib.colors.BoundaryNorm.html    
    if bath_mask.any():
        xa_da = xa_da.where(bath_mask, np.nan)

    im = xa_da.isel(time=-1).plot.pcolormesh(ax=ax, cmap = cmap, add_colorbar=False)
    ax.set_aspect("equal")
    ax.set_title("Geographical visualisation of train-test split")
    ax.add_feature(
        cfeature.NaturalEarthFeature(
            "physical", "land", "10m", edgecolor="black", facecolor="#cccccc"
        )
    )    
    ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True)
    # plt.colorbar(im)
    # format categorical colorbar
    bounds = [0,0.5,1]
    norm = mpl.colors.BoundaryNorm(bounds, 2)
    
    # calculate the position of the tick labels
    min_, max_ = 0,1
    positions = [0.25, 0.75]
    val_lookup = dict(zip(positions, ["train", "test"]))

    def formatter_func(x, pos):
        'The two args are the value and tick position'
        val = val_lookup[x]
        return str(val)

    formatter = plt.FuncFormatter(formatter_func)
    fig.colorbar(ax=ax, mappable=mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
        ticks=positions, format=formatter, spacing='proportional', pad=0.1, fraction=0.046);
    return xa_da


def generate_reproducing_metrics_at_different_resolutions(resolutions: list[float], units: list[str]) -> xa.Dataset:

    target_resolutions = [spatial_data.choose_resolution(number, string)[1] for number, string in zip(resolutions, units)]
    for res in tqdm(target_resolutions, total=len(target_resolutions), desc="Processing metrics at various resolutions"):
        
        files_dir = directories.get_comparison_dir() / utils.replace_dot_with_dash(f"{res:.05f}d_arrays")
        files = file_ops.return_list_filepaths(files_dir, ".nc")
        xa_ds_dict = {}

        for fl in files:
            name = (fl.stem).split("_")[0]
            ds = xa.open_dataset(fl).rio.write_crs("epsg:4326")
            variable_name = next((var_name for var_name in ds.data_vars if var_name != "spatial_ref"), None)
            xa_ds_dict[name] = ds[variable_name]

        generate_reproducing_metrics(xa_ds_dict, res)


In [None]:
resolutions = [1,  0.5, 0.25, 1/12, 4000]
units = ["d", "d", "d", "d", "m"]


In [None]:
loaded_model = pickle.load(open(directories.get_datasets_dir() / "model_params/all_0-03691d_comparative_10_runs_0.pickle", 'rb'))

In [None]:
resolutions = [1,  0.5, 0.25, 1/12, 4000]
units = ["d", "d", "d", "d", "m"]

target_resolutions = [spatial_data.choose_resolution(number, string)[1] for number, string in zip(resolutions, units)]
target_resolutions

In [None]:
models = ["rf_reg", "brt", "maxent", "rf_cla"]

all_outputs = train_test_across_models(models, d_resolution=0.03691)
# run_outputs = train_test_visualise_roc_across_resolutions(d_resolutions=target_resolutions,model_type="rf_cla")

In [None]:
bathymetry_flat

In [None]:
all_data = get_comparison_xa_ds(d_resolution=0.03691)


# Assuming you have an xarray dataset called 'ds' with variables 'gt' and 'bathymetry'
# Access the variables
gt_values = all_data['gt']
bathymetry_values = all_data['bathymetry_A']

# Flatten the variables into 1D arrays
gt_flat = gt_values.values.flatten()
bathymetry_flat = bathymetry_values.values.flatten()

# Remove NaN values
valid_indices = np.logical_and(np.isfinite(gt_flat), np.isfinite(bathymetry_flat))
gt_valid = gt_flat[valid_indices]
bathymetry_valid = bathymetry_flat[valid_indices]

# Define the threshold values for bathymetry
threshold_min = -100  # Minimum threshold
threshold_max = 100  # Maximum threshold

# Create a mask to filter out values below and above the thresholds
mask = np.logical_and(bathymetry_valid >= threshold_min, bathymetry_valid <= threshold_max)
gt_filtered = gt_valid[mask]
bathymetry_filtered = bathymetry_valid[mask]

# Plot the histogram
# Plot the histogram
plt.hist2d(bathymetry_filtered, gt_filtered, bins=50, cmap='viridis')
plt.xlabel('Bathymetry')
plt.ylabel('gt')
plt.colorbar(label='Count')
plt.title('Histogram of gt vs Bathymetry')

In [None]:
all_data["bathymetry_A"].plot(vmin=-40, vmax=0)

In [None]:
all_data["gt"].plot()

In [None]:
# Access the "gt" variable from the dataset
gt_variable = all_data['gt']

# Count the number of non-zero and zero values
non_zero_count = (gt_variable != 0).sum()
zero_count = (gt_variable == 0).sum()

# Create a bar chart
labels = ['Non-Zero', 'Zero']
values = [non_zero_count, zero_count]

plt.bar(labels, values)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Number of Non-Zero and Zero Values in "gt"')

In [None]:
# Flatten the variable into a 1D array and remove NaN values
gt_values = gt_variable.values.flatten()
gt_values = gt_values[~np.isnan(gt_values)]

# Create a histogram
plt.hist(gt_values, bins=10)
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram of "gt" Values')

In [None]:
def visualise_class_balance(all_data: xa.Dataset):

In [None]:
np.shape(run_outputs['1.00000'])

In [None]:
run_outputs.keys()

In [None]:
rocs_n_runs(run_outputs['0.08333'])

In [None]:
1/0.083

In [None]:
### SST (sea water potential temperature)
# load in daily sea water potential temp
thetao_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/thetao.nc")

# annual average, stdev of annual averages, annual minimum, annual maximum
thetao_annual_average, _, (thetao_annual_min, thetao_annual_max) = baselines.calc_timeseries_params(thetao_daily, "y", "thetao")
# monthly average, stdev of monthly averages, monthly minimum, monthly maximum
thetao_monthly_average, thetao_monthly_stdev, (thetao_monthly_min, thetao_monthly_max) = baselines.calc_timeseries_params(
    thetao_daily, "m", "thetao")
# annual range (monthly max - monthly min)
thetao_annual_range = (thetao_annual_max - thetao_annual_min).rename("thetao_annual_range")
# weekly minimum, weekly maximum
_, _, (thetao_weekly_min, thetao_weekly_max) = baselines.calc_timeseries_params(thetao_daily, "w", "thetao")


In [None]:
### Salinity
# load in daily sea water salinity means
salinity_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/so.nc")

# annual average
salinity_annual_average, _, _ = baselines.calc_timeseries_params(salinity_daily, "y", "salinity")
# monthly min, monthly max
_, _, (salinity_monthly_min, salinity_monthly_max) = baselines.calc_timeseries_params(salinity_daily, "m", "salinity")

In [None]:
### Current speed (dot product of horizontal and vertical)
# load in daily currents (longitudinal and latitudinal)
uo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/uo.nc")
vo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/vo.nc")
# calculate current magnitude
current_daily = baselines.calculate_magnitude(uo_daily, vo_daily)

# annual average
current_annual_average, _, _ = baselines.calc_timeseries_params(current_daily, "y", "current")
# monthly min, monthly max
_, _, (current_monthly_min, current_monthly_max) = baselines.calc_timeseries_params(current_daily, "m", "current")

In [None]:
### Light penetration proxy
# Load in bathymetry and scale to climate variable resolution
bathymetry = xa.open_dataset(
    directories.get_bathymetry_datasets_dir() / "bathymetry_A_0-00030d.nc").rio.write_crs("EPSG:4326")["bathymetry_A"]
bathymetry_climate_res = spatial_data.upsample_xa_d_to_other(bathymetry, thetao_annual_average, name = "bathymetry")



In [None]:
# Load in ERA5 surface net solar radiation and upscale to climate variable resolution
solar_radiation = xa.open_dataarray(
    directories.get_era5_data_dir() / "weather_parameters/VAR_surface_net_solar_radiation_LATS_-10_-17_LONS_142_147_YEAR_1993-2020.nc"
    ).rio.write_crs("EPSG:4326")
    
# average solar_radiation daily
solar_radiation_daily = solar_radiation.resample(time="1D").mean(dim="time")
solar_climate_res = spatial_data.upsample_xa_d_to_other(solar_radiation_daily, thetao_annual_average, name = "solar_radiation")

# annual average
solar_annual_average, _, _ = baselines.calc_timeseries_params(solar_climate_res, "y", "net_solar")
# monthly min, monthly max
_, _, (solar_monthly_min, solar_monthly_max) = baselines.calc_timeseries_params(solar_climate_res, "m", "net_solar")


In [None]:
### Load in ground truth data
gt_1000m = xa.open_dataarray(directories.get_processed_dir() / "arrays/coral_raster_1000m.nc")
# upsample to same resolution as climate (1/12 of a degree)
gt_climate_res = spatial_data.upsample_xa_d_to_other(gt_1000m, thetao_annual_average, name = "gt")

In [None]:
### Display different resolutions
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(16,9), subplot_kw=dict(projection=ccrs.PlateCarree()))

ax1 = spatial_plots.plot_spatial(gt_1000m, fax= (fig,ax_left), title="Ground truth coral presence map at 1000m")
ax2 = spatial_plots.plot_spatial(
    gt_climate_res, fax=(fig, ax_right), val_lims = (0,1), title="Ground truth coral presence map at 1/12 degree")

## Baseline Machine Learning Models

In [None]:
# combine data
all_data = xa.merge([thetao_annual_average, thetao_annual_range, thetao_monthly_min, thetao_monthly_max, 
    thetao_monthly_stdev, thetao_weekly_min, thetao_weekly_max, 
    salinity_annual_average, salinity_monthly_min, salinity_monthly_max,
    current_annual_average, current_monthly_min, current_monthly_max,
    bathymetry_climate_res, solar_annual_average, solar_monthly_min, solar_monthly_max,
    gt_climate_res
    ])
    
all_data

In [None]:
test_df = all_data.to_dataframe()
# test_df = test_df.dropna(subset=subset, how="any",axis=0)
test_df = test_df.fillna(0)

test_df.describe()

In [None]:
# drop unnecessary columns
features_df = pd.get_dummies(test_df).drop(["spatial_ref","band","depth","gt"], axis=1)
features_df

In [None]:
target_df = test_df["gt"]
target_df

In [None]:
# X_train, X_test, y_train, y_test, train_coordinates,test_coordinates, xa_masked = spatial_split_train_test(
#     all_data, "gt", split_type="pixel", test_fraction = 0.2)

# test_train_da = visualise_train_test_split(all_data, train_coordinates, test_coordinates)
bath_mask = generate_var_mask(all_data)
test_xa = plot_train_test_spatial(test_train_da, bath_mask=bath_mask)


In [None]:
run_outcomes = n_random_runs_preds(rf_random, 10, all_data)

rocs_n_runs(run_outcomes)

In [None]:
def investigate_resolution_predictions(xa_ds: xa.DataArray):
    

In [None]:
out = np.mean(spatial_data.calculate_spatial_resolution(all_data))

In [None]:
1/out

In [None]:
def investigate_label_thresholds(
    thresholds: list[float],
    y_test: np.ndarray | pd.Series,
    y_predictions: np.ndarray | pd.Series,
    figsize=[7, 7],
):
    """
    Plot ROC curves with multiple lines for different label thresholds.

    Parameters
    ----------
        thresholds (list[float]): List of label thresholds.
        y_test (np.ndarray or pd.Series): True labels.
        y_predictions (np.ndarray or pd.Series): Predicted labels.
        figsize (list, optional): Figure size for the plot. Default is [7, 7].

    Returns
    -------
        None
    """
    # TODO: fix UndefinedMetricWarning: No positive samples in y_true, true positive value should be meaningless

    f, ax = plt.subplots(figsize=figsize)
    # prepare colour assignment
    color_map = spatial_plots.get_cbar("seq")
    num_colors = len(thresholds)
    colors = [color_map(i / num_colors) for i in range(num_colors)]

    # plot ROC curves
    for c, thresh in enumerate(thresholds):
        binary_y_labels, binary_predictions = model_results.threshold_label(
            y_test, y_predictions, thresh
        )
        fpr, tpr, _ = sklmetrics.roc_curve(
            binary_y_labels, binary_predictions, drop_intermediate=False
        )
        roc_auc = sklmetrics.auc(fpr, tpr)

        label = f"{thresh:.01f} | {roc_auc:.02f}"
        ax.plot(fpr, tpr, label=label, color=colors[c])

    # format
    title = "Receiver Operating Characteristic (ROC) Curve\nfor several coral presence/absence thresholds"
    format_roc(ax=ax, title=title)
    ax.legend(title="threshold value | auc")



In [None]:
rf_random_preds = rf_random.predict(X_test)

In [None]:
investigate_label_thresholds(np.linspace(0,1,10), y_test, rf_random_preds)

In [None]:
RANDOM_STATE = 42
rf_reg = RandomForestRegressor(random_state=RANDOM_STATE)
# rf_reg.get_params()
# rf_random = RandomizedSearchCV(
#     estimator = rf_reg, param_distributions = random_grid, n_iter = 100, cv = 3, verbose=2, 
#     random_state=RANDOM_STATE, n_jobs = -1)

rf_reg.fit(X_train, y_train)
# rf_reg.fit(X_train, y_train)

In [None]:
rf_reg.get_params()

In [None]:
best_rf_params={"n_estimators":400,
"min_samples_split":2,
"min_samples_leaf":4,
"max_features":"sqrt",
"max_depth":10,
"bootstrap":True}

In [None]:
rf_random = rf_reg.set_params(**best_rf_params)

In [None]:
rf_random_search_best_params = rf_random.best_params_ 
# save best parameteers to json (in coralshift folder)
import json

with open("rf_random_search_best_params.json", "w") as fp:
    json.dump(rf_random_search_best_params, fp)

In [None]:
X_test

In [None]:
# predictions = rf_random.predict(X_test)
# np.shape(predictions)
bce = log_loss(y_true=list(y_test), y_pred=predictions)

In [None]:
from sklearn.metrics import mean_squared_error, roc_auc_score
from sklearn import metrics
# predictions = rf_random.predict(X_test)

mean_squared_error(y_test,predictions)

In [None]:
# predictions = rf_reg.predict(X_test)
# predictions
sum(y_test.where(y_test <= 0,1))

In [None]:
rf_random.best_params_

In [None]:
val = 0
for i in range(11):
    print(val, sum(y_test.where(y_test >= val, 1)))
    val += 0.1

In [None]:
out = np.array(y_test)
sum(np.where(out > 0.1, 1, 0))

In [None]:

    

rf_reg_preds = rf_reg.predict(X_test)
rf_random_preds = rf_random.predict(X_test)

In [None]:
np.where(sum(np.array(y_test)) > 0.1, 1, 0)

In [None]:
# TODO: function to do N model runs and plot resultant ROC
# TODO: test whether training/optimising on binary helps


model_results.investigate_label_thresholds(np.linspace(0,1,100), y_test, rf_reg_preds)


In [None]:
fpr, tpr, thresholds = metrics.roc_curve(y_labels_bin, predictions_bin, drop_intermediate=False)
roc_auc = metrics.auc(fpr, tpr)

display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
                            estimator_name='example estimator')
display.plot()

### Maximum Entropy Model

In [None]:
from sklearn import preprocessing

#convert y values to categorical values
lab = preprocessing.LabelEncoder()
y_transformed = lab.fit_transform(target_df)

In [None]:
# maxent = LogisticRegression(random_state=0)
# maxent.fit(features_df, y_transformed)

In [None]:
# pred_maxent = maxent.predict(features_df)

# predicted_maxent_data = xa.DataArray(pred_maxent.reshape((85,61,28)),
#     coords=all_data.coords, 
#     dims=all_data.dims)

# f,a = plt.subplots(1,2,figsize=[14,7])
# all_data["gt"].plot(ax=a[0])
# predicted_maxent_data.isel(time=-1).plot(ax=a[1]
#     , vmin=all_data["gt"].values.min(), vmax=all_data["gt"].values.max()
#     )

### Classification and Regression Trees (CART)

In [None]:
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
clf = RandomForestRegressor()
clf_lim = RandomForestRegressor()

num_vals = 50000
clf.fit(features_df, target_df)
clf_lim.fit(features_df[:num_vals], target_df[:num_vals])
pred_lim = clf_lim.predict(features_df)

In [None]:
# clf.fit(features_df, target_df)
pred = clf.predict(features_df)
predicted_data = xa.DataArray(pred.reshape((85,61,28)),
    coords=all_data.coords, 
    dims=all_data.dims).isel(time=0)

predicted_lim_data = xa.DataArray(pred_lim.reshape((85,61,28)),
    coords=all_data.coords, 
    dims=all_data.dims).isel(time=0)

In [None]:
### Compare predicted and ground truth values
spatial_plots.plot_spatial_diffs(predicted_data, gt_climate_res, figsize=(14,13))


In [None]:
spatial_plots.plot_spatial_diffs(predicted_lim_data, gt_climate_res, figsize=(14,13))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle




def train_test_random_forest(xa_ds: xa.Dataset, target_variable: str = "gt", test_fraction=0.5, random_state=None):
    # TODO: tidy up and document
    # Extract latitude, longitude, and time values from the spatial image
    lats = xa_ds.latitude.values
    lons = xa_ds.longitude.values
    times = xa_ds.time.values


    if len(xa_ds.dims) > 2:
        # Flatten the spatial image into 2D arrays for training and testing
        flattened_data = xa_ds.stack(points=("latitude", "longitude", "time")).compute().to_dataframe().fillna(0).drop(
            ["time","spatial_ref","band","depth"], axis=1).astype("float32")
    else:
        flattened_data = xa_ds.stack(points=("latitude", "longitude")).compute().to_dataframe().fillna(0).drop(
            ["time","spatial_ref","band","depth"], axis=1).astype("float32")

    features = flattened_data.drop("gt", axis=1)
    # flattened_data = np.transpose(flattened_data, axes=(1, 0))
    labels = flattened_data["gt"]

    # # Split the data into training and testing datasets
    # X_train, X_test, y_train, y_test = train_test_split(
    #     features, labels, test_size=test_fraction, random_state=random_state
    # )

    # Train the random forest regressor
    regressor = RandomForestRegressor()
    regressor.fit(X_train, y_train)

    # Predict the target variable for the testing dataset
    y_pred = regressor.predict(X_test)

    train_indices = X_train.index.values
    test_indices = X_test.index.values

    lat_spacing = xa_ds.latitude.values[1] - xa_ds.latitude.values[0]
    lon_spacing = xa_ds.longitude.values[1] - xa_ds.longitude.values[0]

    # TODO: refer to generic data_var dimension rather than calling by variable
    train_pixs = np.empty(xa_ds["thetao_y_mean"].values.shape)
    train_pixs[:] = np.nan
    test_pixs = np.empty(xa_ds["thetao_y_mean"].values.shape)
    test_pixs[:] = np.nan
    # Color the spatial pixels corresponding to training and testing regions
    for train_index in tqdm(train_indices, desc="Coloring in training indices..."):
        row, col = find_index_pair(xa_ds, train_index[0], train_index[1], lat_spacing, lon_spacing)
        # ax.add_patch(Rectangle((lons[col], lats[row]), 1, 1, facecolor=train_color, alpha=0.2))
        train_pixs[row,col] = 0

    for test_index in tqdm(test_indices, desc="Coloring in training indices..."):
        row, col = find_index_pair(xa_ds, test_index[0], test_index[1], lat_spacing, lon_spacing)
        # ax.add_patch(Rectangle((lons[col], lats[row]), 1, 1, facecolor=test_color, alpha=0.2))
        test_pixs[row,col] = 1

    return regressor, y_pred, ds


reg,random_pred,ds = train_test_random_forest(all_data.isel(time=0), target_variable=all_data["gt"], random_state=42)

In [None]:
random_pred_data = xa.DataArray(random_pred.reshape((85,61,28)),
    coords=all_data.coords,
    dims=all_data.dims)


spatial_plots.plot_spatial_diffs(random_pred_data, gt_climate_res, figsize=(14,13))



In [None]:
def plot_train_test_spatial(dataset, figsize: tuple[float,float] = (7,7)):
    """
    Plot two spatial variables from a dataset with different colors and labels.

    Parameters:
    dataset (xarray.Dataset): The dataset containing the variables.

    Returns:
    None
    """
    # Create a figure and axes
    fig, ax = plt.subplots(figsize = figsize, subplot_kw=dict(projection=ccrs.PlateCarree()))

    # Plot variable1 with color and label
    dataset["train_pixs"].plot(ax=ax, vmin=0, vmax=1, levels=2, label="train_pixs",add_colorbar=False)

    # Plot variable2 with color and label
    dataset["test_pixs"].plot(ax=ax, vmin=0, vmax=1, levels=2, label="test_pixs", add_colorbar=False)
    ax.set_aspect("equal")
    ax.coastlines(resolution="10m", color="red", linewidth=3)
    # Add a colorbar for each variable
    # cbar1 = plt.colorbar(ax=ax, mappable=dataset["train_pixs"])
    # cbar2 = plt.colorbar(ax=ax, mappable=dataset["test_pixs"])
    ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True)




plot_train_test_spatial(ds)




In [None]:


def plot_categorical_data(dataset, variable_name,  color1='red', color2='blue'):
    """
    Plot categorical spatial data from an xarray dataset.

    Parameters:
    dataset (xarray.Dataset): The dataset containing the categorical data.
    variable_name (str): The name of the variable in the dataset to plot.

    Returns:
    None
    """
    # Extract the required data variable
    variable = dataset[variable_name]

    # Get the latitude, longitude, and values arrays
    latitudes = variable.latitude.values
    longitudes = variable.longitude.values
    values = variable.values

    # Create a meshgrid from latitude and longitude arrays
    lon_mesh, lat_mesh = np.meshgrid(longitudes, latitudes)

    # Plot the categorical data
    mask1 = values == 0
    mask2 = values == 1

    # Plot the binary categorical data
    plt.pcolormesh(lon_mesh, lat_mesh, mask1, cmap='Greys', facecolor=color1)
    plt.pcolormesh(lon_mesh, lat_mesh, mask2, cmap='Greys', facecolor=color2)

    # Add colorbar and labels
    # plt.colorbar()
    plt.xlabel('Longitude')
    plt.ylabel('Latitude')
    plt.title(variable_name)
    plt.legend()
    # Show the plot
    plt.show()

plot_categorical_data(ds,"test_train")


In [None]:
ds

In [None]:
all_data

In [None]:


predicted_lim_data = xa.DataArray(pred_lim.reshape((85,61,28)),
    # coords=all_data.coords, 
    dims=all_data.dims)

f,a = plt.subplots(1,2,figsize=[14,7])
predicted_data.isel(time=0).plot(ax=a[0])
predicted_lim_data.isel(time=0).plot(ax=a[1], vmin=predicted_data.min(), vmax=predicted_data.max())

In [None]:
# TODO: training and testing on subsamples of data (train_test_split for linear, somehow something spatial...)
# add in bathymetry
# try binary (classifier)
# function to plot difference between predicted and true
# hyperparameter tuning

### Boosted Regression Trees (BRT)

https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html

In [None]:
X_train_t, X_test_t, y_train_t, y_test_t = train_test_split(
    features_df, target_df, test_size=0.1, random_state=13
)

params = {
    "n_estimators": 500,
    "max_depth": 4,
    "min_samples_split": 5,
    "learning_rate": 0.01,
    "loss": "squared_error",
}

In [None]:
X_train_t

In [None]:
reg = ensemble.GradientBoostingRegressor(**params)
reg.fit(X_train, y_train)

mse = mean_squared_error(y_test, reg.predict(X_test))
print("The mean squared error (MSE) on test set: {:.4f}".format(mse))

In [None]:
test_score = np.zeros((params["n_estimators"],), dtype=np.float64)
for i, y_pred in enumerate(reg.staged_predict(X_test)):
    test_score[i] = mean_squared_error(y_test, y_pred)

fig = plt.figure(figsize=(6, 6))
plt.subplot(1, 1, 1)
plt.title("Deviance")
plt.plot(
    np.arange(params["n_estimators"]) + 1,
    reg.train_score_,
    "b-",
    label="Training Set Deviance",
)
plt.plot(
    np.arange(params["n_estimators"]) + 1, test_score, "r-", label="Test Set Deviance"
)
plt.legend(loc="upper right")
plt.xlabel("Boosting Iterations")
plt.ylabel("Deviance")
fig.tight_layout()
plt.show()

In [None]:
feature_importance = reg.feature_importances_
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0]) + 0.5
fig = plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.barh(pos, feature_importance[sorted_idx], align="center")
plt.yticks(pos, np.array(features_df.columns)[sorted_idx])
plt.title("Feature Importance (MDI)")

result = permutation_importance(
    reg, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)
sorted_idx = result.importances_mean.argsort()
plt.subplot(1, 2, 2)
plt.boxplot(
    result.importances[sorted_idx].T,
    vert=False,
    labels=np.array(features_df.columns)[sorted_idx],
)
plt.title("Permutation Importance (test set)")
fig.tight_layout()
plt.show()

In [None]:
gbr_pred = reg.predict(features_df)

In [None]:

predicted_gbr_data = xa.DataArray(gbr_pred.reshape((85,61,28)),
    coords=all_data.coords, 
    dims=all_data.dims).isel(time=0)


spatial_plots.plot_spatial_diffs(predicted_gbr_data, gt_climate_res, figsize=(14,13))



# f,a = plt.subplots(1,2,figsize=[14,7])
# all_data["gt"].plot(ax=a[0])
# predicted_gbr_data.isel(time=-1).plot(ax=a[1]
#     , vmin=all_data["gt"].values.min(), vmax=all_data["gt"].values.max()
#     )