In [7]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append("../")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
# choose whether to work on a remote machine
location = "remote"
if location == "remote":
    # change this line to the where the GitHub repository is located
    os.chdir("/lustre_scratch/orlando-code/coralshift/")

In [9]:
# import relevant packages

from __future__ import annotations

from pathlib import Path
import xarray as xa
import numpy as np
# import math as m
# import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import wandb
from tqdm import tqdm
from sklearn import model_selection
from sklearn.preprocessing import normalize
from scipy.interpolate import interp2d
from sklearn.utils import class_weight
from scipy.ndimage import gaussian_gradient_magnitude
import xbatcher

# import rasterio
# from rasterio.plot import show
# import rioxarray as rio

# from bs4 import BeautifulSoup
# import requests


#issues with numpy deprecation in pytorch_env
from coralshift.processing import spatial_data
from coralshift.utils import file_ops, directories
from coralshift.plotting import spatial_plots, model_results
from coralshift.dataloading import data_structure, climate_data, bathymetry, reef_extent

In [10]:
# choose resolution (should be above 1000m for processing in decent time)
target_resolution_m, target_resolution_d = spatial_data.choose_resolution(
    resolution=1000, unit="m")



# Combining data and processing to be ML model-ready

In [11]:
def split_nc_files(dir_path, new_dir: str = None, incl_subdirs: bool=False):
    nc_paths = file_ops.return_list_filepaths(dir_path, ".nc", incl_subdirs=incl_subdirs)
    save_dir = nc_paths[0].parent

    if new_dir:
        save_dir = file_ops.guarantee_existence(save_dir / new_dir)

    for path in nc_paths:
        split_nc_file(path, save_dir)

def split_nc_file(path, save_dir):
    # Load the NetCDF file as an xarray dataset
    path = Path(path)

    dataset = xa.open_dataset(path)

    # Check if the loaded object is an xarray dataset
    # if isinstance(dataset, xa.DataArray):
    if len(list(dataset.data_vars)) == 1:
        # print(f"{path} is already a xa.DataArray. No further action taken")
        # save copy to new folder
        new_file_path = (save_dir / list(dataset.data_vars)[0]).with_suffix(".nc")
        if not new_file_path.is_file():
            dataset.to_netcdf(new_file_path)
        else:
            print(f"{new_file_path} already exists.")
        return None

    # Get the folder path and file name
    folder_path, file_name = path.parent, path.stem

    # Iterate over each variable in the dataset
    for variable in tqdm(dataset.data_vars, desc = f" Unpacking data_vars in {path}"):
        # Get the data array for the variable
        data_array = dataset[variable]

        # Create a new file name based on the variable name
        new_file_name = f"{variable}.nc"

        # Create the new file path
        new_file_path = (save_dir / new_file_name).with_suffix(".nc")
        if not new_file_path.is_file():
            # Write the new dataset to a NetCDF file
            data_array.to_netcdf(new_file_path)
        else:
            print(f"{new_file_path} already exists.")
    

# def xa_ds_to_tensor(nc_path):
    


In [12]:
processed_path = directories.get_processed_dir()
path = directories.get_processed_dir() / "combined_daily_1-12.nc"
# split_nc_file(path)
split_nc_files(processed_path, "arrays")

 Unpacking data_vars in /lustre_scratch/orlando-code/datasets/processed/combined_daily_1-12.nc:   0%|          | 0/9 [00:00<?, ?it/s]

/lustre_scratch/orlando-code/datasets/processed/arrays/bathymetry_A.nc already exists.
/lustre_scratch/orlando-code/datasets/processed/arrays/bottomT.nc already exists.
/lustre_scratch/orlando-code/datasets/processed/arrays/mlotst.nc already exists.
/lustre_scratch/orlando-code/datasets/processed/arrays/so.nc already exists.
/lustre_scratch/orlando-code/datasets/processed/arrays/thetao.nc already exists.
/lustre_scratch/orlando-code/datasets/processed/arrays/uo.nc already exists.
/lustre_scratch/orlando-code/datasets/processed/arrays/vo.nc already exists.
/lustre_scratch/orlando-code/datasets/processed/arrays/zos.nc already exists.


 Unpacking data_vars in /lustre_scratch/orlando-code/datasets/processed/combined_daily_1-12.nc: 100%|██████████| 9/9 [00:01<00:00,  6.86it/s]

/lustre_scratch/orlando-code/datasets/processed/arrays/bathymetry_A.nc already exists.





In [13]:
nc_files = file_ops.return_list_filepaths(processed_path / "arrays", ".nc")
# xa.open_dataarray(processed_path / "test/so.nc").isel(time=0).plot()
# upsample_xarray_to_target(xa.open_dataset(nc_files[8]), target_resolution_d)

NameError: name 'upsample_xarray_to_target' is not defined

In [14]:
def ncfile_to_resolution_tensor(ncfile, target_resolution_d: float = target_resolution_d, name="") -> tf.Tensor:

    da = xa.load_dataarray(ncfile, engine="netcdf4")  # may not need type
    # resolve
    upsampled_ds = spatial_data.upsample_xarray_to_target(da, target_resolution_d)


    # get data as tensor
    return tf.convert_to_tensor(upsampled_ds.data, name=Path(ncfile).stem)   # not sure about name here


def ncfiles_to_tf_dataset(ncfiles, target_resolution_d: float, name: str = None) -> tf.data.Dataset:
    # gen = lambda: (ncfile_to_resolution_tensor(x, target_resolution_d, name) for x in ncfiles)
    gen = lambda: (ncfile_to_resolution_tensor(x) for x in ncfiles)

    return tf.data.Dataset.from_generator(gen, output_types=tf.float32, name="asdf")


In [2]:
import numpy as np

In [5]:
np.arange(0,3+0.1,0.1)

array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. , 1.1, 1.2,
       1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1, 2.2, 2.3, 2.4, 2.5,
       2.6, 2.7, 2.8, 2.9, 3. ])

In [15]:
def generate_interp_lat_lons(xa_high_res, lat_lims: tuple=None, 
    lon_lims: tuple=None, target_resolution_d = float: None):
    """
    Generate latitude and longitude values for interpolation based on the specified limits or target resolution.

    Args:
        xa_high_res: The high-resolution xarray dataset from which to generate latitude and longitude values.
        lat_lims: A tuple of latitude limits (min, max). If None, the latitude values from `xa_high_res` will be used.
        lon_lims: A tuple of longitude limits (min, max). If None, the longitude values from `xa_high_res` will be used.
        target_resolution_d: The target resolution for latitude and longitude values. If None, the original values from 
            `xa_high_res` will be used.

    Returns:
        lat_vals: The generated latitude values for interpolation.
        lon_vals: The generated longitude values for interpolation.
    """
    if not (lat_lims or lon_lims or target_resolution_d):
        lat_vals, lon_vals = xa_high_res["latitude"], xa_high_res["longitude"]
    else:
        lat_vals = np.arange(min(lat_lims), max(lat_lims+target_resolution_d), target_resolution_d)
        lon_vals = np.arange(min(lon_lims), max(lon_lims+target_resolution_d), target_resolution_d)
    return lat_vals, lon_vals

lat_vals, lon_vals = generate_interp_lat_lons(xa.open_dataarray(processed_path / "arrays/coral_raster_1000m.nc"))

In [19]:
def downsample_interp(xa_d, lat_vals, lon_vals, interp_method: str):
   return xa_d.interp(latitude=lat_vals, longitude=lon_vals, method=interp_method)

# downsample_interp(xa.open_dataarray(processed_path / "arrays/bottomT.nc"), lat_vals, lon_vals, "nearest")

In [21]:
def load_upsample_nc_dir(nc_dir: Path | str, 
    lat_vals: tuple[float], lon_vals: tuple[float], interp_method: str = "nearest"):
    """
    Load and downsample NetCDF files from a directory using TensorFlow and xarray.

    Parameters
    ----------
        nc_dir (str): The directory path containing the NetCDF files.
        lat_vals (array-like): Array-like object containing the latitude values.
        lon_vals (array-like): Array-like object containing the longitude values.
        interp_method (str, optional): The interpolation method to use during downsampling.
            Defaults to "nearest".

    Returns
    -------
        tf.data.Dataset: A TensorFlow dataset containing the downsampled data arrays.
    """
    # Define the function to open a NetCDF file and perform downsampling
    def open_path(path_tensor: tf.Tensor):
        # Open the NetCDF file using xarray and convert to float32 dtype
        ds = xa.open_dataarray(path_tensor.numpy().decode()).astype("float32")
        # Perform downsampling using the downsample_interp function via specified interp_method
        downsampled_ds = downsample_interp(ds, lat_vals, lon_vals, interp_method)
        # downsampled_ds = ds
        # Convert the downsampled data to a TensorFlow tensor
        data_tensor = tf.convert_to_tensor(downsampled_ds.data)      
        return tf.identity(data_tensor, name="asdf")

    file_dataset = tf.data.Dataset.list_files(str(Path(nc_dir).joinpath("*.nc")))
    # Map the open_path function to each file path in the dataset
    return file_dataset.map(
        lambda path: tf.py_function(open_path, [path], Tout=tf.float32))

In [22]:
test_tf_ds = load_upsample_nc_dir(processed_path / "arrays", lat_vals, lon_vals)
test_tf_ds.element_spec

TensorSpec(shape=<unknown>, dtype=tf.float32, name=None)

In [23]:
[batch.shape for batch in test_tf_ds]

HDF5-DIAG: Error detected in HDF5 (1.12.2) thread 1:
  #000: H5A.c line 528 in H5Aopen_by_name(): can't open attribute
    major: Attribute
    minor: Can't open object
  #001: H5VLcallback.c line 1091 in H5VL_attr_open(): attribute open failed
    major: Virtual Object Layer
    minor: Can't open object
  #002: H5VLcallback.c line 1058 in H5VL__attr_open(): attribute open failed
    major: Virtual Object Layer
    minor: Can't open object
  #003: H5VLnative_attr.c line 130 in H5VL__native_attr_open(): can't open attribute
    major: Attribute
    minor: Can't open object
  #004: H5Aint.c line 545 in H5A__open_by_name(): unable to load attribute info from object header
    major: Attribute
    minor: Unable to initialize object
  #005: H5Oattribute.c line 476 in H5O__attr_open_by_name(): can't open attribute
    major: Attribute
    minor: Can't open object
  #006: H5Adense.c line 394 in H5A__dense_open(): can't locate attribute in name index
    major: Attribute
    minor: Object not 

[TensorShape([780, 557, 9863]),
 TensorShape([780, 557, 9863]),
 TensorShape([780, 557]),
 TensorShape([780, 557, 9863]),
 TensorShape([780, 557, 9863]),
 TensorShape([780, 557, 9863]),
 TensorShape([780, 557]),
 TensorShape([780, 557, 9863]),
 TensorShape([780, 557, 9863]),
 TensorShape([780, 557, 9863]),
 TensorShape([780, 557, 9863])]

In [None]:
test_tf_ds.element_spec

In [None]:
tf_dataset = ncfiles_to_tf_dataset(nc_files, name="upsampled", target_resolution_d=target_resolution_d)

In [None]:
tf_dataset.element_spec

In [None]:
[batch.shape[0] for batch in test_tf_ds]

In [None]:
tf_dataset.element_spec

In [None]:
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)

def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')

titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)

In [None]:
ds_man = data_structure.MyDatasets()

In [None]:
chunks = {"latitude": 10, "longitude": 10, "time": 100}

# load in climate data (daily): TODO: spatial_padded_1.nc needs renaming and putting with dailies. Get rid of coral_climate_1_12.nc and replace with bathA area only
climate_daily_1_12_padded = spatial_data.process_xa_d(xa.open_dataset(directories.get_monthly_cmems_dir() / 
        # ds_man.get_location() / "global_ocean_reanalysis/monthly_means/coral_climate_1_12.nc")
        "spatial_padded_1.nc"), squeeze_coords=False)

# load in reef ground truth
coral_gt = spatial_data.process_xa_d(xa.open_dataset(directories.get_monthly_cmems_dir() / 
        # ds_man.get_location() / "global_ocean_reanalysis/monthly_means/coral_climate_1_12.nc")
        "coral_climate_1_12.nc"), squeeze_coords=False)["coral_algae_1-12_degree"].sel({"latitude": slice(-17.01,-9.99), "longitude": slice(141.99, 147.01)}).isel(time=0)
        
# load in bathymetry and upsample to climate
bath_area = xa.open_dataset(
    directories.get_bathymetry_datasets_dir() / "bathymetry_A_0-08333_upsampled.nc").squeeze("band")

In [None]:
def add_data_to_chunk(xa_chunk_data, xa_new_data, interp_method: str="nearest", broadcast: bool=True):
    # TODO: add in checking about relative resolutions
    interped = xa_new_data.interp(
        latitude=xa_chunk_data["latitude"], longitude=xa_chunk_data["longitude"], method=interp_method)
    # TODO: issue with time broadcasting?
    merged_ds = xa.combine_by_coords([xa_chunk_data, interped], coords=['latitude', 'longitude', 'time'], join="inner")
    if broadcast:
        (merged_ds,) = xa.broadcast(merged_ds)
    return merged_ds
    # save file

In [None]:
# combine arrays
gt_climate = add_data_to_chunk(climate_daily_1_12_padded, coral_gt)
all_data = add_data_to_chunk(gt_climate, bath_area)

xa_all_1_12 = file_ops.save_nc(
    directories.get_processed_dir(), "combined_daily_1-12", all_data, return_array=True)
# climate_monthly_1_12_padded
# coral_gt
# bath_area


In [None]:
xa_all_1_12.isel(time=0)["coral_algae_1-12_degree"].plot()

In [None]:
# climate_monthly_1_12_padded
coral_gt

## Process to numpy array for machine learning

In [None]:
def generate_patch(
    xa_ds: xa.DataArray | xa.Dataset,
    lat_lon_starts: tuple[float, float],
    coord_lengths: tuple[float, float] = None,
    coord_range: tuple[int, int] = None,
    feature_vars: list[str] = ["bottomT", "so", "mlotst", "uo", "vo", "zos", "thetao"],
    gt_var: str = "coral_algae_1-12_degree",
    normalise: bool = True,
    onehot: bool = True,
) -> tuple[np.ndarray, xa.Dataset | xa.DataArray, dict]:
    """Generate a patch for training or evaluation.
    Parameters
    ----------
    xa_ds (xa.Dataset): The input xarray dataset.
    lat_lon_starts (tuple): The starting latitude and longitude indices for sampling the patch.
    coord_lengths (tuple): The latitude and longitude range for sampling the patch.
    feature_vars (list[str], optional): List of variable names to be used as features.
        Default is ["bottomT", "so", "mlotst", "uo", "vo", "zos", "thetao"].
    gt_var (str, optional): The variable name for the ground truth. Default is "coral_algae_1-12_degree".
    normalise (bool, optional): Flag indicating whether to normalize each variable between 0 and 1. Default is True.
    onehot (bool, optional): Flag indicating whether to encode NaN values using the one-hot method. Default is True.

    Returns
    -------
    tuple[np.ndarray, xa.Dataset | xa.DataArray, dict]: A tuple containing the feature array, ground truth array,
        subsampled dataset, and latitude/longitude values.
    """
    subsample, lat_lon_vals_dict = sample_spatial_batch(
        xa_ds,
        lat_lon_starts=lat_lon_starts,
        coord_lengths=coord_lengths,
        coord_range=coord_range,
    )

    output = process_xa_ds_for_ml(
        xa_ds=subsample,
        feature_vars=feature_vars,
        gt_var=gt_var,
        normalise=normalise,
        onehot=onehot,
    )

    return output, subsample, lat_lon_vals_dict

def sample_spatial_batch(
    xa_d: xa.Dataset,
    lat_lon_starts: tuple = (0, 0),
    coord_lengths: tuple[int, int] = (6, 6),
    coord_range: tuple[float] = None,
    variables: list[str] = None,
) -> np.ndarray:
    """Sample a spatial batch from an xarray Dataset.
    Parameters
    ----------
    xa_d (xa.Dataset): The input xarray Dataset.
    lat_lon_starts (tuple): Tuple specifying the starting latitude and longitude indices of the batch.
    coord_range (tuple[int, int]): Tuple specifying the dimensions (number of cells) of the spatial window.
    coord_range (tuple[float], optional): Tuple specifying the latitude and longitude range (in degrees) of the spatial
        window. If provided, it overrides the coord_range parameter.
    variables (list[str], optional): List of variable names to include in the spatial batch. If None, includes all
        variables.

    Returns
    -------
    np.ndarray: The sampled spatial batch as a NumPy array.
    """
    # if selection of variables specified
    if variables is not None:
        xa_d = xa_d[variables]

    subsample = xa_region_from_window(xa_d, lat_lon_starts, coord_lengths, coord_range)

    # # N.B. have to be careful when providing coordinate ranges for areas with negative coords. TODO: make universal
    # lat_start, lon_start = lat_lon_starts[0], lat_lon_starts[1]

    # if not coord_range:
    #     subsample = xa_d.isel(
    #         {
    #             "latitude": slice(lat_start, coord_range[0]),
    #             "longitude": slice(lon_start, coord_range[1]),
    #         }
    #     )
    # else:
    #     lat_cells, lon_cells = coord_range[0], coord_range[1]
    #     subsample = xa_d.sel(
    #         {
    #             "latitude": slice(lat_start, lat_start + lat_cells),
    #             "longitude": slice(lon_start, lon_start + lon_cells),
    #         }
    #     )

    lat_slice = subsample["latitude"].values
    lon_slice = subsample["longitude"].values
    time_slice = subsample["time"].values

    return subsample, {
        "latitude": lat_slice,
        "longitude": lon_slice,
        "time": time_slice,
    }


def xa_region_from_window(
    xa_d: xa.Dataset | xa.DataArray,
    lat_lon_starts: tuple = (0, 0),
    coord_lengths: tuple[int, int] = (6, 6),
    coord_range: tuple[float, float] = None,
):
    """Sample a spatial batch from an xarray Dataset or DataArray.
    # TODO: check this works
    Parameters
    ----------
        xa_d (xarray.Dataset or xarray.DataArray): The input Dataset or DataArray to sample a spatial batch from.
        lat_lon_starts (tuple, optional): A tuple containing the starting latitude and longitude coordinates.
            Default is (0, 0).
        coord_lengths (tuple[int, int], optional): A tuple specifying the number of latitude and longitude cells to
            include in the spatial batch. Default is (6, 6).
        coord_range (tuple[float, float], optional): A tuple containing the latitude and longitude range in degrees. If
            provided, `coord_lengths` will be ignored, and the range will be used to determine the spatial batch size.
            Default is None.

    Returns
    -------
        xarray.Dataset or xarray.DataArray: The sampled spatial batch from `xa_d` based on the provided parameters.
    """
    lat_start, lon_start = lat_lon_starts[0], lat_lon_starts[1]
    if coord_range:
        lat_cells, lon_cells = coord_range[0], coord_range[1]
        lat_range = (lat_start, lat_start + lat_cells)
        lon_range = (lon_start, lon_start + lon_cells)
        subsample = xa_d.sel(
            {
                "latitude": slice(min(lat_range), max(lat_range)),
                "longitude": slice(min(lon_range), max(lon_range)),
            }
        )
    else:
        subsample = xa_d.isel(
            {
                "latitude": slice(lat_start, coord_lengths[0]),
                "longitude": slice(lon_start, coord_lengths[1]),
            }
        )
    return subsample


def process_xa_ds_for_ml(
    xa_ds: xa.Dataset,
    feature_vars: list[str] = None,
    gt_var: str = None,
    normalise: bool = True,
    onehot: bool = True,
) -> tuple[np.ndarray, ...]:
    """
    Process xarray Dataset for machine learning.

    Parameters
    ----------
    xa_ds : xa.Dataset
        The input xarray dataset.
    feature_vars : list[str], optional
        List of variable names to be used as features. Default is None.
    gt_var : str, optional
        The variable name for the ground truth. Default is None.
    normalise : bool, optional
        Flag indicating whether to normalize each variable between 0 and 1. Default is True.
    onehot : bool, optional
        Flag indicating whether to encode NaN values using the one-hot method. Default is True.

    Returns
    -------
    tuple[np.ndarray, ...]
        A tuple containing the feature array and ground truth array.
    """
    to_return = []
    if feature_vars is not None:
        # switch
        # assign features and convert to lat, lon to latxlon column

        Xs = spatial_array_to_column(xa_d_to_np_array(xa_ds[feature_vars]))

        # if normalise = True, normalise each variable between 0 and 1
        if normalise:
            Xs = normalise_3d_array(Xs)
        # remove columns containing only nans. TODO: enable removal of all nan dims
        nans_array = exclude_all_nan_dim(Xs, dim=1)

        # if encoding nans using onehot method
        if onehot:
            Xs = encode_nans_one_hot(nans_array)
        to_return.append(naive_nan_replacement(Xs))

    if gt_var:
        # assign ground truth and convert to column vector
        ys = spatial_array_to_column(xa_d_to_np_array(xa_ds[gt_var]))
        # ys = naive_nan_replacement(ys)
    
        # take single time slice (since broadcasted back through time)
        # ys = ys[:, 0]
        to_return.append(naive_nan_replacement(ys))

    return tuple(to_return)

In [None]:
# train area
(train_Xs_lstm, train_ys_lstm), train_subsample, train_lat_lons_vals_dict = spatial_data.generate_patch(xa_ds=xa_all_1_12, lat_lon_starts=(-10,142), coord_range=(-3,3), onehot=True)
# test area
(test_Xs_lstm, test_ys_lstm), test_subsample, test_lat_lons_vals_dict = spatial_data.generate_patch(xa_ds=xa_all_1_12, lat_lon_starts=(-14,145), coord_range=(-1,1), onehot=True)

print("train_Xs_lstm shape: ", train_Xs_lstm.shape, "train_ys_lstm shape: ", train_ys_lstm.shape)
print("test_Xs_lstm shape: ", test_Xs_lstm.shape, "test_ys_lstm shape: ", test_ys_lstm.shape)

In [None]:
train_Xs_lstm_shape = train_Xs_lstm.shape
train_ys_lstm_shape = train_ys_lstm.shape
test_Xs_lstm_shape = test_Xs_lstm.shape
test_ys_lstm_shape = test_ys_lstm.shape

train_Xs_convlstm = train_Xs_lstm.reshape(
    int(np.sqrt(train_Xs_lstm_shape[0])), int(np.sqrt(train_Xs_lstm_shape[0])),train_Xs_lstm_shape[1],train_Xs_lstm_shape[2])
train_ys_convlstm = train_ys_lstm.reshape(
    int(np.sqrt(train_ys_lstm_shape[0])), int(np.sqrt(train_ys_lstm_shape[0])), train_ys_lstm_shape[1])

test_Xs_convlstm = test_Xs_lstm.reshape(
    int(np.sqrt(test_Xs_lstm_shape[0])), int(np.sqrt(test_Xs_lstm_shape[0])),test_Xs_lstm_shape[1],test_Xs_lstm_shape[2])
test_ys_convlstm = test_ys_lstm.reshape(
    int(np.sqrt(test_ys_lstm_shape[0])), int(np.sqrt(test_ys_lstm_shape[0])), test_ys_lstm_shape[1])


print("train_Xs_convlstm shape: ", train_Xs_convlstm.shape, "train_ys_convlstm shape: ", train_ys_convlstm.shape)
print("test_Xs_convlstm shape: ", test_Xs_convlstm.shape, "test_ys_convlstm shape: ", test_ys_convlstm.shape)


In [None]:
test_subsample["bathymetry_A"].isel(time=0).plot()

In [None]:
spatial_Xs = train_Xs.reshape((72,72,336,8))
spatial_ys = train_ys.reshape((72,72))

print("spatial_Xs shape: ", spatial_Xs.shape)
print("spatial_ys shape: ", spatial_ys.shape)