In [7]:
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path.append("../")

# choose whether to work on a remote machine
location = "remote"
if location == "remote":
    # change this line to the where the GitHub repository is located
    os.chdir("/lustre_scratch/orlando-code/coralshift/")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [108]:
from __future__ import annotations

from pathlib import Path
import xarray as xa
import numpy as np
import math as m
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import model_selection
from scipy.interpolate import interp2d

import rasterio
from rasterio.plot import show
import rioxarray as rio

from coralshift.utils import directories
from coralshift.processing import spatial_data

## Data Derivation

In [21]:
thetao_daily

In [101]:
def return_time_grouping_offset(period: str):
    if period.lower() in ["year", "y", "annual"]:
        group = "time.year"
        offset = "AS"
    elif period.lower() in ["month", "m"]:
        group = "time.month"
        offset = "MS"
    elif period.lower() in ["week", "w"]:
        group = "time.week"
        offset = "W"

    return group, offset


def calc_weighted_mean(xa_da_daily_means: xa.DataArray, period: str):

    group, offset = return_time_grouping_offset(period)
    # Determine the month length (has no effect on other time periods)
    month_length = xa_da_daily_means.time.dt.days_in_month
    # Calculate the monthly weights
    weights = month_length.groupby(group) / month_length.groupby(group).sum()

    # Setup our masking for nan values
    ones = xa.where(xa_da_daily_means.isnull(), 0.0, 1.0)

    # Calculate the numerator
    xa_da_daily_means_sum = (xa_da_daily_means * weights).resample(time=offset).sum(dim="time")
    # Calculate the denominator
    ones_out = (ones * weights).resample(time=offset).sum(dim="time")

    # weighted average
    return xa_da_daily_means_sum / ones_out


def calc_timeseries_params(xa_da_daily_means: xa.DataArray, period: str):
    """
    weight by days in each month
    """ 

    # weighted average
    weighted_av = calc_weighted_mean(xa_da_daily_means, period)
    # standard deviation of weighted averages
    stdev = weighted_av.std(dim="time", skipna=True)
    # max and min
    min, max = xa_da_daily_means.min(dim="time", skipna=True), xa_da_daily_means.max(dim="time", skipna=True)
    
    # Return the weighted average
    return weighted_av, stdev, (min, max)


def calculate_january_july_std(xa_da_daily_means: xa.DataArray):
    # N.B. weighting by month
    monthly_means = xa_da_daily_means.groupby('time.month').mean(dim='time')
    january_std = monthly_means.sel(month=1).std(dim='time')
    july_std = monthly_means.sel(month=7).std(dim='time')
    return january_std, july_std


def calculate_magnitude(horizontal_data, vertical_data) -> xa.DataArray:
    func = lambda horizontal_data, vertical_data: np.sqrt(horizontal_data**2 + vertical_data**2)
    return xa.apply_ufunc(func, horizontal_data, vertical_data)

In [93]:
### SST (sea water potential temperature)
# load in daily sea water potential temp
thetao_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/thetao.nc")

# annual average, stdev of annual averages, annual minimum, annual maximum
thetao_annual_average, thetao_annual_stdev, (thetao_annual_min, thetao_annual_max) = calc_timeseries_params(thetao_daily, "y")
# monthly average, stdev of monthly averages, monthly minimum, monthly maximum
thetao_monthly_average, thetao_monthly_stdev, (thetao_monthly_min, thetao_monthly_max) = calc_timeseries_params(thetao_daily, "m")
# annual range (monthly max - monthly min)
thetao_yearly_range = annual_max - annual_min
# weekly minimum, weekly maximum
_, _, (thetao_weekly_min, thetao_weekly_max) = calc_timeseries_params(thetao_daily, "w")


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


In [None]:
# standard deviation of january and july means TODO FIX
january_std, july_std = calculate_january_july_std(thetao_daily)

In [105]:
### Salinity
# load in daily sea water salinity means
salinity_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/so.nc")

# annual average
salinity_annual_average, _, _ = calc_timeseries_params(salinity_daily, "y")
# monthly min, monthly max
_, _, (salinity_monthly_min, salinity_monthly_max) = calc_timeseries_params(salinity_daily, "m")

  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


In [106]:
### Current speed (dot product of horizontal and vertical)
# load in daily currents (longitudinal and latitudinal)
uo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/uo.nc")
vo_daily = xa.open_dataarray(directories.get_processed_dir() / "arrays/vo.nc")
# calculate current magnitude
current_daily = calculate_magnitude(uo_daily, vo_daily)

# annual average
current_annual_average, _, _ = calc_timeseries_params(current_daily, "y")
# monthly min, monthly max
_, _, (current_monthly_min, current_monthly_max) = calc_timeseries_params(current_daily, "m")

  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


In [112]:
### Load in ground truth data at same resolution as climate (1/12 of a degree)
gt_1000m = xa.open_dataarray(directories.get_processed_dir() / "arrays/coral_raster_1000m.nc")
gt_climate_res = spatial_data.upsample_xarray_to_target(gt_1000m, 1/12)

## Baseline Machine Learning Models

### Maximum Entropy Model

### Classification and Regression Trees (CART)

### Boosted Regression Trees (BRT)

In [None]:
# 