In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append("../")


In [2]:
# choose whether to work on a remote machine
location = "remote"
# location = "local"

if location == "remote":
    # change this line to the where the GitHub repository is located
    os.chdir("/lustre_scratch/orlando-code/coralshift/")

# Data storage setup

In [3]:
# import relevant packages

from __future__ import annotations

from pathlib import Path
import xarray as xa
import numpy as np
# import math as m
# import pandas as pd
# import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# import wandb
from tqdm import tqdm
from sklearn import model_selection
from sklearn.preprocessing import normalize
from scipy.interpolate import interp2d
from sklearn.utils import class_weight
from scipy.ndimage import gaussian_gradient_magnitude
# import xbatcher

import rasterio
from rasterio.plot import show
import rioxarray as rio

# from bs4 import BeautifulSoup
# import requests


#issues with numpy deprecation in pytorch_env
from coralshift.processing import spatial_data
from coralshift.utils import file_ops, directories, utils
# from coralshift.plotting import spatial_plots, model_results
from coralshift.dataloading import data_structure, climate_data, bathymetry, reef_extent

In [None]:
ds_man = data_structure.MyDatasets()
ds_man.set_location(location)

## Specify your area of interest

The availability of high-resolution (30m) bathymetric data means that areas of interest are currently confided to 4 areas on the Great Barrier Reef (GBR). The following code generates a geoJSON file specifying which area (A-D) you'd like to investigate:

| Reef Area Name                	| Latitudes 	| Longitudes 	|
|-------------------------------	|-----------	|------------	|
| Great Barrier Reef A 2020 30m 	| 10-17°S   	| 143-147°E  	|
| Great Barrier Reef B 2020 30m 	| 16-23°S   	| 144-149°E  	|
| Great Barrier Reef C 2020 30m 	| 18-24°S   	| 148-154°E  	|
| Great Barrier Reef D 2020 30m 	| 23-29°S   	| 150-156°E  	|


Download your required area from here: https://ecat.ga.gov.au/geonetwork/srv/eng/catalog.search#/metadata/115066

Due to the computational load required to run ML models on such a high resolution data, bathymetric data is currently upsampled to 4km grid cells and areas are limited to a quarter of the GBR's total area.

In [4]:
# choose resolution (should be above 1000m for processing in decent time)
target_resolution_m, target_resolution_d = spatial_data.choose_resolution(
    resolution=1000, unit="m")

# convert distance to degrees:
# _,_,av_degrees = spatial_data.distance_to_degrees(target_resolution)
print(f"Data will be downsampled to {target_resolution_d:.05f} degrees (~{target_resolution_m}m).")

Data will be downsampled to 0.00923 degrees (~1000m).


## Bathymetry

In [5]:
# select your area
area_name = "A"

# download .tif if not downloaded aready
reef_areas = bathymetry.ensure_bathymetry_downloaded(area_name)
# cast tif to processed xarray
xa_bath = spatial_data.tif_to_xarray(
    directories.get_bathymetry_datasets_dir() / reef_areas.get_filename(area_name), reef_areas.get_xarray_name(area_name)
    )
# bath_name = f"{reef_areas.get_xarray_name(area_name)}"
# save to files
bath_name = file_ops.save_nc(
    directories.get_bathymetry_datasets_dir(),
    f"{reef_areas.get_xarray_name(area_name)}", 
    xa_bath)


# upsample to specified resolution
xa_bath_upsampled = spatial_data.upsample_xarray_to_target(xa_bath, target_resolution=target_resolution_d)
# upsampled_bath_name = file_ops.replace_dot_with_dash(
#     f"{reef_areas.get_xarray_name(area_name)}_{target_resolution_d:.05f}_upsampled")
# save file TODO: return name, move replacedot to utils
upsampled_bath_name = file_ops.save_nc(
    directories.get_bathymetry_datasets_dir(), 
    f"{reef_areas.get_xarray_name(area_name)}_{target_resolution_d:.05f}_upsampled", 
    xa_bath_upsampled)

Already exists: /lustre_scratch/orlando-code/datasets/bathymetry/Great_Barrier_Reef_A_2020_30m_MSL_cog.tif
bathymetry_A already exists in /lustre_scratch/orlando-code/datasets/bathymetry.
bathymetry_A_0-00923_upsampled already exists in /lustre_scratch/orlando-code/datasets/bathymetry.


In [None]:
# TODO: better visualisation, investigate speed of plotting. Cell seems not to load when calling values...
# xa_bath_upsampled.plot()

## Calculate slopes

In [6]:
def calculate_and_save_gradient_magnitude(dataarray, sigma, output_path):
    # Calculate the gradient magnitude using gaussian_gradient_magnitude
    gradient_magnitude = gaussian_gradient_magnitude(dataarray.values, sigma=sigma)

    # Create a new DataArray with the gradient magnitude
    grad_magnitude_dataarray = xr.DataArray(
        gradient_magnitude,
        coords=dataarray.coords,
        dims=dataarray.dims,
        attrs=dataarray.attrs
    )

In [None]:
calculate_and_save_gradient_magnitude(xa_bath_upsampled)

In [None]:
def generate_gradient_nc(
    bathymetry_name,
    kernel_size: int = 1,
    return_array: bool = False,
) -> xa.DataArray:
    gradient_dir = directories.get_gradients_dir()
    gradient_path = gradient_dir / f"{bathymetry_name}_gradients"

    if not gradient_path.is_file():
        xa_bath = xa.open_dataset(
            (directories.get_bathymetry_datasets_dir() / bathymetry_name).with_suffix(
                ".nc"
            )
        )
        bath_vals = xa_bath.compute()
        gradients = gaussian_gradient_magnitude(bath_vals, sigma=kernel_size)
        xa_gradients = xa_bath.copy(data={"gradients": gradients})
        file_ops.save_nc(gradient_dir, gradient_path.stem, xa_gradients)
    else:
        gradients = xa.open_dataset(gradient_path)

    if return_array:
        return gradients

In [None]:
# xa_bath.values
# xa_bath_upsampled.values
xa_bath_upsampled_vals = xa_bath_upsampled.compute()

In [None]:
gradients.shape

In [None]:
# def calculate_cell_slopes(values, sigma: int = 1):
#     grad_func = gaussian_gradient_magnitude(values, sigma=sigma)

#     return xa.apply_ufunc(grad_func, values, kwargs={"sigma": sigma}, dask="parallelized", vectorize=True)

# out = calculate_cell_slopes(xa_bath_upsampled)

def apply_gaussian_gradient_magnitude(da: xr.DataArray, sigma: float) -> xr.DataArray:
    # Convert DataArray to numpy array
    array = da.values
    
    # Apply gaussian_gradient_magnitude to the array
    gradient_array = gaussian_gradient_magnitude(array, sigma=sigma)
    
    # Create a new DataArray with the updated values
    gradient_da = xa.DataArray(gradient_array, coords=da.coords, dims=da.dims, attrs=da.attrs)
    
    return gradient_da

# def apply_gaussian_gradient_magnitude(da: xr.DataArray, sigma: float) -> xr.DataArray:
#     # Convert DataArray to Dask array
#     dask_array = da.data
    
#     # Apply gaussian_gradient_magnitude to the Dask array
#     gradient_array = da.map_blocks(lambda block: gaussian_gradient_magnitude(block, sigma=sigma))
    
#     # Create a new DataArray with the Dask array
#     gradient_da = xr.DataArray(gradient_array, coords=da.coords, dims=da.dims, attrs=da.attrs)
    
#     return gradient_da

result = apply_gaussian_gradient_magnitude(xa_bath_upsampled, sigma=1.0)


In [None]:
f,a =plt.subplots(1,2, figsize=[14,5])
result.plot(ax=a[0])
xa_bath_upsampled_vals.plot(ax=a[1])

In [None]:
gradients = gaussian_gradient_magnitude(xa_bath_upsampled_vals, sigma=1)
xa_gradients = xa_bath_upsampled.copy(data={"gradients": gradients})


In [None]:
# calculate slopes of upsampled bathymetry and save to nc file
generate_gradient_nc(xa_bath_upsampled_vals, return_array=True)

## Coral ground truth: Allen Coral Atlas


There is currently no API for accessing data directly from your local machine. Please follow the instructions* below:
1. Make an account on the [Allen Coral Atlas](https://allencoralatlas.org/atlas/#6.00/-13.5257/144.5000) webpage
2. Generate a geojson file using the code cell below (generated in the `reef_baseline` directory)

*Instructions correct as of 30.06.23

In [None]:
# generate geojson file in reef_baseline directory for download from the Allen Coral Atlas
geojson_path = reef_extent.generate_area_geojson(
    area_class = reef_areas, area_name=area_name, save_dir=directories.get_reef_baseline_dir())

3. Upload the geojson file via:

    \> My Areas > Upload a GeoJSON or KML file
4. Wait for the area to be processed, and select "Benthic Map (OGC GeoPackage (.gpkg))". Sign the terms and conditions 
and select "Prepare Download". After ~two minutes a sequence of emails will arrive notifying you that your download is ready.
5. Download the file and unzip it using a unzipping utility. Then,
    - add the `benthic.geojson` to the `reef_baseline` directory
    
    
6. Continue with the subsequent code cells.

----

You have now downloaded:

**`benthic.gpkg`**

This is a dataframe of Shapely objects ("geometry" polygons) defining the boundaries of different benthic classes:

| Class           	| Number of polygons 	|
|-----------------	|--------------------	|
| Coral/Algae     	| 877787             	|
| Rock            	| 766391             	|
| Rubble          	| 568041             	|
| Sand            	| 518805             	|
| Microalgal Mats 	| 27569              	|

In [None]:
benthic_df = file_ops.check_pkl_else_read_gpkg(directories.get_reef_baseline_dir(), filename = "benthic.pkl")
benthic_df.head()

### Rasterize polygons

Rasterized arrays are necessary to process the geospatial data e.g. to align different gridcells. Doing this locally through rasterio requires such significant compute that clouds computing is the only reasonable option. A JavaScript file for use with Google Earth Engine (GEE) is provided in the `coralshift` repo. Visit [this page](https://developers.google.com/earth-engine/guides/getstarted) for information regarding setting up a GEE account and getting started.

GEE requires shapefile (.shp) to ingest data. This is generated in the following cell:


In [None]:
# process df to gpd.GeoDataFrame. We are interested only in the "Coral/Algae" class, so gdf is limited to these rows
gdf_coral = reef_extent.process_benthic_pd(benthic_df)
# save as shapely file for rasterisation in GEE, if not already present
reef_extent.generate_coral_shp(gdf_coral)

1. Ingest the shapefile (and all accompanying files: .cpg, .dbf, .prj, .shx) as an asset.
2. Import the subsequent `Table` into the script.
3. Run the script, and submit the `coral_A_north` and `coral_A_south` tasks. Sit back and wait! After ~1 hour the rasters will be available to download from your Google Drive as GeoTIFFS: after this, please add them to the reef_baseline directory and carry on with the following cells.

In [None]:
# fetch data
gbr_30_files = file_ops.return_list_filepaths(gbr_30_dir, ".tif")
# generate dictionary of file names and arrays: {filename: xarray.DataArray, ...}
gbr_30_dict_preprocess = spatial_data.tifs_to_xa_array_dict(gbr_30_files)
# process xa_arrays
# gbr_30_dict = spatial_data.process_xa_arrays_in_dict(gbr_30_dict_preprocess)

In [None]:
tif_eg = rio.open_rasterio(rasterio.open(directories.get_reef_baseline_dir() / "gt_tifs/coral_raster_1000m.tif")).squeeze("band")
gt_4km = xa.open_dataarray(directories.get_reef_baseline_dir() / "gt_tifs/gt_nc_dir/concatenated_0-03691_degree.nc")

In [None]:
xa_4km

In [None]:
xa_4km

In [None]:
f, ax = plt.subplots(1,2, figsize=[14,8])
tif_eg.sel({"x":slice(142,142.4), "y":slice(-10.6,-11)}).plot(ax=ax[0])
xa_4km["__xarray_dataarray_variable__"].sel({"longitude":slice(142,142.4), "latitude":slice(-11,-10.6)}).plot(ax=ax[1])
ax[0].set_aspect("equal")
ax[1].set_aspect("equal")

In [None]:
test=xa.open_dataarray(directories.get_reef_baseline_dir() / "coral_A_south-0000065536-0000065536.nc")

In [None]:
xa.where(test>0, 1, test).plot()

In [None]:
# def process_coral_gt_tifs(tif_dir_name=None, target_resolution_d:float=None):
#     if not tif_dir_name:
#         tif_dir = directories.get_reef_baseline_dir()
#     else:
#         tif_dir = directories.get_reef_baseline_dir() / tif_dir_name
    
#     nc_dir = file_ops.guarantee_existence(tif_dir / "gt_nc_dir")
#     # save tifs to ncs in new dir
#     tif_paths = tifs_to_ncs(nc_dir, target_resolution_d)
#     # get list of nc paths in dir
#     xa_arrays_list = [tif_to_xa_array(tif_path) for tif_path in tif_paths]
#     # merge ncs into one mega nc file
#     if len(xa_arrays_list) > 1:
#         concatted = xa.concat(xa_arrays_list, dim=["latitude","longitude"])
#     else:
#         concatted = xa_arrays_list[0]
#     file_ops.save_nc(nc_dir, f"concatenated_{target_resolution_d:.05f}_degree", concatted)


# def tifs_to_ncs(nc_dir: Path | list[str], target_resolution_d: float=None) -> None:

#     tif_dir = nc_dir.parent
#     tif_paths = file_ops.return_list_filepaths(tif_dir, ".tif")
#     xa_array_dict = {}
#     for tif_path in tqdm(tif_paths, total=len(tif_paths), desc="Writing tifs to nc files"):
#         # filename = str(file_ops.get_n_last_subparts_path(tif, 1))
#         filename = tif_path.stem
#         tif_array = tif_to_xa_array(tif_path)
#         # xa_array_dict[filename] = tif_array.rename(filename)
#         if target_resolution_d:
#             tif_array = spatial_data.upsample_xarray_to_target(xa_array=tif_array, target_resolution=target_resolution_d)
#         # save array to nc file
#         file_ops.save_nc(tif_dir, filename, tif_array)

#     return tif_paths
#     print(f"All tifs converted to xarrays and stored as .nc files in {nc_dir}.")


# def tif_to_xa_array(tif_path) -> xa.DataArray:
#     return spatial_data.process_xa_d(rio.open_rasterio(rasterio.open(tif_path)))

In [None]:
reef_extent.process_coral_gt_tifs(tif_dir_name="gt_tifs", target_resolution_d=target_resolution_d)

In [None]:
tifs_to_ncs(gbr_30_dir, target_resolution_d=target_resolution_d)

In [None]:
tif_path = gbr_30_dir / "coral_A_south-0000000000-0000000000.tif"
check = tif_to_xa_array(tif_path)
check

In [None]:
check.sel({"longitude":slice(143.3,143.6), "latitude":slice(-13.7,-13.5)}).plot()

In [None]:
out = combine_tifs_to_xarray(gbr_30_dir)

In [None]:
out

In [None]:
out["coral_A_south-0000065536-0000032768.tif"].isel({"x":slice(0,100),"y":slice(0,100)}).plot()

In [None]:
np.sum(tif_eg.values)

In [None]:
tif_eg.sel({"x":slice(143.5,143.7), "y":slice(-13.6,-13.7)}).plot(cmap="binary")

In [None]:
gbr_30_dict_preprocess['coral_A_south-0000065536-0000032768.tif'].plot()

In [None]:
# TODO: visualisation

## Global Ocean Physics Reanalysis

The dataset metadata can be accessed [here](https://doi.org/10.48670/moi-00021).

### Download data

You're required to set up an account with the [Copernicus Marine Service](https://marine.copernicus.eu/). 


**Warning:**  this is a large amount of data for which the only way to gather it is to query the copernicus API via motu. Requests are queued, and request sizes are floated to the top of the queue. The following functions take advantage of this by splitting a single request up by date adn variable before amalgamating the files, but this can still take a **very long time**, and vary significantly depending on overall website traffic. For those who aren't interested in the entire database, it's highly recommended that you use the toy dataset provided as a `.npy` file in the GitHub repository.


In [6]:
# download monthly data. Can be adjusted to specify subset of variables, dates, and depths to download.
# Values generated here are those reported in the accompanying paper.
xa_cmems_monthly, cmems_monthly_path = climate_data.download_reanalysis(download_dir=directories.get_monthly_cmems_dir(), final_filename = "cmems_gopr_monthly",
    lat_lims = reef_areas.get_lat_lon_limits(area_name)[0], lon_lims = reef_areas.get_lat_lon_limits(area_name)[1], 
    product_id = "cmems_mod_glo_phy_my_0.083_P1M-m")   


Merged file already exists at /lustre_scratch/orlando-code/datasets/global_ocean_reanalysis/monthly_means/cmems_gopr_monthly.nc


In [7]:
# download daily data
xa_cmems_daily, cmems_daily_path = climate_data.download_reanalysis(download_dir=directories.get_daily_cmems_dir(), final_filename = "cmems_gopr_daily.nc",
    lat_lims = reef_areas.get_lat_lon_limits(area_name)[0], lon_lims = reef_areas.get_lat_lon_limits(area_name)[1], 
    product_id = "cmems_mod_glo_phy_my_0.083_P1D-m")   

Merged file already exists at /lustre_scratch/orlando-code/datasets/global_ocean_reanalysis/daily_means/cmems_gopr_daily.nc


### Spatially pad the data

TODO: add visual explanation

In [8]:
def spatially_buffer_timeseries(
    xa_ds: xa.Dataset,
    buffer_size: int = 3,
    exclude_vars: list[str] = ["spatial_ref", "coral_algae_1-12_degree"],
) -> xa.Dataset:
    """Applies a spatial buffer to each data variable in the xarray dataset.

    Parameters
        xa_ds (xarray.Dataset): Input xarray dataset.
        buffer_size (int): Buffer size in grid cells.
        exclude_vars (list[str]): List of variable names to exclude from buffering.

    Returns:
        xarray.Dataset: Xarray dataset with buffered data variables.
    """
    filtered_vars = [var for var in xa_ds.data_vars if var not in exclude_vars]

    buffered_ds = xa.Dataset()
    for data_var in tqdm(
        filtered_vars, desc=f"Buffering variables by {buffer_size} pixel(s)"
    ):
        buffered = xa.apply_ufunc(
            spatial_data.buffer_nans,
            xa_ds[data_var],
            input_core_dims=[[]],
            output_core_dims=[[]],
            kwargs={"size": buffer_size},
            dask="parallelized",
        )
        buffered_ds[data_var] = buffered

    return buffered_ds


def spatially_buffer_nc_file(nc_path: Path | str, buffer_size: int = 3):
    # TODO: specify distance buffer
    nc_path = Path(nc_path)
    buffered_name = nc_path.stem + f"_buffered_{buffer_size}_pixel"
    buffered_path = (nc_path.parent / buffered_name).with_suffix(".nc")

    # if buffered file doesn't already exist
    if not buffered_path.is_file():
        nc_file = xa.open_dataset(nc_path)
        buffered_ds = spatially_buffer_timeseries(
            nc_file, buffer_size=buffer_size
        )
        buffered_ds.to_netcdf(buffered_path)
    else:
        buffered_ds = xa.open_dataset(buffered_path)
        print(
            f"Area buffered by {buffer_size} pixel(s) already exists at {buffered_path}."
        )

    return buffered_ds, buffered_path

In [9]:
# xa_cmems_monthly_buffered, _ = spatial_data.spatially_buffer_nc_file(cmems_monthly_path, buffer_size=5)
xa_cmems_daily_buffered, _ = spatially_buffer_nc_file(cmems_daily_path, buffer_size=5)

Buffering variables by 5 pixel(s):   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
spatial_plots.plot_DEM(xa.open_dataset(directories.get_monthly_cmems_dir() / "cmems_gopr_monthly_buffered_5_pixel")["mlotst"].isel(time=0), "")
# spatial_plots.plot_DEM(buffered["mlotst"].isel(time=0), "")

## Combine datasets into single netcdf file

In [None]:
# ground truth
gt_xa = xa.open_dataset(directories.get_reef_baseline_dir() / "gt_tifs/coral_raster_1000m.nc")
# gradient
# bath_xa = xa.open_dataset(directories.get_bathymetry_datasets_dir() / "gt_tifs/coral_raster_1000m.nc")
xa_bath_upsampled
# slopes
# cmems global ocean reanalysis
# 

In [None]:
ds_man = data_structure.MyDatasets()
ds_man.set_location(location)

noaa_features = ['mlotst', 'bottomT', 'uo', 'so', 'zos', 'thetao', 'vo']

# TODO: transparency in preprocessing to get to this (probably split into separate gt datarray)
ds_man.add_dataset(
    "monthly_climate_1_12", xa.open_dataset(
        ds_man.get_location() / "global_ocean_reanalysis/monthly_means/coral_climate_1_12.nc")
)

ds_man.add_datasets(
    ["monthly_climate_1_12_X", "monthly_climate_1_12_y"], 
        spatial_data.process_xa_ds_for_ml(ds_man.get_dataset("monthly_climate_1_12"), 
        feature_vars=noaa_features, gt_var="coral_algae_1-12_degree")
)

# TODO: handle depth
ds_man.add_dataset(
    "daily_climate_1_12", spatial_data.generate_and_add_gt_to_xa_d(xa.open_dataset(
        Path(ds_man.get_location() / "global_ocean_reanalysis/daily_means/dailies_combined.nc")).isel(depth=0),
        ds_man.get_dataset("monthly_climate_1_12")["coral_algae_1-12_degree"])
)

# TODO: streamline checking and saving process
daily_climate_1_12_X_file_path = ds_man.get_location() / "global_ocean_reanalysis/daily_means/daily_climate_1_12_X.npy"
# if daily_climate_1_12_X numpy array doesn't exist, generate and save
if not file_ops.check_file_exists(filepath = daily_climate_1_12_X_file_path):
    daily_climate_1_12_X = spatial_data.process_xa_ds_for_ml(ds_man.get_dataset("daily_climate_1_12"),
        feature_vars = noaa_features)
    np.save(daily_climate_1_12_X_file_path, daily_climate_1_12_X) 
    ds_man.add_dataset("daily_climate_1_12_X", np.load(daily_climate_1_12_X_file_path))
else:
    ds_man.add_dataset("daily_climate_1_12_X", np.load(daily_climate_1_12_X_file_path))

daily_climate_1_12_padded_1_file_path = ds_man.get_location() / "global_ocean_reanalysis/daily_means/daily_climate_1_12_padded_1.nc"
# if daily_climate_1_12_padded_1 .nc file doesn't exist, generate and save
if not file_ops.check_file_exists(filepath = daily_climate_1_12_padded_1_file_path):
    daily_climate_1_12_padded_1 = spatial_data.spatially_buffer_timeseries(
        ds_man.get_dataset("daily_climate_1_12"), buffer_size=1, exclude_vars = ["spatial_ref", "coral_algae_gt"])
    daily_climate_1_12_padded_1.to_netcdf(filepath = daily_climate_1_12_padded_1_file_path)
    ds_man.add_dataset("daily_climate_1_12_padded_1", xa.open_dataset(daily_climate_1_12_padded_1_file_path))
else:
    ds_man.add_dataset("daily_climate_1_12_padded_1", xa.open_dataset(daily_climate_1_12_padded_1_file_path))

# add in ground truth to padded
ds_man.add_dataset(
    "daily_climate_1_12_padded_1_gt", spatial_data.generate_and_add_gt_to_xa_d(
        ds_man.get_dataset("daily_climate_1_12_padded_1"),
        ds_man.get_dataset("monthly_climate_1_12")["coral_algae_1-12_degree"])
)

ds_man.add_dataset(
    "bathymetry_A", rio.open_rasterio(
        rasterio.open(ds_man.get_location() / "bathymetry/GBR_30m/Great_Barrier_Reef_A_2020_30m_MSL_cog.tif"),
        ).rename("bathymetry_A").rename({"x": "longitude", "y": "latitude"})
)

In [None]:
ds_man_ml = data_structure.MyDatasets()
ds_man_ml.set_location(location)
