In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path.append("../")
# TODO: hacky, shouldn't be necessary
os.chdir("/lustre_scratch/orlando-code/coralshift/")
os.environ["WANDB_NOTEBOOK_NAME"] = "lustre_scratch/coralshift/notebooks/rnn.ipynb"

In [5]:
from __future__ import annotations

from pathlib import Path
import xarray as xa
import numpy as np
import math as m
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import wandb
from tqdm import tqdm
from sklearn import model_selection
from sklearn.preprocessing import normalize
from scipy.interpolate import interp2d
from sklearn.utils import class_weight


import rasterio
from rasterio.plot import show
import rioxarray as rio

#issues with numpy deprecation in pytorch_env
from coralshift.processing import spatial_data
from coralshift.utils import file_ops, directories
from coralshift.plotting import spatial_plots
from coralshift.dataloading import data_structure, climate_data

## Load in data

In [6]:
ds_man = data_structure.MyDatasets()

# add datasets
ds_man.set_location("remote")

ds_man.add_dataset(
    "monthly_climate_1_12", xa.open_dataset(
        ds_man.get_location() / "global_ocean_reanalysis/monthly_means/coral_climate_1_12.nc")
)

coral_climate_feature_vars = list(
    set(ds_man.get_dataset("monthly_climate_1_12").data_vars) - {'spatial_ref', 'coral_algae_1-12_degree', 'output'})
ds_man.add_dataset(
    "monthly_climate_features", ds_man.get_dataset("monthly_climate_1_12")[coral_climate_feature_vars]
)

# TODO: sort numpy assignment with new one-hot encoding and noormalisation
ds_man.add_dataset(
    "monthly_climate_1_12_X_y_np", spatial_data.filter_out_nans(
        spatial_data.xa_ds_to_3d_numpy(ds_man.get_dataset("monthly_climate_1_12")), 
        np.array(ds_man.get_dataset("monthly_climate_1_12")["coral_algae_1-12_degree"].isel(time=-1)).reshape(-1, 1))
)

ds_man.add_dataset(
    "monthly_climate_1_12_X_np", ds_man.get_dataset("monthly_climate_1_12_X_y_np")[0]
)

ds_man.add_dataset(
    "monthly_climate_1_12_y_np", ds_man.get_dataset("monthly_climate_1_12_X_y_np")[1]
)

# TODO: handle depth
ds_man.add_dataset(
    "daily_climate_1_12", xa.open_dataset(
        Path(ds_man.get_location() / "global_ocean_reanalysis/daily_means/dailies_combined.nc")).isel(depth=0)
)

ds_man.add_dataset(
    "daily_climate_1_12_X", 
)

# same target as monthly
ds_man.add_dataset(
    "daily_climate_1_12_y_np", ds_man.get_dataset("monthly_climate_1_12_y_np")
)

ds_man.add_dataset(
    "bathymetry_A", rio.open_rasterio(
        rasterio.open(ds_man.get_location() / "bathymetry/GBR_30m/Great_Barrier_Reef_A_2020_30m_MSL_cog.tif"),
        ).rename("bathymetry_A").rename({"x": "longitude", "y": "latitude"})
)

converting xarray Dataset to numpy arrays: 100%|██████████| 13/13 [00:00<00:00, 17453.89it/s]


TypeError: MyDatasets.add_dataset() missing 1 required positional argument: 'data'

In [7]:
ds_daily_1_12 = ds_man.get_dataset("daily_climate_1_12")
ds_daily_1_12

In [None]:
# daily climate to numpy array

In [None]:
monthly_climate = ds_man.get_dataset("monthly_climate_1_12")
nan_eg, _ = sample_spatial_batch(monthly_climate, window_dims=(10,10))

In [None]:
nan_eg

In [None]:
nan_eg["bottomT"].isel(time=-1).plot()

In [None]:
array = spatial_data.xa_ds_to_3d_numpy(nan_eg, 
    exclude_vars = ["spatial_ref", "coral_algae_1-12_degree", "latitude", "longitude", "depth", "time"])


In [None]:
one_hot_nans = spatial_data.encode_nans_one_hot(array).shape

In [None]:
# TODO: fix boolean indexing
# # For now, shallowest depth is taken (0.45)
# # TODO: process this and export it to new file since takes a while to run
# ds_man.add_dataset(
#     "daily_climate_1_12_X_np", filter_out_nans(
#         spatial_data.xa_ds_to_3d_numpy(ds_man.get_dataset("daily_climate_1_12").isel(depth=0)), ds_man.get_dataset("daily_climate_1_12_y_np"))[0]
# )

In [None]:
# TODO: put this merge into data processing pipeline
# merge daily mean files
# var_daily_dir = Path("lustre_scratch/datasets/global_ocean_reanalysis/daily_means")
# save_combined_dailies_path = Path("lustre_scratch/datasets/global_ocean_reanalysis/daily_means/dailies_combined.nc")
# daily_file_paths = file_ops.return_list_filepaths(var_daily_dir, ".nc")
# combined_dailies = xa.open_mfdataset(daily_file_paths)
# combined_dailies.to_netcdf(save_combined_dailies_path)

In [None]:
# # create 3D array from xarray dataset variables. Shape: (num_samples, num_parameters, sequence_len)
# X_with_nans = spatial_data.xa_ds_to_3d_numpy(xa_coral_climate_1_12_features)
# print(f'X_with_nans shape (num_samples: {X_with_nans.shape[0]}, total num_parameters (includes nans parameters): {X_with_nans.shape[1]}, sequence_len: {X_with_nans.shape[2]})')

# for i, param in enumerate(xa_coral_climate_1_12_features.data_vars):
#     print(f"{i}: {param}")

#### Remove observations for which there are nan values

99% sure these are are just gridcells containing land. Would be a good thing to investigate, however.

In [None]:
# X = X_with_nans
# # problem, probably with sea ice features

In [None]:
# # filter out columns that contain entirely NaN values
# col_mask = ~np.all(np.isnan(X), axis=(0,2)) # boolean mask indicating which columns to keep
# masked_cols = X[:, col_mask, :] # keep only the columns that don't contain entirely NaN values
# print("masked_cols shape:", masked_cols.shape)

In [None]:
# # filter out all rows which contain any NaN values
# row_mask = ~np.any(np.isnan(masked_cols), axis=1) # boolean mask indicating which rows to keep
# masked_cols_rows = masked_cols[row_mask[:,0], :, :] # keep only the rows that don't contain any NaN values
# masked_cols_rows.shape

In [None]:
# # filter out all depths which contain any NaN values
# depth_mask = ~np.any(np.isnan(masked_cols_rows), axis=(0,1)) # boolean mask indicating which depths to keep
# X = masked_cols_rows[:, :, depth_mask] # keep only the depths that don't contain any NaN values
# X = np.swapaxes(X, 1, 2)
# print(f"X shape: {X.shape}")

In [None]:
# # create target from coral ground truth. Shape: (num_samples, 1)
# # TODO: not sure if this is shuffling the values when reshaping
# y_with_nans = np.array(xa_coral_climate_1_12["coral_algae_1-12_degree"].sel(
#     time=xa_coral_climate_1_12.time[-1])).reshape(-1, 1)
# # remove ys with nan values in other variables
# y = y_with_nans[row_mask[:,0]]

# print(f"y_with_nans shape: {y_with_nans.shape}")
# print(f"y shape: {y.shape}")

In [None]:
# X, y = filter_out_nans(X_with_nans, np.array(xa_coral_climate_1_12["coral_algae_1-12_degree"].isel(time=-1)).reshape(-1, 1))
# print(f"X shape: {X.shape}")
# print(f"y shape: {y.shape}")

In [None]:
# X.shape

## GRU function definitions 

In [11]:
xa_coral_climate_1_12_features = ds_man.get_dataset("monthly_climate_features")
xa_coral_climate_1_12 = ds_man.get_dataset("monthly_climate_1_12")

xa_coral_climate_1_12_working = xa_coral_climate_1_12

In [31]:
xa_coral_climate_1_12

In [119]:
def check_ds(obj):
    if isinstance(obj, xa.Dataset):
        return True
    else:
        return False


def check_da(obj):
    if isinstance(obj, xa.DataArray):
        return True
    else:
        return False

def sample_spatial_batch(
    xa_ds: xa.Dataset,
    lat_lon_starts: tuple = (0, 0),
    window_dims: tuple[int, int] = (6, 6),
    coord_range: tuple[float] = None,
    variables: list[str] = None,
) -> np.ndarray:
    """Sample a spatial batch from an xarray Dataset.

    Parameters
    ----------
    xa_ds (xa.Dataset): The input xarray Dataset.
    lat_lon_starts (tuple): Tuple specifying the starting latitude and longitude indices of the batch.
    window_dims (tuple[int, int]): Tuple specifying the dimensions (number of cells) of the spatial window.
    coord_range (tuple[float], optional): Tuple specifying the latitude and longitude range (in degrees) of the spatial
        window. If provided, it overrides the window_dims parameter.
    variables (list[str], optional): List of variable names to include in the spatial batch. If None, includes all
        variables.

    Returns
    -------
    np.ndarray: The sampled spatial batch as a NumPy array.

    Notes
    -----
    - The function selects a subsample of the input dataset based on the provided latitude, longitude indices, and window
        dimensions.
    - If a coord_range is provided, it is used to compute the latitude and longitude indices of the spatial window.
    - The function returns the selected subsample as a NumPy array.

    Example
    -------
    # Sample a spatial batch from an xarray Dataset
    dataset = ...
    lat_lon_starts = (2, 3)
    window_dims = (6, 6)
    coord_range = (2.5, 3.5)
    variables = ['var1', 'var2', 'var3']
    spatial_batch = sample_spatial_batch(dataset, lat_lon_starts, window_dims, coord_range, variables)
    """
    # if selection of variables specified
    if variables is not None:
        xa_ds = xa_ds[variables]

    # N.B. have to be careful when providing coordinate ranges for areas with negative coords. TODO: make universal
    lat_start, lon_start = lat_lon_starts[0], lat_lon_starts[1]

    if not coord_range:
        subsample = xa_ds.isel(
            {
                "latitude": slice(lat_start, window_dims[0]),
                "longitude": slice(lon_start, window_dims[1]),
            }
        )
    else:
        lat_cells, lon_cells = coord_range[0], coord_range[1]
        subsample = xa_ds.sel(
            {
                "latitude": slice(lat_start, lat_start + lat_cells),
                "longitude": slice(lon_start, lon_start + lon_cells),
            }
        )

    lat_slice = subsample["latitude"].values
    lon_slice = subsample["longitude"].values
    time_slice = subsample["time"].values

    return subsample, {"latitude": lat_slice, "longitude": lon_slice, "time": time_slice}

    
def generate_patch(
    xa_ds,
    lat_lon_starts,
    coord_range,
    feature_vars: list[str] = ["bottomT", "so", "mlotst", "uo", "vo", "zos", "thetao"],
    gt_var: str = "coral_algae_1-12_degree",
    normalise: bool = True,
    onehot: bool=True
):
    """Generate a patch for training or evaluation.

    Parameters
    ----------
    xa_ds (xa.Dataset): The input xarray dataset.
    lat_lon_starts (tuple): The starting latitude and longitude indices for sampling the patch.
    coord_range (tuple): The latitude and longitude range for sampling the patch.
    feature_vars (list[str], optional): List of variable names to be used as features. 
        Default is ["bottomT", "so", "mlotst", "uo", "vo", "zos", "thetao"].
    gt_var (str, optional): The variable name for the ground truth. Default is "coral_algae_1-12_degree".
    normalise (bool, optional): Flag indicating whether to normalize each variable between 0 and 1. Default is True.
    onehot (bool, optional): Flag indicating whether to encode NaN values using the one-hot method. Default is True.

    Returns
    -------
    tuple: A tuple containing the feature array, ground truth array, subsampled dataset, and latitude/longitude values.
    """
    subsample, lat_lon_vals_dict = spatial_data.sample_spatial_batch(
        xa_ds, lat_lon_starts=lat_lon_starts, coord_range=coord_range
    )
    # assign features
    Xs = xa_d_to_np_array(subsample[feature_vars])
    # assign ground truth
    ys = xa_d_to_np_array(subsample[gt_var])

    # convert to column vectors
    Xs, ys = spatial_array_to_column(Xs), spatial_array_to_column(ys)

    # if normalise = True, normalise each variable between 0 and 1
    if normalise:
        Xs = spatial_data.normalise_3d_array(Xs)

    # remove columns containing only nans. TODO: enable all nan dims
    nans_array = exclude_all_nan_dim(Xs, dim=1)

    # if encoding nans using onehot method
    if onehot:
        Xs = encode_nans_one_hot(nans_array)
    Xs = naive_nan_replacement(Xs)
    
    # this shouldn't ever be necessary
    ys = naive_nan_replacement(ys)
    # take single time slice (since broadcasted back through time)
    ys = ys[:, 0]

    return Xs, ys, subsample, lat_lon_vals_dict


def exclude_all_nan_dim(array, dim):
    # TODO: check performance. Currently only able to hand columns (see generalisation comment)
    # filter out columns that contain entirely NaN values
    num_dims = len(array.shape)
    axes = tuple(set(np.arange(0, num_dims)) - {dim})

    dim_mask = ~np.all(
        np.isnan(array), axis=axes
    )  # boolean mask indicating which columns to keep
    # TODO: need to generalise this
    return array[
        :, dim_mask, :
    ]  # keep only the columns that don't contain entirely NaN values


def encode_nans_one_hot(array: np.ndarray, all_nan_dims: int = 1) -> np.ndarray:
    """One-hot encode NaN values in a 3D array.

    Parameters
    ----------
    array (np.ndarray) The input 3D array.
    all_nan_dims (int, optional): The number of dimensions (starting from the second dimension) to consider when 
        determining if all values are NaN. Default is 1.

    Returns
    -------
    np.ndarray: The one-hot encoded array with NaN information.
    """
    # boolean mask of land (where all variable values are nan throughout all time)
    land_mask = np.all(np.isnan(array), (1, 2))
    # binary land mask
    onehot_column = np.where(land_mask, 1, 0)
    # binary land mask expanded to target dimensions
    onehot_expanded = np.expand_dims(onehot_column, axis=(1, 2))
    # binary land mask broadcast back through time
    onehot_broadcast = np.repeat(onehot_expanded, array.shape[1], axis=1)

    return np.concatenate((array,onehot_broadcast), axis=2)


def subsample_to_array(
    xa_ds,
    lat_lon_starts,
    coord_range,
    variables: list[str]
):
    subsample, lat_lon_vals_dict = sample_spatial_batch(
        xa_ds[variables], lat_lon_starts=lat_lon_starts, coord_range=coord_range
    )
    return xa_d_to_np_array(subsample), subsample, lat_lon_vals_dict


def xa_d_to_np_array(
    xa_d: xa.Dataset | xa.DataArray
) -> np.ndarray:
    """Converts an xarray dataset or data array to a NumPy array.

    Parameters
    ----------
    xa_d (xarray.Dataset or xarray.DataArray): The xarray dataset or data array to convert.

    Returns
    -------
    np.ndarray: The converted NumPy array.

    Raises
    ------
    TypeError: If the provided object is neither an xarray Dataset nor an xarray DataArray.
    """
    # if xa.DataArray
    if check_da(xa_d):
        return np.array(xa_d.values)

    # else if dataset
    elif check_ds(xa_d):
        # transpose coordinates for consistency
        ds = xa_d.transpose("latitude","longitude","time")
        # send to array
        array = ds.to_array().values
        # reorder dimensions to (lat x lon x var x time)
        return np.moveaxis(array, 0, 3)
    else:
        return TypeError("object provideed was neither an xarray Dataset nor xarray DataArray.")


def spatial_array_to_column(array: np.ndarray) -> np.ndarray:
    """Reshape the first two dimensions of a 3D NumPy array to a column vector.

    Parameters
    ----------
    array (np.ndarray): The input 3D NumPy array.

    Returns
    -------
    np.ndarray: The reshaped column vector.

    Examples
    --------
    array = np.random.rand(lat, lon, var, ...)
    column_vector = spatial_array_to_column(array)
    print(column_vector.shape)
    (lat x lon, var, ...)
    """
    array_shape = array.shape
    new_shape = (array_shape[0] * array_shape[1], *array.shape[2:])
    return np.reshape(array, new_shape)


def naive_nan_replacement(array:np.ndarray, replacement: float=0) -> np.ndarray:
    """Replace NaN values in a NumPy array with a specified replacement value.

    Parameters
    ----------
    array (np.ndarray): The input array.
    replacement (float, optional): The value to replace NaNs with. Default is 0.

    Returns
    -------
    np.ndarray: The array with NaN values replaced.
    """
    # replace nans with "replacement"
    array[np.isnan(array)] = 0
    return array

In [123]:
# all_Xs_onehot, all_lat_lon_dict_onehot = sample_spatial_batch(xa_coral_climate_1_12, lat_lon_starts=(-8,140), coord_range=(-20,13))
# all_Xs_onehot, all_lat_lon_dict_onehot = subsample_to_array(xa_coral_climate_1_12, lat_lon_starts=(-8,140), coord_range=((-20,13)))
# all_Xs_onehot = naive_X_nan_replacement(all_Xs_onehot)
# all_ys_onehot, _ = subsample_to_array(xa_coral_climate_1_12, lat_lon_starts=(-8,140), coord_range=(-20,13), variables = ["coral_algae_1-12_degree"])
# all_ys_onehot = naive_y_nan_replacement(all_ys_onehot)
# all_ys_onehot = all_ys_onehot[:,:,0]

train_onehot_Xs, train_onehot_ys, train_onehot_subsample, train_onehot_lat_lons_vals_dict = generate_patch(xa_ds=xa_coral_climate_1_12, lat_lon_starts=(-10,142), coord_range=(-6,6), onehot=False)
test_onehot_Xs, test_onehot_ys, test_onehot_subsample, test_onehot_lat_lons_vals_dict = generate_patch(xa_ds=xa_coral_climate_1_12, lat_lon_starts=(-16,148), coord_range=(-6,6))


In [124]:
train_onehot_Xs.shape

(5184, 336, 7)

In [125]:
train_onehot_Xs[:5,0,:]

array([[0.89920235, 0.49554792, 0.01470677, 0.58688444, 0.48317835,
        0.6076122 , 0.72998065],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        ],
       [0.89826053, 0.5114261 , 0.02744764, 0.58721715, 0.46181637,
        0.61703223, 0.72711533],
       [0.89819086, 0.51152647, 0.03303663, 0.5938123 , 0.4667428 ,
        0.6196865 , 0.72736686]], dtype=float32)

In [23]:
xa_coral_climate_1_12

In [None]:
print("train_onehot_Xs shape: ", train_onehot_Xs.shape)
print("train_onehot_ys shape: ", train_onehot_ys.shape)

In [None]:
# load bathymetry
bath_A = ds_man.get_dataset("bathymetry_A")

# 1 km. Struggles displaying/processing 100m, but have yet to try saving to this/inferring
target_resolution = 1000
_,_,av_degrees = spatial_data.distance_to_degrees(target_resolution)
bath_A_1km = spatial_data.upsample_xarray_to_target(bath_A, av_degrees)
# im = bath_A_1km.plot(ax=ax)

spatial_plots.plot_DEM(bath_A_1km, f" DEM upsampled to {target_resolution} meters", vmin=-100, vmax=0)
# spatial_plots.format_spatial_plot(im, fig, ax, f"Upsampled to {target_resolution} degrees")

In [None]:
# def min_max_index_extreme_values(array, min_val=-100, max_val=0):
#     # Find indices of values above min and below max values
#     above_max_indices = np.where(array > max_val)
#     below_min_indices = np.where(array < min_val)

#     # Find the minimum and maximum indices among the above criteria
#     lon_min_index = np.min(np.concatenate((above_max_indices[1], below_min_indices[1])))
#     lat_min_index = np.max(np.concatenate((above_max_indices[1], below_min_indices[1])))
#     return lon_min_index, lat_min_index


# def slice_coast(bath_ds, num_slices=10, bathymetric_range: tuple[float] = (-100,0)):

#     # lat_lims = spatial_data.xarray_coord_limits(bath_ds, "latitude")
#     # could use values or indices


#     # slice_ranges = np.linspace(lat_lims[0], lat_lims[1], num_slices)
#     lat_num = len(list(bath_ds["latitude"]))
#     slice_ranges = np.arange(0, lat_num, (lat_num // num_slices))
#     # slice_height = int(np.floor(np.abs(np.diff(lat_lims)) / num_slices))
#     # for vertical centre of each slice, find limits of relevant bathymetry
#     for i in range(len(slice_ranges)-1):
#         slice_ds = bath_ds.isel(latitude=slice(slice_ranges[i],slice_ranges[i+1]))
#         #don't think I need values here
#         lon_min_index, lat_min_index = min_max_index_extreme_values(slice_ds.values)
#         print(lon_min_index, lat_min_index)



In [None]:
# slice_coast(bath_1_12_degree)
chunk_size = 20
vmin, vmax = -100, 0
threshold_percent = 10
chunk_coords = spatial_data.find_chunks_with_percentage(bath_A_1km.values[0,:,:], -100, 0, chunk_size, threshold_percent)

print("array_shape", bath_A_1km.values[0,:,:].shape)
print("extreme chunk", chunk_coords[-1])

In [None]:
chunk_coords[:4]

In [None]:
out = bath_A_1km.isel(band=0)
out

In [None]:
get_vars_from_ds_or_da(out)

In [None]:
def ds_subsample_from_coord(xa_ds, chunk_coords):
    lat_start, lat_end = chunk_coords[0][0], chunk_coords[0][1]
    lon_start, lon_end = chunk_coords[1][0], chunk_coords[1][1]

    return xa_ds.isel({"latitude": slice(lat_start, lat_end), "longitude": slice(lon_start, lon_end)})

def get_vars_from_ds_or_da(xa_d: xa.DataArray | xa.Dataset) -> str | list[str]:
    if type(xa_d) == xa.core.dataarray.DataArray:
        vars = xa_d.name
    elif type(xa_d) == xa.core.dataarray.Dataset:
        vars = list(xa_d.data_vars)
    else:
        raise TypeError("Format was neither an xarray Dataset nor a DataArray")

    return vars


def nc_chunk_files(dest_dir_path: Path | str, xa_ds: xa.Dataset, chunk_size: int = 20, 
    threshold_percent: float=10, vmin: float=-100, vmax: float=0):
    
    chunk_coord_pairs = spatial_data.find_chunks_with_percentage(
        xa_ds, vmin, vmax, chunk_size, threshold_percent)
    
    for coord_pair in chunk_coord_pairs:
        sub_ds = ds_subsample_from_coord(xa_ds, coord_pair)
        # make filename
        vars = get_vars_from_ds_or_da(xa_ds)
        # convert coord indices to absolute coords
        filename = climate_data.generate_spatiotemporal_var_filename_from_dict(
            {"vars": vars, "lats": coord_pair[0], "lons": coord_pair[1]})

        # save file
    print(filename)
    return sub_ds

nc_chunk_files("asdf", bath_A_1km.isel(band=0))

In [None]:
chunk_coords

In [None]:
fig, ax = plt.subplots(figsize=[10,10])
da = bath_A_1km
da.plot(ax=ax, vmin=vmin, vmax=vmax, cmap="BrBG")
ax.set_aspect("equal")

for coord in chunk_coords:
    xy = index_to_coord(da, coord[0])
    height, width = delta_index_to_distance(da, coord[1], coord[0])
    rect = patches.Rectangle(xy, width, height, linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(rect)

plt.show()

In [None]:
# attempt downsampling climate in each square
# save as individual nc file
# set up tf dataloader: load each nc file and take batches from it

In [None]:
bath_1_12_degree = spatial_data.upsample_xarray_to_target(bath_A, 1/12)

bath_1_12_degree.values[0,:,:].shape

In [None]:
min_max_index_extreme_values(bath_1_12_degree.sel(latitude=slice(-12,-15)))

In [None]:
bath_1_12_degree

In [None]:
monthly_climate.sel(latitude=slice(-10,-17.05), longitude=slice(141.95,147.05))

In [None]:
no_band_bath = bath_1_12_degree.isel(band=0)

# downsample climate data to 1km
monthly_climate = ds_man.get_dataset("monthly_climate_1_12")

# get limits of bathymetry
lat_lims = spatial_data.xarray_coord_limits(bath_1_12_degree, "latitude")
lon_lims = spatial_data.xarray_coord_limits(bath_1_12_degree, "longitude")

restricted_monthly_climate = monthly_climate.sel(latitude=slice(-10,-17), longitude=slice(142,147))


# padded_restricted_monthly_climate = spatial_data.buffer_nans(restricted_monthly_climate, size=1
km_monthly = restricted_monthly_climate.interp_like(bath_1_12_degree, method="linear")

coral_climate_1km = xa.combine_by_coords([km_monthly.drop("spatial_ref"),no_band_bath], coords=["time", "latitude", "longitude"])
(coral_climate_1km,) = xa.broadcast(coral_climate_1km)
coral_climate_1km

In [None]:
# TODO: fix ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (13,) + inhomogeneous part.

# train_comb_Xs, train_comb_ys, train_comb_subsample, train_comb_lat_lons_vals_dict = generate_patch(xa_ds=coral_climate_1km, lat_lon_starts=(-10,142), coord_range=(-6,6))
# test_comb_Xs, test_comb_ys, test_comb_subsample, test_comb_lat_lons_vals_dict = generate_patch(xa_ds=coral_climate_1km, lat_lon_starts=(-16,148), coord_range=(-6,6))

In [None]:
print("all_Xs_onehot shape: ", all_Xs_onehot.shape)
print("all_ys_onehot shape: ", all_ys_onehot.shape)

In [None]:
# TODO: normalise along variable axes

In [None]:
all_Xs, all_lat_lon_dict = sample_spatial_batch(xa_coral_climate_1_12, lat_lon_starts=(-8,140), coord_range=(-20,13))
all_Xs, all_lat_lon_dict = subsample_to_array(xa_coral_climate_1_12, lat_lon_starts=(-8,140), coord_range=((-20,13)))
all_Xs = naive_X_nan_replacement(all_Xs)
all_ys, _ = subsample_to_array(xa_coral_climate_1_12, lat_lon_starts=(-10,142), coord_range=(-7,5), variables = ["coral_algae_1-12_degree"])
all_ys = naive_y_nan_replacement(all_ys)
all_ys = all_ys[:,:,0]

In [None]:
print("Xs shape: ", Xs.shape)
print("ys shape: ", ys.shape)

In [None]:
Xs, ys, all_subsample, all_lat_lons_vals_dict = generate_patch(xa_coral_climate_1_12, lat_lon_starts=(-10,142), coord_range=(-7,5))
patch1_Xs, patch1_ys, patch1_subsample, patch1_lat_lons_vals_dict = generate_patch(xa_coral_climate_1_12, (-10,142), (-7,5))
patch2_Xs, patch2_ys, patch2_subsample, patch2_lat_lons_vals_dict = generate_patch(xa_coral_climate_1_12, (-17,147), (-7,5))
patch3_Xs, patch3_ys, patch3_subsample, patch3_lat_lons_vals_dict = generate_patch(xa_coral_climate_1_12, (-16,146), (-7,5))



In [None]:
wandb.finish()

In [None]:
wandb.init(
    project="coralshift",
    entity="orlando-code",
    settings=wandb.Settings(start_method="fork")
    # config={    }
    )

# initialize optimiser: will need hyperparameter scan for learning rate and others
# https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam
optimizer = tf.keras.optimizers.Adam(3e-4)

# X = ds_man.get_dataset("monthly_climate_1_12_X_np")
# y = ds_man.get_dataset("monthly_climate_1_12_y_np")
# # check that untrained model runs (should output array of non-nan values)
# # why values change?
# # g_model(X[:32])

# X_train, X_test, y_train, y_test = model_selection.train_test_split(
#     X, y, test_size=0.2, random_state=42)

# X_train, X_test, y_train, y_test = model_selection.train_test_split(
#     sub_X, sub_y, test_size=0.2, random_state=42)

# Define Gated Recurrent Unit model class in TensorFlow
class gru_model(tf.keras.Model):
    # initialise class instance to define layers of the model
    def __init__(self, rnn_units: list[int], num_layers: int, 
        # dff: int
        ):
        """Sets up a GRU model architecture with multiple layers and dense layers for mapping the outputs of the GRU 
        layers to a desired output shape

        Parameters
        ----------
        rnn_units (list[int]): list containing the number of neurons to use in each layer
        num_layers (int): number of layers in GRU model
        """
        super(gru_model, self).__init__()   # initialise GRU model as subclass of tf.keras.Model
        # store values for later use
        self.num_layers = num_layers    # number of layers in GRU model
        self.rnn_units = rnn_units
        # self.dff = dff
        # define model layers: creating new `tf.keras.layers.GRU` layer for each iteration
        self.grus = [tf.keras.layers.GRU(rnn_units[i],  # number (integer) of rnn units/neurons to use in each model layer
                                   return_sequences=True,   # return full sequence of outputs for each timestep
                                   return_state=True) for i in range(num_layers)] # return last hidden state of RNN at end of sequence
        
        # dense layers are linear mappings of RNN layer outputs to desired output shape
        # self.w1 = tf.keras.layers.Dense(dff) # 10 units
        self.w1 = tf.keras.layers.Dense(10) # 10 units

        self.w2 = tf.keras.layers.Dense(1)  # 1 unit (dimension 1 required before final sigmoid function)
        # self.A = tf.keras.layers.Dense(30)
        # self.B = tf.keras.layers.Dense(dff)



    def call(self, inputs: np.ndarray, training: bool=False):
        """Processes an input sequence of data through several layers of GRU cells, followed by a couple of
        fully-connected dense layers, and outputs the probability of an event happening.
        
        Parameters
        ----------
        inputs (np.ndarray): input tensor of shape (batch_size, seq_length, features)
            batch_size - defines the size of the sample drawn from datapoints
            seq_length - number of timesteps in sequence
            features - number of features associated with each datapoint
        training (bool, defaults to False): True if model is in training, False if in inference mode

        Returns
        -------
        target: probability of an event occuring, with shape (batch_size, 1)
        """
        # input shape: (batch_size, seq_length, features)
       
        assert self.num_layers == len(self.rnn_units)

        # check that input tensor has correct shape
        if (len(inputs.shape) != 3):
            print(f"Incorrect shape of input tensor. Expected 3D array. Recieved {len(inputs.shape)}D array.")

        # print('input dim ({}, {}, {})'.format(inputs.shape[0], inputs.shape[1], inputs.shape[2]))
        # whole_seq, static_input = inputs
        whole_seq = inputs


        # iteratively passes input tensor to GRU layers, overwritting preceding sequence 'whole_seq'
        for layer_num in range(self.num_layers):
            whole_seq, final_s = self.grus[layer_num](whole_seq, training=training)

        # adding extra layers
        # static = self.B(tf.nn.gelu(self.A(static_input)))
        # target = self.w1(final_s)  + static # final hidden state of last layer used as input to fully connected dense layers...
        target = self.w1(final_s)   # final hidden state of last layer used as input to fully connected dense layers...

        target = tf.nn.relu(target) # via ReLU activation function
        target = self.w2(target)    # final hidden layer must have dimension 1 
        
        # obtain a probability value between 0 and 1
        target = tf.nn.sigmoid(target)
        
        return target


# initialise GRU model with 500 hidden layers, one GRU unit per layer 
g_model = gru_model([500], 1) # N.B. [x] is number of hidden layers in GRU network


def negative_log_likelihood(y: np.ndarray, y_pred: np.ndarray, class_weights: np.ndarray = None) -> float:
    """Compute binary cross-entropy loss between ground-truth binary labels and predicted probabilities,
    incorporating class weights.

    Parameters
    ----------
    y (np.ndarray): true binary labels, where 0 represents the negative class
    y_pred (np.ndarray): predicted labels (as probability value between 0 and 1)
    class_weights (np.ndarray): weights for each class. If None, no class weights will be applied.

    Returns
    -------
    float: negative log likelihood loss computed using binary cross-entropy loss between 'y' and 'y_pred',
    incorporating class weights if provided
    """
    bce = tf.keras.losses.BinaryCrossentropy()  

    if class_weights is not None:
        sample_weights = tf.gather(class_weights, np.asarray(y,dtype=np.int32))
        return bce(y, y_pred, sample_weight=sample_weights)

    return bce(y, y_pred)


def training_batches(X: np.ndarray, y: np.ndarray, batch_num: int, batch_size: int=32):
    start_idx = batch_num * batch_size
    end_idx = (batch_num + 1) * batch_size

    X_batch = X[start_idx:end_idx]
    y_batch = y[start_idx:end_idx]
    
    return X_batch, y_batch

# https://stackoverflow.com/questions/52357542/attributeerror-tensor-object-has-no-attribute-numpy
# should aim to delete the following to speed up training: but can't figure out a way to make wandb reporting work
# without it
tf.config.run_functions_eagerly(True)

def build_graph():
    
    # compile function as graph using tf's autograph feature: leads to faster execution times, at expense of limitations
    # to Python objects/certain control flow structures (somewhat relaxed by experimental_relax_shapes)
    @tf.function(experimental_relax_shapes=True)
    def train_step(gru: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, X: np.ndarray, y: np.ndarray, 
        training: bool=True, class_weights=class_weights, batch_num:int=None, batch_size: int=None) -> tuple[np.ndarray, float]:
        """Train model using input `X` and target data `y` by computing gradients of the loss (via 
        negative_log_likelihood)
        
        Parameters
        ----------
        y (np.ndarray): true binary labels, where 0 represents the negative class
        y_pred (np.ndarray): predicted labels (as probability value between 0 and 1)

        Returns
        -------
        float: negative log likelihood loss computed using binary cross-entropy loss between 'y' and 'y_pred'
        """
        if training:
            num_samples = X.shape[0]
            num_batches = num_samples // batch_size
            # num_batches = batch_num
            total_epoch_loss = 0.0
            for batch_num in tqdm(range(num_batches), desc="batches", position=0, leave=True):
                X_batch, y_batch = training_batches(X, y, batch_num=batch_num, batch_size=batch_size)

                with tf.GradientTape(persistent=True) as tape:
                    y_pred = gru(X_batch, training) 
                    xent = negative_log_likelihood(y_batch, y_pred, class_weights)
                    # y_pred = gru(X, training) # TO DELETE
                    # xent = negative_log_likelihood(y, y_pred)
                
                gradients = tape.gradient(xent, gru.trainable_variables)
                optimizer.apply_gradients(zip(gradients, gru.trainable_variables))
                # print("xent", xent.numpy())
                # print("total_epoch_loss", total_epoch_loss)
                total_epoch_loss += xent
                # learning rate?
                wandb.log({"batch": batch_num, "loss": xent, "total_epoch_loss": total_epoch_loss})

            average_loss = total_epoch_loss / num_batches
            # return predicted output values and total loss value
            return y_pred, xent, total_epoch_loss

    # set default float type
    tf.keras.backend.set_floatx('float32')
    # TODO: this isn't assigned... What should it return otherwise?
    return train_step


X_train = train_onehot_Xs[:]
y_train = train_onehot_ys[:]

class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(y_train), y=np.reshape(y_train,-1))

with tf.device("/GPU:0"):
    num_epochs = 50
    # will update so that subsamples are fed in from which batches are taken: will require recomputation
    # of class_weight for each subsample
    num_batches = 2
    batch_size = 512
    tr_step = build_graph()
    for epoch in tqdm(range(num_epochs), desc= " epochs", position = 0, leave=True):
        y_pred, xent, total_epoch_loss = tr_step(
            g_model, optimizer, X_train[:], y_train[:], class_weights=class_weights, 
            batch_size=batch_size, batch_num=num_batches, training=True)
        wandb.log({"epoch": epoch, "epoch_loss": total_epoch_loss})

wandb.finish()

# check (with prints) that wandb is functioning
# check against known timeseries task for correct implementation
# find timeseries which get bad loss and debug why
# log best loss which can be logged: save weights (and do a inference plot with best weights)

In [None]:
X_train.shape

In [None]:
### train/predict
# if training:
    # assign epoch output to dataset
# if testing:
# patch1_pred = g_model(patch1_Xs, training=False)
# patch2_pred = g_model(patch2_Xs, training=False)
# patch3_pred = g_model(patch3_Xs, training=False)
pred = g_model(test_onehot_Xs, training=False)

# all_predicted = g_model(all_Xs, training=False)
    # assign output predictions to dataset

In [None]:
patch_pred_xa_ds_onehot_lim = reformat_prediction(xa_coral_climate_1_12_working, test_onehot_subsample, pred, test_onehot_lat_lons_vals_dict)

mask = patch_pred_xa_ds_onehot_lim["output"] > 0.6

spatial_plots.plot_var(mask)

In [None]:
patch_pred_xa_ds_onehot_lim = reformat_prediction(xa_coral_climate_1_12_working, test_onehot_subsample, pred, test_onehot_lat_lons_vals_dict)

spatial_plots.plot_var(patch_pred_xa_ds_onehot_lim["output"])



In [None]:
patch_Xs_onehot, patch_ys_onehot, patch_subsample_onehot, patch_lat_lons_vals_dict_onehot = generate_patch(xa_coral_climate_1_12, (-10,142), (-7,5))

patch_pred_xa_ds_onehot = reformat_prediction(xa_coral_climate_1_12_working, patch_subsample_onehot, pred, patch_lat_lons_vals_dict_onehot)

spatial_plots.plot_var(patch_pred_xa_ds_onehot["output"])


In [None]:
np.array_equal(patch1_pred_xa_ds["output"].values,patch2_pred_xa_ds["output"])

In [None]:
# patch1_pred_xa_ds = reformat_prediction(xa_coral_climate_1_12_working, patch1_subsample, patch1_pred, patch1_lat_lons_vals_dict)
# patch2_pred_xa_ds = reformat_prediction(xa_coral_climate_1_12_working, patch2_subsample, patch2_pred, patch2_lat_lons_vals_dict)
# patch3_pred_xa_ds = reformat_prediction(xa_coral_climate_1_12_working, patch3_subsample, patch3_pred, patch3_lat_lons_vals_dict)


f, a0 = spatial_plots.plot_var_at_time(xa_coral_climate_1_12["coral_algae_1-12_degree"], "2020-12-16")
# # visualise result
# spatial_plots.plot_var(patch1_pred_xa_ds["output"])
# spatial_plots.plot_var(patch2_pred_xa_ds["output"])
# spatial_plots.plot_var(patch3_pred_xa_ds["output"])

# pred_xa_ds["coral_algae_1-12_degree"].isel(time=-1).plot(ax=ax[0])
# pred_xa_ds["output"].plot(ax=ax[1])


In [None]:
xa_coral_climate_1_12["bottomT"].isel(time=-1).plot()

# Adding in bathymetry

In [None]:
# downsample climate data to 1km
monthly_climate = ds_man.get_dataset("monthly_climate_1_12")

# get limits of bathymetry
lat_lims = spatial_data.xarray_coord_limits(coarsened_bath_A, "latitude")
lon_lims = spatial_data.xarray_coord_limits(coarsened_bath_A, "longitude")

restricted_monthly_climate = monthly_climate.sel(latitude=slice(-10,-17), longitude=slice(142,147))


# padded_restricted_monthly_climate = spatial_data.buffer_nans(restricted_monthly_climate, size=1

In [None]:
buffer_size = 3
exclude_vars = ["spatial_ref", "coral_algae_1-12_degree", "siconc", "usi", "vsi", "sithick"]
buffered_ds = spatially_buffer_timeseries(monthly_climate, buffer_size=buffer_size, exclude_vars=exclude_vars)

buffered_ds.to_netcdf(
    ds_man.get_location() / f"global_ocean_reanalysis/monthly_means/monthly_climate_{buffer_size}_buffer.nc")

In [None]:
buffered_ds

In [None]:
f,a = plt.subplots(1,2, figsize=[10,5])
monthly_climate["mlotst"].isel(time=99).plot(ax=a[0], cmap="jet")
buffer_attempt["mlotst"].isel(time=99).plot(ax=a[1],cmap="jet")

In [None]:
coral_climate_1km

In [None]:
buffer_attempt.equals(monthly_climate.isel(time=slice(0,2)))

In [None]:
buffer_attempt.isel(time=1)["mlotst"].plot()

In [None]:
coarsened_bath_A.isel(band=0)

In [None]:
coral_climate_1km["bathymetry_A"]

In [None]:
monthly_climate

In [None]:
f,a = plt.subplots(1,2, figsize=[10,5])
coral_climate_1km.isel(time=-1)["bathymetry_A"].plot(ax=a[0], vmin=-100, vmax=0)
coral_climate_1km.isel(time=-1)["mlotst"].plot(ax=a[1])

In [None]:
# attempt.isel(time=-1)["mlotst"].plot()
eg_data = buffer_attempt.isel(time=-1)["mlotst"]

spatial_plots.plot_DEM(eg_data, f" DEM upsampled to {target_resolution} meters", 
    landmask=False, vmin=np.nanmin(eg_data.values), vmax=np.nanmax(eg_data.values), cmap="jet")

In [None]:
# TESTING

# for longitude in array
# get 
sub_X.shape

In [None]:
X.shape

In [None]:
wandb.finish()

In [None]:
# TODO: optionally replace batching with spatial batching

### Test GRU functions

In [None]:
print(tf.config.list_physical_devices())
!nvidia-smi

In [None]:
with tf.device("/GPU:0"):
    num_epochs = 5
    num_batches = 100
    tr_step = build_graph()
    for epoch in tqdm(range(num_epochs), desc= " epochs", position = 0):
        y_pred, average_loss = tr_step(g_model, optimizer, X_train[:1000], y_train[:1000], batch_size=32, training=True)
        
        
        
        # for batch in range(num_batches):
        #     array, y  = batcher_fun(X, 32, 276 
        #     #training = True)# shapes: (batch_s, seq_l, features), (batch_s, 1)
        #     )
        #     y_pred, xent = tr_step(g_model, optimizer, X[:32], y, training=True)
            
        #  ## validation set 
        #  ## test_set 

In [None]:
# negative_log_likelihood(y_test,g_model(X_test))
y_test=y_test[:1000]
y_pred = g_model(X[:1000])

In [None]:
plt.scatter(y_test,y_pred.numpy())

In [None]:
# made unnecssary due to isel indexing
# def pixels_to_coord_diff(xa_ds: xa.Dataset | xa.DataArray, window_dim: int, coord: str) -> list[float, float]:
#     return float(window_dim * np.diff(spatial_data.min_max_of_coords(xa_ds, coord)) / len(list(xa_coral_climate_1_12[coord])))

In [None]:
subsample["bottomT"].isel(time=-1).plot()

In [None]:
sample_spatial_batch(xa_coral_climate_1_12,lat_lon_starts=(-16,144), coord_range=(-4,2))["coral_algae_1-12_degree"]

In [None]:
xa_coral_climate_1_12_working = xa_coral_climate_1_12

In [None]:
# [lat_lon_vals_dict.items() for key in ["latitude", "longitude"]]
{key: lat_lon_vals_dict[key] for key in ["latitude", "longitude"]}

In [None]:
sub_y_nans.shape

In [None]:
(list(subsample.data_vars))

In [None]:
# sub_X = np.moveaxis(np.array(test_array), 2, 1)
sub_X.shape

In [None]:
col_mask = ~np.all(np.isnan(test_array), axis=(0,2))
sub_X = test_array[:, col_mask, :]

In [None]:
sub_X.shape

In [None]:
xa_coral_climate_1_12_features = ds_man.get_dataset("monthly_climate_features")
xa_coral_climate_1_12 = ds_man.get_dataset("monthly_climate_1_12")

In [None]:
Xs, lat_lon_dict = subsample_to_array(xa_coral_climate_1_12, lat_lon_starts=(-10,142), coord_range=(-7,5))
Xs_nonans = naive_X_nan_replacement(Xs)
ys, _ = subsample_to_array(xa_coral_climate_1_12, lat_lon_starts=(-10,142), coord_range=(-7,5), variables = ["coral_algae_1-12_degree"])
ys = ys[:,:,0]

In [None]:
ys.shape

In [None]:
# sample subset
# subsample, lat_lon_vals_dict = sample_spatial_batch(xa_coral_climate_1_12_features,lat_lon_starts=(-12,144), coord_range=(-4,2))
# subsample, lat_lon_vals_dict = sample_spatial_batch(xa_coral_climate_1_12_features,lat_lon_starts=(-10,142), coord_range=(-7,5))
# convert to ndarray
# # test_array = spatial_data.xa_ds_to_3d_numpy(subsample)
# # subsample_all, _ = sample_spatial_batch(xa_coral_climate_1_12,lat_lon_starts=(-16,144), coord_range=(-4,2))
# subsample_all, _ = sample_spatial_batch(xa_coral_climate_1_12,lat_lon_starts=(-10,142), coord_range=(-7,5))
# sub_y_nans = (np.array(subsample_all["coral_algae_1-12_degree"].isel(time=-1))).reshape(-1, 1)
# # remove nans
# #sub_X, sub_y = filter_out_nans(test_array, sub_y_nans)
# # testing, so replace nans with -1
# # filter out columns that contain entirely NaN values
# # col_mask = ~np.all(np.isnan(test_array), axis=(0,2)) # boolean mask indicating which columns to keep
# # sub_X = test_array[:, col_mask, :] # keep only the columns that don't contain entirely NaN values

# # sub_X = np.moveaxis(np.array(sub_X), 2, 1)
# sub_y = sub_y_nans
# # sub_X[np.isnan(sub_X)] = -10000
# sub_y[np.isnan(sub_y)] = -10000




In [None]:
def subset_to_dataset_var(xa_ds: xa.Dataset | xa.DataArray, subset_vals: np.ndarray, dims: list=['latitude', 'longitude', "time"]):

In [None]:
predicted

In [None]:
test_x_train = np.append(100*np.ones((50,50,5)), 1*np.ones((50,50,5)), 0)
test_y_train = np.append(np.ones(50,), np.zeros(50,))


print("test_x_train:", test_x_train.shape)
# print("x_test:", x_test.shape)
print("test_y_train:", test_y_train.shape)
# print("y_test:", y_test.shape)


In [None]:
array = np.random.normal(size = (32, 20, 1))    # shape: (num_samples, sequence_length, num_features)
y_dud = np.random.choice([0, 1], size = 32)
print("array shape:", array.shape)
print("y_dud shape:", y_dud.shape)

x_train, y_train = X[:500], y[:500].reshape(500,)
# x_test, y_test = X[5000:6000], y[5000:6000].reshape((1000,))

print("x_train:", x_train.shape)
# print("x_test:", x_test.shape)
print("y_train:", y_train.shape)
# print("y_test:", y_test.shape)


In [None]:
plt.plot(g_model(test_x_train[:],training=False).numpy())

In [None]:
predicted = g_model(X[:5610],training=False)

In [None]:
X.shape

In [None]:
xa_coral_climate_1_12["coral_algae_1-12_degree"].isel(time=-1)

In [None]:
30*187

In [None]:
fig, ax = plt.subplots()
out = ax.imshow(predicted.numpy().reshape(30,187))
fig.colorbar(out, ax=ax)

In [None]:
np.sqrt(500)

In [None]:
xa_coral_climate_1_12_features

In [None]:
xa_coral_climate_1_12["coral_algae_1-12_degree"].isel(time=-1)

In [None]:
fig, ax = plt.subplots()
out = ax.imshow(y_pred.numpy().reshape(20,25))
fig.colorbar(out, ax=ax)

In [None]:
sum((y_pred > 0.5).numpy())

In [None]:
# check log likelihood is computable
negative_log_likelihood(y[:32], g_model(X[:32]))

## Train and test GRU

In [None]:
# define batcher function (by space and time)

In [None]:
def batcher_fun(X, y, batch_size, seq_len):
    """
    A function to prepare the data for training the LSTM.
    
    :param data: The input data to the LSTM.
    :param batch_size: The number of samples in each batch.
    :param seq_len: The sequence length of each sample.
    
    :return: A tuple of (batch_x, batch_y), where batch_x is a numpy array of shape (batch_size, seq_len, num_features) 
             and batch_y is a numpy array of shape (batch_size, num_classes).
    """
    num_samples = len(data)
    num_batches = int(num_samples / batch_size)
    num_features = spatial_data.shape[1]
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = (i + 1) * batch_size
        
        batch_x = np.zeros((batch_size, seq_len, num_features))
        batch_y = np.zeros((batch_size, 1))
        
        for j in range(start_idx, end_idx):
            sample = data[j]
            X = sample[:-1]
            y = y[]
            
            batch_x[j - start_idx] = x.reshape((seq_len, num_features))
            batch_y[j - start_idx, y] = 1
            
        yield batch_x, batch_y


In [None]:
with tf.device("/CPU:0"):
    num_epochs = 1
    num_batches = 100
    tr_step = build_graph()
    for epoch in range(num_epochs):
        for batch in range(num_batches):
            array, y  = batcher_fun(X, 32, 276 
            #training = True)# shapes: (batch_s, seq_l, features), (batch_s, 1)
            )
            y_pred, xent = tr_step(g_model, optimizer, X[:32], y, training=True)
            
         ## validation set 
         ## test_set 

In [None]:
y_pred

# Copypasta

[Source](https://github.com/christianversloot/machine-learning-articles/blob/main/build-an-lstm-model-with-tensorflow-and-keras.md)

In [None]:
X.shape

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import imdb
from tensorflow.keras.layers import Embedding, Dense, LSTM
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers.legacy import Adam # https://stackoverflow.com/questions/75356826/attributeerror-adam-object-has-no-attribute-get-updates
from tensorflow.keras.preprocessing.sequence import pad_sequences

# Model configuration
additional_metrics = ["accuracy"]
# batch_size = 128
batch_size = 32
# embedding_output_dims = 15
# embedding_output_dims = 10
loss_function = BinaryCrossentropy()
# max_sequence_length = 300
max_sequence_length = 276
# num_distinct_words = 5000
# num_distinct_words = 10000
number_of_epochs = 5
optimizer = Adam()
validation_split = 0.20
verbosity_mode = 1

# Disable eager execution
tf.compat.v1.disable_eager_execution()

In [None]:
# Load dataset
# (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=num_distinct_words)
x_train, y_train = X[:5000], y[:5000].reshape((5000,))
x_test, y_test = X[5000:6000], y[5000:6000].reshape((1000,))

print("x_train:", x_train.shape)
print("x_test:", x_test.shape)
print("y_train:", y_train.shape)
print("y_test:", y_test.shape)

In [None]:
# Pad all sequences: keras requires sequences of equal lengths. Should be handled in pre-processing, but here for now for security
padded_inputs = pad_sequences(x_train, maxlen=max_sequence_length, value = 0.0) # 0.0 because it corresponds with <PAD>
padded_inputs_test = pad_sequences(x_test, maxlen=max_sequence_length, value = 0.0) # 0.0 because it corresponds with <PAD>

# (number_samples, sequence_length, num_features)
print("padded_inputs:", padded_inputs.shape)
print("padded_inputs_test:", padded_inputs_test.shape)

In [None]:
padded_inputs = pad_sequences(x_train[:,:,0], maxlen=max_sequence_length, value = 0.0)
padded_inputs_test = pad_sequences(x_test[:,:,0], maxlen=max_sequence_length, value = 0.0)

# (number_samples, sequence_length)
print("padded_inputs:", padded_inputs.shape)
print("padded_inputs_test:", padded_inputs_test.shape)

In [None]:
# Define the Keras model
model = Sequential()
model.add(
    Embedding(
        num_distinct_words+1, embedding_output_dims, input_length=max_sequence_length
    )
)
model.add(LSTM(10))
model.add(Dense(1, activation="sigmoid"))

# Compile the model
model.compile(optimizer=optimizer, loss=loss_function, metrics=additional_metrics)

# Give a summary
model.summary()

In [None]:
y_train.shape

In [None]:
# Train the model
history = model.fit(
    padded_inputs,
    y_train,
    batch_size=batch_size,
    epochs=number_of_epochs,
    verbose=verbosity_mode,
    validation_split=validation_split,
)

# Test the model after training
test_results = model.evaluate(padded_inputs_test, y_test, verbose=False)
print(f"Test results - Loss: {test_results[0]} - Accuracy: {100*test_results[1]}%")

In [None]:
def plot_score_timeseries(history) -> None:
    fig, ax = plt.subplots()
    ax.plot(history.history["accuracy"])
    ax.plot(history.history["val_accuracy"])

    ax.set_title("Model accuracy against epoch")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Accuracy")
    ax.legend(['train set', 'validation set'], loc='upper left')

plot_score_timeseries(history)

In [None]:
model.metrics_names

# Multivariate model

[Source](https://medium.com/@canerkilinc/hands-on-multivariate-time-series-sequence-to-sequence-predictions-with-lstm-tensorflow-keras-ce86f2c0e4fa)

In [None]:
X_toy = X[:32*10,:10,:3]
print("X_toy:", X_toy.shape)
y_toy = y[:32*10]
print("y_toy:", y_toy.shape)


In [None]:
#import packages
import tensorflow
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Input, LSTM
from tensorflow.keras.models import Model
#####################################
#Before do anything else do not forget to reset the backend for the next iteration (rerun the model)
tensorflow.keras.backend.clear_session()
#####################################
# Initialising the LSTM Model with MAE Loss-Function
batch_size = 32
epochs = 120
timesteps = 10
num_features = 3
input_1 = Input(batch_shape=(batch_size,timesteps,num_features))
#each layer is the input of the next layer
lstm_hidden_layer_1 = LSTM(10, stateful=True, return_sequences=True)(input_1)
lstm_hidden_layer_2 = LSTM(10, stateful=True, return_sequences=True)(lstm_hidden_layer_1)
output_1 = Dense(units = 1)(lstm_hidden_layer_2)
regressor_mae = Model(inputs=input_1, outputs = output_1)
#adam is fast starting off and then gets slower and more precise
#mae -> mean absolute error loss function
regressor_mae.compile(optimizer='adam', loss = 'mae')
#####################################
#Summarize and observe the layers as well as paramter configurations
regressor_mae.summary()

In [None]:
regressor_mae.fit(
    
)