# Example of OceanBench use

### First install needed libraries

In [None]:
!pip install --quiet condacolab einops pytorch-lightning cmocean pint-xarray numpy_groupies loguru cffi==1.15.1
!mamba env update --quiet -f /home/onyxia/work/oceanbench/environments/linux.yaml
!mamba install --yes --quiet -c conda-forge  -c pyviz kornia metpy xrft pyinterp funcy=1.18 fsspec=2022.11.0 dvc-s3=2.21.0 hvplot gcm_filters esmpy xesmf

!pip --quiet install  "git+https://github.com/jejjohnson/ocn-tools.git"  
!pip --quiet install torch==2.0.1+cpu torchvision==0.15.2+cpu torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cpu

!export PYTHONPATH="${PYTHONPATH}:oceanbench/"

# Fetch data using CMC

In [None]:
import copernicus_marine_client as cmc
query_keywords = ["Global Ocean Physics Reanalysis"]

query_result = cmc.describe(contains=['Global Ocean Physics Reanalysis'], include_datasets=True)

for product in query_result.get('products'):
    for dataset in product.get('datasets'):
        print(f"{dataset.get('dataset_id')} - {dataset.get('dataset_name')}")

for product in query_result.get('products'):
    for dataset in product.get('datasets'):
        if dataset.get('dataset_id') == "cmems_mod_glo_phy_my_0.083_P1D-m":
            chosen_dataset = dataset
print("---Variables---")
for variable in chosen_dataset.get("variables"):
    print(f"{variable.get('standard_name')} - {variable.get('short_name')}")

In [None]:
from datetime import datetime
start_datetime=datetime(2020,1,1)
end_datetime=datetime(2020,1,5)

test_data=cmc.load_xarray_dataset(
    dataset_id="cmems_mod_glo_phy_my_0.083_P1D-m", 
    variables=["uo","vo", "zos"],
    minimal_longitude=-27,
    maximal_longitude=0,
    minimal_latitude=21,
    maximal_latitude=39,
    minimal_depth=0.0,
    maximal_depth=2.0,
    start_datetime=start_datetime,
    end_datetime=end_datetime,
)

In [None]:
test_data

# XRDAPatcher (oceanbench patching/stitching tool) for Machine Learning

##  green part in:

![](scheme_ob.png)




In [None]:
import torch
from xrpatcher import XRDAPatcher

import matplotlib.pyplot as plt
import itertools
import xarray as xr
import collections
import numpy as np


In [None]:
class XrTorchDataset(torch.utils.data.Dataset):
    def __init__(self, patcher: XRDAPatcher, item_postpro=None):
        self.patcher = patcher
        self.postpro = item_postpro
        
    def __getitem__(self, idx):
        item = self.patcher[idx].load().values
        if self.postpro:
            item = self.postpro(item)
        return item
    
    def reconstruct_from_batches(self, batches, **rec_kws):
        return self.patcher.reconstruct([*itertools.chain(*batches)], **rec_kws)

    def __len__(self):
        return len(self.patcher)

In [None]:
# Preparing the training data
raw_data = test_data
TrainingItem = collections.namedtuple('TrainingItem', ('uo', 'vo'))
data = (
    raw_data[[*TrainingItem._fields]].isel(longitude=slice(None, 325), latitude=slice(None, 217), time=0, elevation=0)
    .sortby('longitude').sortby('latitude')
    .to_array().transpose('variable', 'latitude', 'longitude').load()
)

# Instantiate the patching logic
patches = dict(longitude=25, latitude=31)
patcher = XRDAPatcher(
    da=data,
    patches=patches,
    strides=patches, # No overlap
    check_full_scan=True
)


# Instantiate the 
torch_ds = XrTorchDataset(patcher, item_postpro=TrainingItem._make)
dataloader = torch.utils.data.DataLoader(torch_ds, batch_size=4, shuffle=False)


items = [torch_ds[i] for i in range(len(torch_ds))]
ex_item = items[0]
batch = next(iter(dataloader))
             
print(f"Item shape: {ex_item.uo.shape=}, {ex_item.vo.shape=}")
print(f"Batch shape: {batch.uo.shape=}, {batch.vo.shape=}")

In [None]:
data.plot(row='variable', figsize=(5, 6))


In [None]:
def plot_patches(items_to_plot, nbaxes=(7, 13)):
    fig, axs = plt.subplots(*nbaxes, figsize=(5, 5))

    for item, ax in zip( items_to_plot, [*itertools.chain(*reversed(axs))]):
        ax.imshow(item, cmap='RdBu_r', origin='lower')
        ax.set_xticks([], labels=None)
        ax.set_axis_off()
        ax.set_yticks([], labels=None)

        
print("Patches of uo")
plot_patches([i.uo for i in items] )
plt.show()

print("\n\nPatches of vo")
plot_patches([i.vo for i in items])


### Reconstructing the amplitude of the speed from the patches

In [None]:
rec_ds = torch_ds.reconstruct_from_batches((np.sqrt(batch.uo**2 + batch.vo**2) for batch in dataloader), dims_labels=['latitude', 'longitude'])
rec_ds.plot(figsize=(5, 3))

### Reconstructing the laplacian (~ vorticity) from the patches

In [None]:
rec_ds = torch_ds.reconstruct_from_batches(((np.diff(batch.uo, axis=1, prepend=0) + np.diff(batch.vo,axis=2, prepend=0)) for batch in dataloader), dims_labels=['latitude', 'longitude'])
rec_ds.plot(figsize=(5, 3))

### Above we see that the border of the patches creates artifact during the derivative: That can befixed by using overlapping patches with overlapping patches (stride smaller than patch size)

In [None]:
###### patches = dict(longitude=50, latitude=31)
# strides = dict(longitude=25, latitude=6)
patches = dict(longitude=25, latitude=31)
strides = dict(longitude=5, latitude=6)
patcher = XRDAPatcher(
    da=data, patches=patches, strides=strides, check_full_scan=True
)

rec_weight = np.ones((31, 25)) # Weight for each pixel of one patch
rec_weight[:1] = 0 # do not use the border pixels during the reconstruction
rec_weight[:, :1] = 0
rec_weight[-1:] = 0
rec_weight[:, -1:] = 0


torch_ds = XrTorchDataset(patcher, item_postpro=TrainingItem._make)
dataloader = torch.utils.data.DataLoader(torch_ds, batch_size=4, shuffle=False)
rec_ds = torch_ds.reconstruct_from_batches(
    ((np.diff(batch.uo, axis=1, prepend=0) + np.diff(batch.vo,axis=2, prepend=0)) for batch in dataloader),
    dims_labels=['latitude', 'longitude'],
    weight=rec_weight,
)
rec_ds.plot(figsize=(5, 3))
### not properly handling nan??

# Geoprocessing examples using ocn_tools tool of oceanbench


##  green part in:

![](scheme_ob.png)


In [None]:
from ocn_tools._src.geoprocessing.spatial import latlon_deg2m
from ocn_tools._src.geoprocessing.temporal import time_rescale
import pandas as pd


In [None]:
test_data.time

#### Deg to meters

In [None]:
test_data=test_data.rename({'latitude':'lat', 'longitude':'lon'})


In [None]:

da_scaled = latlon_deg2m(test_data, mean=False)
da_scaled.lat
# ----> add args (lat lon names)

##### DateTime 2 Seconds

In [None]:
t0 = "2020-01-01"
freq_dt = 1
freq_unit = "D"

test_data_trescal = time_rescale(test_data, freq_dt=freq_dt, freq_unit=freq_unit, t0=t0)
test_data_trescal.time


### Computer geostrophic components

In [None]:
from ocn_tools._src.geoprocessing.geostrophic import geostrophic_velocities


In [None]:
da = geostrophic_velocities(test_data, variable="zos")
##----> add args (lat lon names)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

da.u.isel(time=1).plot.pcolormesh(ax=ax[0], cmap="jet")
da.v.isel(time=1).plot.pcolormesh(ax=ax[1], cmap="jet")

plt.tight_layout()
plt.show()

### Computer the Kinetic Energy from current and geostrophic current

In [None]:
from ocn_tools._src.geoprocessing.geostrophic import kinetic_energy

da = geostrophic_velocities(test_data, variable="zos")
da = kinetic_energy(da, variables=["u", "v"])
da2 = kinetic_energy(test_data, variables=["uo", "vo"])


In [None]:
%matplotlib inline
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

da.ke.isel(time=1).plot.pcolormesh(ax=ax[0], cmap="jet")
da2.ke.isel(time=1, elevation=1).plot.pcolormesh(ax=ax[1], cmap="jet")

plt.tight_layout()
plt.show()

### Relative Vorticity


In [None]:
from ocn_tools._src.geoprocessing.geostrophic import relative_vorticity

da = geostrophic_velocities(test_data, variable="zos")
da = relative_vorticity(da, variables=["u", "v"])
da2 = relative_vorticity(test_data, variables=["uo", "vo"])

In [None]:
%matplotlib inline
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

da.vort_r.isel(time=1).plot.pcolormesh(ax=ax[0], cmap="jet")
da2.vort_r.isel(time=1, elevation=1).plot.pcolormesh(ax=ax[1], cmap="jet")

plt.tight_layout()
plt.show()

### Absolute Vorticity

In [None]:
from ocn_tools._src.geoprocessing.geostrophic import absolute_vorticity

da = geostrophic_velocities(test_data, variable="zos")
da = absolute_vorticity(da, variables=["u", "v"])
da2 = absolute_vorticity(test_data, variables=["uo", "vo"])


In [None]:
%matplotlib inline
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

da.vort_a.isel(time=1).plot.pcolormesh(ax=ax[0], cmap="jet")
da2.vort_a.isel(time=1, elevation=1).plot.pcolormesh(ax=ax[1], cmap="jet")

plt.tight_layout()
plt.show()

### Divergence

In [None]:
from ocn_tools._src.geoprocessing.geostrophic import divergence, coriolis_normalized

da = geostrophic_velocities(test_data, variable="zos")
da = divergence(da, variables=["u", "v"])
da2 = divergence(test_data, variables=["uo", "vo"])

da = coriolis_normalized(da, "div")
da2 = coriolis_normalized(da2, "div")


%matplotlib inline
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

da.div.isel(time=1).plot.pcolormesh(ax=ax[0], cmap="jet")
da2.div.isel(time=1, elevation=1).plot.pcolormesh(ax=ax[1], cmap="jet")

plt.tight_layout()
plt.show()

### Enstropy

In [None]:
from ocn_tools._src.geoprocessing.geostrophic import enstrophy, coriolis_normalized

da = geostrophic_velocities(test_data, variable="zos")
da = relative_vorticity(da, variables=["u", "v"])
da2 = relative_vorticity(test_data, variables=["uo", "vo"])

da = enstrophy(da)
da2 = enstrophy(da2)

da = coriolis_normalized(da, "ens")
da2 = coriolis_normalized(da2, "ens")


%matplotlib inline
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

da.ens.isel(time=1).plot.pcolormesh(ax=ax[0], cmap="jet")
da2.ens.isel(time=1, elevation=1).plot.pcolormesh(ax=ax[1], cmap="jet")

plt.tight_layout()
plt.show()

### Strain

In [None]:
from ocn_tools._src.geoprocessing.geostrophic import shear_strain, coriolis_normalized

da = geostrophic_velocities(test_data, variable="zos")
da = shear_strain(da, variables=["u", "v"])
da2 = shear_strain(test_data, variables=["uo", "vo"])



%matplotlib inline
fig, ax = plt.subplots(ncols=2, figsize=(10,4))

da.shear_strain.isel(time=1).plot.pcolormesh(ax=ax[0], cmap="jet")
da2.shear_strain.isel(time=1, elevation=1).plot.pcolormesh(ax=ax[1], cmap="jet")

plt.tight_layout()
plt.show()

# PostProcessings/Evaluations
## assessing the effective resolution

In [None]:
from ocn_tools._src.metrics import power_spectrum as psdcalc
from ocn_tools._src.preprocessing.mean import xr_cond_average
from ocn_tools._src.geoprocessing.temporal import time_rescale

In [None]:
#test_data=test_data.rename({'latitude':'lat', 'longitude':'lon'})
test_data0=test_data.fillna(0)
da_scaled = latlon_deg2m(test_data0, mean=False)


In [None]:
t0 = "2020-10-22"
freq_dt = 1
freq_unit = "D"


da_scaled = time_rescale(da_scaled, freq_dt=freq_dt, freq_unit=freq_unit, t0=t0)



## Isotropic Assumption

In [None]:
da_psd_iso = psdcalc.psd_isotropic(test_data0, "zos", ["lat", "lon"])
da_psd_iso = xr_cond_average(da_psd_iso, dims=["time"], drop=True)
#doesn't work when you rescal to m

In [None]:
class PlotPSDIsotropic:
        
    def init_fig(self, ax=None, figsize=None):
        if ax is None:
            figsize = (5,4) if figsize is None else figsize
            self.fig, self.ax = plt.subplots(figsize=figsize)
        else:
            self.ax = ax
            self.fig = plt.gcf()
        
    def plot_wavenumber(self, da, freq_scale=1.0, units=None, **kwargs):
        
        if units is not None:
            xlabel = f"Wavenumber [cycles/{units}]"
        else:
            xlabel = f"Wavenumber"
        
        self.ax.plot(da.freq_r * freq_scale, da, **kwargs)

        self.ax.set(
            yscale="log", xscale="log",
            xlabel=xlabel,
            ylabel=f"PSD [{da.name}]"
        )

        self.ax.legend()
        self.ax.grid(which="both", alpha=0.5)
        
    def plot_wavelength(self, da, freq_scale=1.0, units=None, **kwargs):
        
        if units is not None:
            xlabel = f"Wavelength [{units}]"
        else:
            xlabel = f"Wavelength"
        
        self.ax.plot(1/(da.freq_r * freq_scale), da, **kwargs)
        
        self.ax.set(
            yscale="log", xscale="log",
            xlabel=xlabel,
            ylabel=f"PSD [{da.name}]"
        )

        self.ax.xaxis.set_major_formatter("{x:.0f}")
        self.ax.invert_xaxis()
        
        self.ax.legend()
        self.ax.grid(which="both", alpha=0.5)
                
    def plot_both(self, da, freq_scale=1.0, units=None, **kwargs):
        
        if units is not None:
            xlabel = f"Wavelength [{units}]"
        else:
            xlabel = f"Wavelength"
        
        self.plot_wavenumber(da=da, units=units, freq_scale=freq_scale, **kwargs)
        
        self.secax = self.ax.secondary_xaxis(
            "top", functions=(lambda x: 1 / (x + 1e-20), lambda x: 1 / (x + 1e-20))
        )
        self.secax.xaxis.set_major_formatter("{x:.0f}")
        self.secax.set(xlabel=xlabel)

In [None]:
psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig()
psd_iso_plot.plot_both(
    da_psd_iso.zos,
    freq_scale=1e3, 
    units="km",
    label="GLORYS12"
)
plt.show()

## Spatial Temporal (Time vs Longitude)

In [None]:
da_psd_st = psdcalc.psd_spacetime(test_data, "zos", ["time", "lon"])
da_psd_st = xr_cond_average(da_psd_st, dims=["lat"], drop=True)


In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker


class PlotPSDSpaceTime:
    def init_fig(self, ax=None, figsize=None):
        if ax is None:
            figsize = (5,4) if figsize is None else figsize
            self.fig, self.ax = plt.subplots(figsize=figsize)
        else:
            self.ax = ax
            self.fig = plt.gcf()
        
    def plot_wavenumber(
        self, 
        da, 
        space_scale: float=1.0,
        space_units: str=None,
        time_units: str=None,
        psd_units: float=None,
        **kwargs):
        
        if space_units is not None:
            xlabel = f"Wavenumber [cycles/{space_units}]"
        else:
            xlabel = f"Wavenumber"
        if time_units is not None:
            ylabel = f"Frequency [cycles/{time_units}]"
        else:
            ylabel = f"Frequency"

        if psd_units is None:
            cbar_label = "PSD"
        else:
            cbar_label = f"PSD [{psd_units}]"
        
        locator = ticker.LogLocator()
        norm = colors.LogNorm()
        
        pts = self.ax.contourf(
            1/(da.freq_lon*space_scale),
            1/da.freq_time, 
            da.transpose("freq_time", "freq_lon"), 
            norm=norm, 
            locator=locator, 
            cmap=kwargs.pop("cmap", "RdYlGn"), 
            extend=kwargs.pop("extend", "both"),
            **kwargs
        )

        self.ax.set(
            yscale="log",
            xscale="log",
            xlabel=xlabel,
            ylabel=ylabel,
        )
        # colorbar
        fmt = ticker.LogFormatterMathtext(base=10)
        cbar = plt.colorbar(
            pts,
            ax=self.ax,
            pad=0.02,
            format=fmt,
        )
        cbar.ax.set_ylabel(cbar_label)
        self.ax.invert_xaxis()
        self.ax.invert_yaxis()
        self.ax.grid(which="both", linestyle="--", linewidth=1, color="black", alpha=0.2)


    def plot_wavelength(        
        self, 
        da, 
        space_scale: float=1.0,
        space_units: str=None,
        time_units: str=None,
        psd_units: float=None,
        **kwargs
    ):
    
        if space_units is not None:
            xlabel = f"Wavelength [{space_units}]"
        else:
            xlabel = f"Wavelength"
            
        if time_units is not None:
            ylabel = f"Period [{time_units}]"
        else:
            ylabel = f"Period"
            
        if psd_units is None:
            cbar_label = "PSD"
        else:
            cbar_label = f"PSD [{psd_units}]"
            
        self.plot_wavenumber(
            da=da, space_scale=space_scale, 
            space_units=space_units, time_units=time_units,
            psd_units=psd_units
        )

        self.ax.set(
            xlabel=xlabel, 
            ylabel=ylabel
        )
        self.ax.xaxis.set_major_formatter("{x:.0f}")
        self.ax.yaxis.set_major_formatter("{x:.0f}")

In [None]:
psd_st_plot = PlotPSDSpaceTime()
psd_st_plot.init_fig()
psd_st_plot.plot_wavenumber(
    da_psd_st.zos,
    space_scale=1e3, 
    space_units="km", 
    time_units="days",
    psd_units="SSH"
)
plt.show()