In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path.append("../")
# TODO: hacky, shouldn't be necessary
os.chdir("/lustre_scratch/orlando-code/coralshift/")

In [2]:
# import nctoolkit as nc
from pathlib import Path
import numpy as np
import haversine
import xarray as xa
import pandas as pd
# from haversine import haversine, Units, inverse_haversine

from coralshift.utils import directories, file_ops, utils
from coralshift.processing import spatial_data
from coralshift.dataloading import climate_data

import cdsapi

In [3]:
def fetch_weather_data(
    download_dest_dir, weather_params, years, 
    months: list[int] | int = np.arange(1,13), 
    days: list[int] | int = np.arange(1,32), 
    hours: list[int] | int = np.arange(0,24),
    lat_lims=(-10,-17), lon_lims=(142,147), 
    dataset_tag: str="reanalysis-era5-single-levels", format: str="grib"):
    c = cdsapi.Client()

    area = [max(lat_lims), min(lon_lims), min(lat_lims), max(lon_lims)]

    for param in weather_params:
        param_download_dest = file_ops.guarantee_existence(
            Path(download_dest_dir) / param)
        for year in years:
            filename = climate_data.generate_spatiotemporal_var_filename_from_dict({
                "var": param,
                "lats": lat_lims,
                "lons": lon_lims,
                "year": str(year)
            })
            # filename = str(file_ops.generate_filepath(param_download_dest, filename, format))
            filepath = str(generate_filepath(param_download_dest, filename, format))
            
            if not Path(filepath).is_file():
                time_info_dict = return_times_info(year, months, days)
                # filename = str(file_ops.generate_filepath(param_download_dest, f"{param}_{year}", format))
                # filename = str((param_download_dest / param / str(year)).with_suffix(format))
                ecmwf_api_call(c, filepath, param, time_info_dict, area, dataset_tag, format)
            else:
                print(f"Filepath already exists: {filepath}")
        # TODO: more descriptive filename


def pad_suffix(suffix: str) -> str:
    """Pads the given file suffix with a leading period if necessary.

    Parameters
    ----------
        suffix (str): file suffix to pad.

    Returns
    -------
        str: The padded file suffix.
    """
    if "." not in suffix:
        suffix = "." + suffix
    return suffix


def generate_filepath(
    dir_path: str | Path, filename: str = None, suffix: str = None
) -> Path:
    """Generates directory path if non-existant; if filename provided, generates filepath, adding suffix if
    necessary."""
    # if generating/ensuring directory path
    if not filename:
        return guarantee_existence(dir_path)
    # if filename provided, seemingly with suffix included
    elif not suffix:
        return Path(dir_path) / filename
    # if filename and suffix provided
    else:
        return (Path(dir_path) / filename).with_suffix(pad_suffix(suffix))


def generate_ecmwf_api_dict(
    weather_params: list[str], time_info_dict: dict, area: list[float], format: str
) -> dict:
    """Generate api dictionary format for single month of event"""

    # if weather_params

    api_call_dict = {
        "product_type": "reanalysis",
        "variable": [weather_params],
        "area": area,
        "format": format,
    } | time_info_dict

    return api_call_dict


def generate_month_day_hour_list(items_range):
    items = []
    
    if isinstance(items_range, (int, np.integer)):
        items_range = [items_range]
    elif isinstance(items_range, np.ndarray):
        items_range = items_range.tolist()
    elif not isinstance(items_range, list):
        raise ValueError("Invalid input format. Please provide an integer, a list, or a NumPy array.")
    
    for item in items_range:
        if isinstance(item, (int, np.integer)):
            if item < 0 or item > 31:
                raise ValueError("Invalid items value: {}.".format(item))
            items.append(item)
        else:
            raise ValueError("Invalid input format. Please provide an integer, a list, or a NumPy array.")
    
    return items


def return_times_info(year: int, 
    months: list[int] | int = np.arange(1,13), 
    days: list[int] | int = np.arange(1,32), 
    hours: list[int] | int = np.arange(0,24)):

    year = str(year)
    months = [utils.pad_number_with_zeros(month) for month in generate_month_day_hour_list(months)]
    days = [utils.pad_number_with_zeros(day) for day in generate_month_day_hour_list(days)]
    
    hours = [utils.pad_number_with_zeros(hour) for hour in generate_month_day_hour_list(hours)]
    for h, hour in enumerate(hours):
        hours[h] = f"{hour}:00"

    return {"year": year, "month": months, "day": days, "time": hours}


def ecmwf_api_call(
    c,
    # download_dest_dir: Path | str,
    filepath: str,
    parameter: str,
    time_info_dict: dict,
    area: list[tuple[float]],
    dataset_tag: str = "reanalysis-era5-single-levels",
    format: str = "grib",
):
    api_call_dict = generate_ecmwf_api_dict(parameter, time_info_dict, area, format)
    # make api call
    try:
        c.retrieve(dataset_tag, api_call_dict, filepath)
    # if error in fetching, limit the parameter
    except ConnectionAbortedError():
        print(f"API call failed for {parameter}.")

In [4]:
from tqdm import tqdm

def hourly_means_to_daily(hourly_dir: Path | str, suffix: str="netcdf"):
    filepaths = file_ops.return_list_filepaths(hourly_dir, suffix, incl_subdirs=True)
    # create subdirectory to store averaged files
    daily_means_dir = file_ops.guarantee_existence(Path(hourly_dir) / "daily_means")
    for filepath in tqdm(filepaths, desc="Converting hourly means to daily means"):
        filename = "_".join((str(filepath.stem), "daily"))
        save_path = (daily_means_dir / filename).with_suffix(file_ops.pad_suffix(suffix))
        # open dataset
        hourly = xa.open_dataset(filepath, chunks = {"time": 100})
        daily = hourly.resample(time="1D").mean()
        # take average means
        daily.to_netcdf(save_path)

def merge_files_in_dir(dir: Path | str, suffix: str="netcdf", concat_dim: str="time"):
    filepaths = file_ops.return_list_filepaths(dir, suffix, incl_subdirs=False)
    dir = Path(dir)
    merged_name = f"{str(dir.stem)}_time_merged.nc"
    merged_path = dir / merged_name
    if not merged_path.is_file():
        print(f"Merging .nc files into {merged_path}")

        merged_ds = xa.open_mfdataset(filepaths, chunks={"time": 100},
            concat_dim=concat_dim, combine="nested").sortby("time", ascending=True)
        merged_ds.to_netcdf(merged_path)
        return merged_ds
    else:
        print(f"{merged_path} already exists.")

        

In [None]:
merge_files_in_dir("lustre_scratch/datasets/era5/surface_net_solar_radiation/daily_means")

In [None]:
hourly_means_to_daily("lustre_scratch/datasets/era5/surface_net_solar_radiation")

In [None]:
xa.open_dataset("/Users/orlandotimmerman/Downloads/19930101022011-NCEI-L3C_GHRSST-SSTskin-AVHRR_Pathfinder-PFV5.3_NOAA11_G_1993001_night-v02.0-fv01.0.nc")

In [None]:
test_file = xa.open_dataset("lustre_scratch/datasets/era5/surface_net_solar_radiation/VAR_surface_net_solar_radiation_LATS_-10_-17_LONS_142_147_YEAR_1993.netcdf", 
    # chunks={"latitude": 10, "longitude": 10}
    chunks = {"time": 100}
    )

In [None]:
test_file

In [None]:
test_file.resample(time="1D").mean()

In [None]:
np.arange(1993,2021)

In [5]:
# os.chdir(os.path.expanduser("lustre_scratch/datasets/era5test"))

os.chdir(os.path.expanduser("~"))


fetch_weather_data(
    "lustre_scratch/datasets/era5",
    ['evaporation', 'significant_height_of_combined_wind_waves_and_swell', 
    'surface_net_solar_radiation', 'surface_pressure'], 
    np.arange(1993,2021), format="netcdf"
)

Filepath already exists: /lustre_scratch/orlando-code/datasets/era5/evaporation/VAR_evaporation_LATS_-10_-17_LONS_142_147_YEAR_1993.netcdf
Filepath already exists: /lustre_scratch/orlando-code/datasets/era5/evaporation/VAR_evaporation_LATS_-10_-17_LONS_142_147_YEAR_1994.netcdf
Filepath already exists: /lustre_scratch/orlando-code/datasets/era5/evaporation/VAR_evaporation_LATS_-10_-17_LONS_142_147_YEAR_1995.netcdf
Filepath already exists: /lustre_scratch/orlando-code/datasets/era5/evaporation/VAR_evaporation_LATS_-10_-17_LONS_142_147_YEAR_1996.netcdf
Filepath already exists: /lustre_scratch/orlando-code/datasets/era5/evaporation/VAR_evaporation_LATS_-10_-17_LONS_142_147_YEAR_1997.netcdf
Filepath already exists: /lustre_scratch/orlando-code/datasets/era5/evaporation/VAR_evaporation_LATS_-10_-17_LONS_142_147_YEAR_1998.netcdf
Filepath already exists: /lustre_scratch/orlando-code/datasets/era5/evaporation/VAR_evaporation_LATS_-10_-17_LONS_142_147_YEAR_1999.netcdf
Filepath already exists: /l

2023-06-09 17:07:29,165 INFO Welcome to the CDS
2023-06-09 17:07:29,167 INFO Sending request to https://cds.climate.copernicus.eu/api/v2/resources/reanalysis-era5-single-levels
2023-06-09 17:07:29,451 INFO Request is completed
2023-06-09 17:07:29,452 INFO Downloading https://download-0006-clone.copernicus-climate.eu/cache-compute-0006/cache/data3/adaptor.mars.internal-1686314405.472787-20106-4-594420a8-d78b-43f1-9ec7-0142107c7534.nc to /lustre_scratch/orlando-code/datasets/era5/surface_pressure/VAR_surface_pressure_LATS_-10_-17_LONS_142_147_YEAR_1995.netcdf (10.2M)
2023-06-09 17:07:30,249 INFO Download rate 12.8M/s  
2023-06-09 17:07:30,348 INFO Welcome to the CDS
2023-06-09 17:07:30,349 INFO Sending request to https://cds.climate.copernicus.eu/api/v2/resources/reanalysis-era5-single-levels
2023-06-09 17:07:30,446 INFO Request is queued
2023-06-09 17:07:31,496 INFO Request is running
2023-06-09 17:26:09,621 INFO Request is completed
2023-06-09 17:26:09,623 INFO Downloading https://down

In [None]:
lustre_scratch/datasets/era5test

In [None]:
xa.open_dataset("/Users/orlandotimmerman/Desktop/test2/air_density_over_the_oceans/air_density_over_the_oceans_1993.netcdf", engine="netcdf4")

In [None]:
dir = "/Users/orlandotimmerman/Desktop/test2/air_density_over_the_oceans"
files = file_ops.return_list_filepaths(dir, "netcdf")
# files
xa.open_mfdataset(files, chunks={"time": 100})

In [None]:
min(lat_lims)

In [None]:
lat_lims=(-10,-17)
lon_lims=(142,147)

area = [max(lat_lims), min(lon_lims), min(lat_lims), max(lon_lims)]

In [None]:
area

In [None]:
utils.pad_number_with_zeros(np.arange(0,24)[0])

In [None]:
import numpy as np

# Create a 3D NumPy array with NaN values
arr = np.array([
    [[1, 2, 3], [4, 5, 2]],
    [[np.nan, 7, 8], [np.nan, 3, 11]]
])

# Check for NaN values along all dimensions (s, v, and t)
is_nan = np.isnan(arr).all(axis=(1))

In [None]:
# Add a new column (v+1) with binary values based on the NaN check
new_col = np.where(is_nan, 1, 0)
new_col

In [None]:

arr_with_new_col = np.concatenate((arr, new_col[:, np.newaxis, np.newaxis]), axis=1)

# Print the original array and the array with the new column
print("Original Array:")
print(arr)
print("\nArray with New Column:")
print(arr_with_new_col)

In [None]:
### EXAMPLE: https://nctoolkit.readthedocs.io/en/latest/interpolation.html
ds1 = nc.open_thredds("https://psl.noaa.gov/thredds/dodsC/Datasets/COBE2/sst.mon.mean.nc")
ds1.subset(timestep = 0)
ds1.subset(lat = [0, 90])

ds2 = nc.open_thredds("https://psl.noaa.gov/thredds/dodsC/Datasets/COBE2/sst.mon.mean.nc")
ds2.subset(timestep = 0)
ds2.regrid(ds1, recycle=True)

In [None]:
class MyDataset:
    """Handle the variety of datasets required to test and train model"""
    # TODO: add in declaration of filepath root
    def __init__(self):
        self.datasets = {}
        self.files_location = Path()
        # fetching external functions

    def set_location(self, location="remote"):
        if location == "remote":
            # change directory to home. TODO: make less hacky
            os.chdir("/home/jovyan")
            self.files_location = Path("lustre_scratch/datasets/")
        elif location == "local":
            self.files_location = directories.get_volume_dir()
        else:
            raise ValueError

    def get_location(self):
        return self.files_location

    def add_dataset(self, name, data):
        self.datasets[name] = data

    def get_dataset(self, name):
        return self.datasets.get(name, None)

    def remove_dataset(self, name):
        if name in self.datasets:
            del self.datasets[name]

    def list_datasets(self):
        return list(self.datasets.keys())

In [None]:
ds_man = MyDataset()

# add datasets
ds_man.set_location("remote")

ds_man.add_dataset(
    "monthly_climate_1_12", xa.open_dataset(
        ds_man.get_location() / "global_ocean_reanalysis/monthly_means/coral_climate_1_12.nc")
)

coral_climate_feature_vars = list(
    set(ds_man.get_dataset("monthly_climate_1_12").data_vars) - {'spatial_ref', 'coral_algae_1-12_degree', 'output'})
ds_man.add_dataset(
    "monthly_climate_features", ds_man.get_dataset("monthly_climate_1_12")[coral_climate_feature_vars]
)

# ds_man.add_dataset(
#     "monthly_climate_1_12_y_np", np.array(ds_man.get_dataset("monthly_climate_1_12")["coral_algae_1-12_degree"].isel(time=-1)).reshape(-1, 1)
# )

ds_man.add_dataset(
    "monthly_climate_1_12_X_y_np", filter_out_nans(
        spatial_data.xa_ds_to_3d_numpy(ds_man.get_dataset("monthly_climate_1_12")), 
        np.array(ds_man.get_dataset("monthly_climate_1_12")["coral_algae_1-12_degree"].isel(time=-1)).reshape(-1, 1))
)

ds_man.add_dataset(
    "monthly_climate_1_12_X_np", ds_man.get_dataset("monthly_climate_1_12_X_y_np")[0]
)

ds_man.add_dataset(
    "monthly_climate_1_12_y_np", ds_man.get_dataset("monthly_climate_1_12_X_y_np")[1]
)

ds_man.add_dataset(
    "daily_climate_1_12", xa.open_dataset(
        Path(ds_man.get_location() / "global_ocean_reanalysis/daily_means/dailies_combined.nc"))
)

# same target as monthly
ds_man.add_dataset(
    "daily_climate_1_12_y_np", ds_man.get_dataset("monthly_climate_1_12_y_np")
)

ds_man.add_dataset(
    "bathymetry_A", rio.open_rasterio(
        rasterio.open(ds_man.get_location() / "bathymetry/GBR_30m/Great_Barrier_Reef_A_2020_30m_MSL_cog.tif"),
        ).rename("bathymetry_A").rename({"x": "longitude", "y": "latitude"})
)

In [None]:
def sample_spatial_batch(xa_ds: xa.Dataset, lat_lon_starts: tuple=(0,0), window_dims: tuple[int,int] = (6,6), 
    coord_range: tuple[float]=None, variables: list[str] = None) -> np.ndarray:
    """Sample a spatial batch from an xarray Dataset.

    Parameters
    ----------
        xa_ds (xa.Dataset): The input xarray Dataset.
        lat_lon_starts (tuple): Tuple specifying the starting latitude and longitude indices of the batch.
        window_dims (tuple[int, int]): Tuple specifying the dimensions (number of cells) of the spatial window.
        coord_range (tuple[float]): Tuple specifying the latitude and longitude range (in degrees) of the spatial 
            window. If provided, it overrides the window_dims parameter.
        variables (list[str]): List of variable names to include in the spatial batch. If None, includes all variables.

    Returns
    -------
        np.ndarray: The sampled spatial batch as a NumPy array.

    Notes
    -----
        - The function selects a subset of the input dataset based on the provided latitude, longitude indices, and window dimensions.
        - If a coord_range is provided, it is used to compute the latitude and longitude indices of the spatial window.
        - The function returns the selected subset as a NumPy array.

    Example
    -------
        # Sample a spatial batch from an xarray Dataset
        dataset = ...
        lat_lon_starts = (2, 3)
        window_dims = (6, 6)
        coord_range = (2.5, 3.5)
        variables = ['var1', 'var2', 'var3']
        spatial_batch = sample_spatial_batch(dataset, lat_lon_starts, window_dims, coord_range, variables)
    """
    # N.B. have to be careful when providing coordinate ranges for areas with negative coords. TODO: make universal
    lat_start, lon_start = lat_lon_starts[0], lat_lon_starts[1]
    if not coord_range:
        subset = xa_ds.isel({"latitude": slice(lat_start,window_dims[0]), 
                            "longitude": slice(lon_start,window_dims[1])})
    else:
        lat_cells, lon_cells = coord_range[0], coord_range[1]
        subset = xa_ds.sel({"latitude": slice(lat_start,lat_start+lat_cells), 
                            "longitude": slice(lon_start,lon_start+lon_cells)})

    lat_slice = subset["latitude"].values
    lon_slice = subset["longitude"].values
    time_slice = subset["time"].values

    return subset, {"latitude": lat_slice, "longitude": lon_slice, "time": time_slice}


In [None]:
# load bathymetry
bath_A = ds_man.get_dataset("bathymetry_A")
bath_A

In [None]:
def degrees_to_distances(
    target_lat_res: float,
    target_lon_res: float = None,
    approx_lat: float = -18,
    approx_lon: float = 145,
) -> tuple[float]:
    """TODO: docstring"""
    start_coord = (approx_lat, approx_lon)
    lat_end_coord = (approx_lat + target_lat_res, approx_lon)
    # if both lat and lon resolutions specified
    if target_lon_res:
        lon_end_coord = (approx_lat, approx_lon + target_lon_res)
    else:
        lon_end_coord = (approx_lat, approx_lon + target_lat_res)

    return (haversine.haversine(start_coord, lat_end_coord, unit=haversine.Unit.METERS), 
        haversine.haversine(start_coord, lon_end_coord, unit=haversine.Unit.METERS))


def distance_to_degrees(
    distance_lat: float, distance_lon: float = None, approx_lat: float = -18, approx_lon: float = 145
) -> tuple[float, float, float]:
    # TODO: enable specification of distance in different orthogonal directions
    """Converts a distance in meters to the corresponding distance in degrees, given an approximate location on Earth.

    Parameters
    ----------
    distance (float): The distance in meters.
    approx_lat (float, optional): The approximate latitude of the location in degrees. Defaults to -18.0.
    approx_lon (float, optional): The approximate longitude of the location in degrees. Defaults to 145.0.

    Returns
    -------
    float: The corresponding distance in degrees.
    """
    # if distance_lon not provided, assume the same as distance_lat
    if not distance_lon:
        distance_lon = distance_lat

    degrees_lat = haversine.inverse_haversine(
        (approx_lat, approx_lon),
        distance_lat,
        haversine.Direction.SOUTH,
        unit=haversine.Unit.METERS)
    
    degrees_lon = haversine.inverse_haversine(
        (approx_lat, approx_lon),
        distance_lon,
        haversine.Direction.WEST,
        unit=haversine.Unit.METERS)

    # calculate the coordinates 'distance' meters to the southwest (chosen to give measure of both lat and lon)
    av_distance = (distance_lat+distance_lon)/2
    (lat_deg, lon_deg) = haversine.inverse_haversine(
        (approx_lat, approx_lon),
        av_distance,
        haversine.Direction.SOUTHWEST,
        unit=haversine.Unit.METERS,
    )
    delta_lat, delta_lon = abs(lat_deg - approx_lat), abs(lon_deg - approx_lon)
    # return hypotenuse (encapsulates difference in both lat and lon)
    return (np.subtract((approx_lat, approx_lon), degrees_lat)[0], np.subtract((approx_lat, approx_lon), degrees_lon)[1],
        np.hypot(delta_lat, delta_lon))

In [None]:
# degrees_to_distances(0.00898315)
distance_to_degrees(4000)

In [None]:
# 1 km. Struggles displaying/processing 100m, but have yet to try saving to this/inferring
_,_,av_degrees = distance_to_degrees(1000)
coarsened_bath_A = spatial_data.upsample_xarray_to_target(bath_A, av_degrees)
# im = coarsened_bath_A.plot(ax=ax)

spatial_plots.plot_DEM(coarsened_bath_A, f" DEM upsampled to {target_resolution} degrees", vmin=-100, vmax=0)
# spatial_plots.format_spatial_plot(im, fig, ax, f"Upsampled to {target_resolution} degrees")

In [None]:
# carving off small slice

# 1km upscale ground truth, bathymetry

# regrid ground truth, bathymetry, climate 

In [None]:
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
c

In [None]:
c.value()

In [None]:
ds_man.add_dataset(
    "bathymetry_A", rio.open_rasterio(
        rasterio.open(ds_man.get_location() / "bathymetry/GBR_30m/test_A.nc"),
        ).rename("bathymetry_A")
)

In [None]:
bath_A = ds_man.get_dataset("bathymetry_A")
bath_A = bath_A.rename({"x": "longitude", "y": "latitude"})
bath_A

In [None]:
ds_climate_monthly_1_12 = ds_man.get_dataset("monthly_climate_1_12")
ds_climate_monthly_1_12_small = ds_climate_monthly_1_12.sel(latitude=slice(-10,-10.1), longitude=slice(142,142.1))

In [None]:
data_var_list = list(ds_climate_monthly_1_12_small.data_vars)
variables = out + data_var_list

In [None]:
ds_climate_monthly_1_12_small = ds_climate_monthly_1_12_small[variables]
ds_climate_monthly_1_12_small

In [None]:
climate_path = str(ds_man.get_location() / "global_ocean_reanalysis/daily_means/dailies_combined.nc")

In [None]:
climate_xa = xa.open_dataset(climate_path)
climate_xa

In [None]:
restrict_xa = climate_xa.isel(time=slice(0,10), latitude=slice(0,2), longitude=slice(0,2), depth=0)
restrict_xa.to_netcdf("lustre_scratch/datasets/global_ocean_reanalysis/daily_means/restrict_xa.nc")
restrict_nc = nc.open_data("lustre_scratch/datasets/global_ocean_reanalysis/daily_means/restrict_xa.nc")

In [None]:
ds = nc.open_data(climate_path)

In [None]:
ds_static.times

In [None]:
restrict_nc.regrid(restrict_nc)

In [None]:
out

In [None]:
ds_climate_monthly_1_12_small.dims

In [None]:
ds_bath.dims

In [None]:
ds_climate = nc.from_xarray(ds_climate_monthly_1_12_small)
ds_bath = nc.from_xarray(bath_A)

In [None]:
ds_climate.months

In [None]:
ds_bath.variables

In [None]:
ds_climate.regrid(ds_climate)

In [None]:
import cf

In [None]:
cf.example_fields(2)

In [None]:
bath_A.to_netcdf(ds_man.get_location() / "bathymetry/GBR_30m/test_A.nc")

In [None]:
dst, 

In [None]:
dst = cf.read(ds_man.get_location() / "bathymetry/GBR_30m/test_A.nc")

In [None]:
dst.select_by_ncvar("bathymetry_A")

In [None]:
srces = cf.read(ds_man.get_location() / "global_ocean_reanalysis/monthly_means/coral_climate_1_12.nc")

In [None]:
srces[5]

In [None]:
lat = cf.DimensionCoordinate(data = cf.Data(np.arange(-17, -10, 0.01), "lats"))
lon = cf.DimensionCoordinate(data = cf.Data(np.arange(142, 147, 0.01), "lons"))

In [None]:
g = srces[5].regrids(dst, method='linear')


In [None]:
bath_A = ds_man.get_dataset("bathymetry_A")
bath_A = bath_A.rename({"x": "longitude", "y": "latitude"})
bath_A

In [None]:
bath_A.values[:, None, :].shape

In [None]:
ds_climate_monthly_1_12 = ds_man.get_dataset("monthly_climate_1_12")
ds_climate_monthly_1_12_A = ds_climate_monthly_1_12.sel(latitude=slice(-10,-17), longitude=slice(142,147))

In [None]:
ds_climate_monthly_1_12_A["mlotst"].isel(time=-1).plot()

In [None]:
### Testing interp_like
climate_m_small = ds_climate_monthly_1_12_A.sel(latitude=slice(-10,-10.2), longitude=slice(142,142.2))
bath_A_small = bath_A.sel(latitude=slice(-10,-10.2), longitude=slice(145,145.2))

In [None]:
check = climate_m_small.interp_like(bath_A_small)

In [None]:
check

In [None]:
check["mlotst"].isel(time=-1).plot()

In [None]:
ds_climate_monthly_1_12_A, bath_A = xa.broadcast(ds_climate_monthly_1_12_A, bath_A)

In [None]:
# Assuming you have the high-resolution and low-resolution xarray datasets named 'bath_A' and 'ds_climate_monthly_1_12_A'

# Extract the latitude and longitude coordinates from the high-resolution dataset
lat_hr = bath_A['latitude'].values
lon_hr = bath_A['longitude'].values

# Extract the latitude and longitude coordinates from the low-resolution dataset
lat_lr = ds_climate_monthly_1_12_A['latitude'].values
lon_lr = ds_climate_monthly_1_12_A['longitude'].values

# Get the values of the low-resolution array
ds_climate_monthly_1_12_A_values = ds_climate_monthly_1_12_A['bottomT'].values

# Create a new high-resolution array with the same values for each time step
bath_A_values = np.broadcast_to(bath_A.values[:, None, :], ds_climate_monthly_1_12_A_values.shape)


# # Create a 2D meshgrid for the high-resolution coordinates
# lon_hr_mesh, lat_hr_mesh = np.meshgrid(lon_hr, lat_hr)

# # Create a 2D meshgrid for the low-resolution coordinates
# lon_lr_mesh, lat_lr_mesh = np.meshgrid(lon_lr, lat_lr)

# # Get the values of the low-resolution array
# low_res_values = ds_climate_monthly_1_12_A['bottomT'].values

In [None]:
low_res_values.shape

In [None]:
# Interpolate the low-resolution values onto the high-resolution grid
interp_func = interp2d(lon_lr_mesh, lat_lr_mesh, low_res_values, kind='linear')
high_res_values = interp_func(lon_hr, lat_hr)

# # Create a new xarray dataset with the interpolated values
# combined_ds = xa.Dataset(
#     {
#         'data_variable': (('lat', 'lon'), high_res_values),
#     },
#     coords={'lat': lat_hr, 'lon': lon_hr}
# )

In [None]:
# will need to try with GEE or do in smaller chunks
# cut out nan vars before this. Huge memory consumption and not even the daily...
climate_m_small = ds_climate_monthly_1_12_A.sel(latitude=slice(-10,-10.2), longitude=slice(142,142.2))
bath_A_small = bath_A.sel(latitude=slice(-10,-10.2), longitude=slice(142,142.2))
climate_resolution_monthly_reindexed = climate_m_small.reindex_like(bath_A_small, method='nearest')
# takes ~ 40s

In [None]:
ds_climate_daily_1_12 = ds_man.get_dataset("daily_climate_1_12")
ds_climate_daily_1_12_A = ds_climate_daily_1_12.sel(latitude=slice(-10,-17), longitude=slice(142,147))

In [None]:
# completely destroys kernel memory usage
climate_d_small = ds_climate_daily_1_12_A.sel(latitude=slice(-10,-10.2), longitude=slice(142,142.2))
climate_resolution_daily_reindexed = climate_d_small.reindex_like(bath_A_small, method='nearest')

In [None]:
test = climate_resolution_monthly_reindexed.merge(bath_A_small)
# climate_resolution_daily_reindexed

In [None]:
climate_resolution_reindexed

In [None]:
climate_resolution_reindexed = 0

In [None]:
climate_resolution_broadcasted = climate_resolution_reindexed.broadcast_like(test)
climate_resolution_broadcasted