# Applying Cross Scattering Transform in Earth Observation imagery to fill missing values

## Authors & Contributors
### Authors
- Jean-Marc Delouis, LOPS - Laboratoire d'Oceanographie Physique et Spatiale
UMR 6523 CNRS-IFREMER-IRD-Univ.Brest-IUEM  (France), [@jmdelouis](https://github.com/jmdelouis)

### Contributors
- Anne Fouilloux, Simula (Norway), [@annefou](https://github.com/annefou)
- Tina Odaka, LOPS - Laboratoire d'Oceanographie Physique et Spatiale
UMR 6523 CNRS-IFREMER-IRD-Univ.Brest-IUEM (France), [@tinaok](https://github.com/tinaok)
- Justus Magin, LOPS - Laboratoire d'Oceanographie Physique et Spatiale
UMR 6523 CNRS-IFREMER-IRD-Univ.Brest-IUEM (France), [@keewis](https://github.com/keewis)

### Modelling publication
```{bibliography}
  :style: plain
  :list: bullet
  :filter: topic % "jmdelouis2022"
```


<div class="alert alert-info">
<i class="fa-question-circle fa" style="font-size: 22px;color:#666;"></i> Overview
    <br>
    <br>
    <b>Questions</b>
    <ul>
        <li>What is healpix grid?</li>
        <li>How do I read Copernicus Marine data in its original healpix grid?</li>
        <li>What is Cross Scattering Transform and what can I used it for?</li>
        <li>What is foscat?</li>
        <li>How can I use cross scattering transform to fill missing values?</li>
    </ul>
    <b>Objectives</b>
    <ul>
        <li>Learn about Healpix</li>
        <li>Learn about Cross Scattering Transform and foscat Python package</li>
        <li>Learn about Dask, Dask Gateway, Dask Client, Scheduler, Workers</li>
        <li>Understand out-of-core and speed-up limitations</li>
    </ul>
</div>

## Context


Most Earth Observation images contain values that need to be "removed" or set to missing, such as pixels with clouds when we are interested in the land surface. Filling the "gaps" or denoising data in the most realistic way is challenging and we investigate the usage of Cross Scattering Transform  with [foscat](https://foscat-documentation.readthedocs.io/en/latest/index.html) Python Package. 

As remote sensing data can be quite large, we also use [Dask](https://docs.dask.org/) with [Xarray](https://docs.xarray.dev/en/stable/) to parallelize our data analysis. 

### Modelling approach

This notebook applies an algorithm developed using scattering transforms to denoise and distinguish dust polarization from data noise on spherical images. Implementation and testing revealed effective recovery of dust emission statistics until certain limits. This method will be used in Earth Observation imagery to fill in missing values and provide a more accurate representation. The details of this algorithm are given in the paper entitled "**Non-Gaussian modelling and statistical denoising of Planck dust polarization full-sky maps using scattering transforms**", *Delouis et al.*, 2022. DOI: [10.1051/0004-6361/202244566](http://dx.doi.org/10.1051/0004-6361/202244566)

### Data

In this episode, we will be using datasets from [Copernicus Marine Service](https://marine.copernicus.eu):
- [ODYSSEA Global Sea Surface Temperature Gridded Level 4 Daily Multi-Sensor Observations](https://data.marine.copernicus.eu/product/SST_GLO_PHY_L4_NRT_010_043/description). DOI (product): [h10.48670/mds-00321](https://doi.org/10.48670/mds-00321))
- [ODYSSEA Global Ocean - Sea Surface Temperature Multi-sensor L3S Observations](https://data.marine.copernicus.eu/product/SST_GLO_SST_L3S_NRT_OBSERVATIONS_010_010/description). DOI (product): [10.48670/moi-00164](https://doi.org/10.48670/moi-00164).

Both datasetst are in [Zarr](https://zarr.dev) format and will be accessed through [S3-compatible object storage](https://en.wikipedia.org/wiki/Amazon_S3).

## Setup

This episode uses the following main Python packages:

- numpy {cite:ps}`a-numpy-harris2020`   
- healpy {cite:ps}`a-healpy-zonca2024`   
- xarray {cite:ps}`a-xarray-hoyer2017` with [`netCDF4`](https://pypi.org/project/h5netcdf/), [`h5netcdf`](https://pypi.org/project/h5netcdf/) and [`zarr`](https://pypi.org/project/zarr/) engines
- xdggs {cite:ps}`a-xdggs-2024`                                                                           
- foscat {cite:ps}`a-foscat-delouis2024`
- matplotlib {cite:ps}`a-matplotlib-Hunter2007`
- hvplot {cite:ps}`a-holoviews-rudiger2020`
- cmcrameri {cite:ps}`a-cmcrameri-crameri2018`
- pyinterp {cite:ps}`a-pyinterp`

Please install these packages if not already available in your Python environment.

### Packages

In this episode, Python packages are imported when we start to use them. However, for best software practices, we recommend you to install and import all the necessary libraries at the top of your Jupyter notebook.

## Installation of required packages

In [None]:
!pip install cmcrameri foscat xdggs 

In [None]:
!mamba install pyinterp -y

## Import necessary libraries

In [None]:
import xarray as xr
import numpy as np

import pint_xarray
import cf_xarray.units  

import healpy as hp
import matplotlib.pyplot as plt
import foscat.Synthesis as synthe
import foscat.scat_cov as sc


import xdggs

import holoviews as hv
import hvplot.xarray
import cmcrameri.cm as cmc

## Create a local Dask cluster on the local machine

In [None]:
from dask.distributed import Client

client = Client()   # create a local dask cluster on the local machine.
client

Inspecting the `Cluster Info` section above gives us information about the created cluster: we have 2 or 4 workers and the same number of threads (e.g. 1 thread per worker). 



<div class="alert alert-warning">
    <i class="fa-check-circle fa" style="font-size: 22px;color:#666;"></i> <b>Go further</b>
    <br>
    <ul>
        <li> You can also create a local cluster with the `LocalCluster` constructor and use `n_workers` 
        and `threads_per_worker` to manually specify the number of processes and threads you want to use. 
        For instance, we could use `n_workers=2` and `threads_per_worker=2`.  </li>
        <li> This is sometimes preferable (in terms of performance), or when you run this tutorial on your PC, 
        you can avoid dask to use all your resources you have on your PC!  </li>
    </ul>
</div>

## Dask Dashboard

Dask comes with a really handy interface: the Dask Dashboard. It is a web interface that you can open in a separate tab of your browser.

We will learn here how to use it through the [Dask JupyterLab extension](https://github.com/dask/dask-labextension).

To use the Dask Dashboard through the JupyterLab extension on the Pangeo EOSC infrastructure, you will just need to click *Launch dashboard in JupyterLab* at the Client configuration in your JupyterLab, and the Dask dashboard port number, as highlighted in the figure below.

![Dashboard link](./images/dashboardlink.png)

![Dask Lab](./images/dasklab.png)

Then click the orange icon indicated in the above figure, and type your dashboard link.

You can click several buttons indicated with red arrows in the above figures, then drag and drop them to place them as per your convenience.

![Example Dask Lab](./images/exampledasklab.png)

It's really helpful to understand your computation and how it is distributed.

## Data Loading: Get data from Copernicus Marine Services

Data from Copernicus Marine Service is available in `zarr` format via s3-compatible object storage which make this data easily and efficiently accessible.

Lets choose the date we want to make the gap filling.

In [None]:
time_slice = slice('2024-06-01', '2024-06-01')

### Load Global Ocean Sea Surface Temperature L3S Observations Dataset from Copernicus Marine Services

*This product provides daily foundation sea surface temperature from multiple satellite sources on a 0.10 x 0.10 degree grid (approximately 10 x 10 km) for the Global Ocean.* It contains 'gaps' due to clouds.

In [None]:
L3S = xr.open_zarr("https://s3.waw3-1.cloudferro.com/mdl-arco-time-045/arco/SST_GLO_SST_L3S_NRT_OBSERVATIONS_010_010/IFREMER-GLOB-SST-L3-NRT-OBS_FULL_TIME_SERIE_202211/timeChunked.zarr"
).sel(time=time_slice)

L3S

<div class="alert alert-success">
    <i class="fa-check-circle fa" style="font-size: 22px;color:#666;"></i> <b>Key Points</b>
    <br>
    <ul>
        <li>Where do you find attributes?</li>
        <li>What kind of data variables do you find in this dataset? What are the coordinates and dimensions?</li>
    </ul>
</div>

Lets try to plot sea surfarce temperature values.

In [None]:
L3S['sea_surface_temperature'].plot()

### Load L4 Dataset from Copernicus Marine Services

*This dataset provides a time series of gap-free maps of Sea Surface Temperature (SST) foundation at high resolution on a 0.10 x 0.10 degree grid (approximately 10 x 10 km) for the Global Ocean, updated every 24 hours.*

We load the data in `L4` and select one date.

### Geo Chunk and Time Chunk
Let's try to load data in 'geo' chunked format and 'time' chunked format to see the differences. 

In [None]:
L4 = xr.open_zarr(
 "https://s3.waw3-1.cloudferro.com/mdl-arco-geo-045/arco/SST_GLO_PHY_L4_NRT_010_043/cmems_obs-sst_glo_phy_nrt_l4_P1D-m_202303/geoChunked.zarr"
 ).sel(time=time_slice)

L4

In [None]:
L4 = xr.open_zarr(
 "https://s3.waw3-1.cloudferro.com/mdl-arco-time-045/arco/SST_GLO_PHY_L4_NRT_010_043/cmems_obs-sst_glo_phy_nrt_l4_P1D-m_202303/timeChunked.zarr"
 ).sel(time=time_slice)

L4

<div class="alert alert-success">
    <i class="fa-check-circle fa" style="font-size: 22px;color:#666;"></i> <b>Key Points</b>
    <br>
    <ul>
        <li>`geo` and `time`:  which chunk is suitable for our computation??</li>
    </ul>
</div>

### Preprocess the L4 gap free SST data

Lets create Xarray dataset with `land` `sea ice` `lake` bit mask variable

In [None]:
ds = L4[['mask']]
ds

In [None]:
ds['mask'].plot(y='latitude', x='longitude')

#### Get land-sea mask (0-1)

Lets set sea to 1 and everything else to 0

In [None]:
ds['mask'] = xr.where(ds.mask==1,True,False)

In [None]:
ds['mask'].plot(y='latitude', x='longitude', cmap="binary")

#### Lets check the unit.


In [None]:
L4['analysed_sst'].attrs

In [None]:
L4['analysed_sst'].attrs['units']

With pint xarray one can convert units.   Lets convert SST L4 from Kelvin to Celcius.

In [None]:
L4['analysed_sst'].pint.quantify().pint.to("degC").pint.dequantify()

In [None]:
ds['SST_L4'] = L4['analysed_sst'].pint.quantify().pint.to("degC").pint.dequantify()


### Preprocess the L3 SST data

#### Convert SST L3 from Celcius to Kelvin and only keep data where quality_level = 5

SST L3 data is quality controlled and quality levels are assigned with the following meaning:
- no_data -> 0
- bad_data -> 1
- worst_quality -> 2
- low_quality -> 3
- acceptable_quality -> 4
- best_quality -> 5

We will filter SST L3 and only keep values with best quality (`quality_level =  5`).

In [None]:
ds['SST_L3S'] = (
    L3S['sea_surface_temperature'].pint.quantify().pint.to("degC").pint.dequantify()
).where ((L3S.quality_level ==5 ))


#### Mask values for land, ice and lakes

Use our binary mask and set SST to -100.

We do not set it to missing values (`np.nan`) because we want later to differentiate between pixels with clouds (missing values e.g. `np.nan`) and land.

In [None]:
ds['SST_L3S'] = ds['SST_L3S'].where(ds.mask, -100)

#### Lets try Persist() ;

`persist()` load data from disk, triggers computation and keeps data as dask arrays in your memory. 
Please watch carefully the dask lab view

In [None]:
ds=ds.persist()

In [None]:
ds.SST_L3S.hvplot(y='latitude', x='longitude')

## Save Dataset to local Zarr

Zarr is a data format for storing chunked, compressed, N-dimensional arrays. 

In [None]:
ds.to_zarr('SST.zarr', mode='w')

In [None]:
!ls

In [None]:
!ls -lart SST.zarr

In [None]:
!cat SST.zarr/.zmetadata | head -n 30

<div class="alert alert-success">
    <i class="fa-check-circle fa" style="font-size: 22px;color:#666;"></i> <b>Key Points</b>
    <br>
    <ul>
        <li>What is 'zarr' format?  </li>
    </ul>
</div>

## Data preparation


Lets open the zarr file we prepared, and this time lets use hvplot for plotting the sea surface temperature.

In [None]:
ds = xr.open_dataset("SST.zarr", engine="zarr").isel(time=0)#.persist()
ds

In [None]:
ds['SST_L3S'].hvplot(y='latitude', x='longitude', width=800, height=400)

### How can we fill the gap using Xarray functions? 

In [None]:
da=ds.SST_L3S
(da
    .interpolate_na(dim="latitude",
                  method="linear")
    .interpolate_na(dim="longitude",
                  method="linear")

    .plot()
)

### We can also use tools like pyinterp
to know more about pyinterp, 
https://pangeo-pyinterp.readthedocs.io/en/latest/index.html

In [None]:
import pyinterp.backends.xarray
# Module that handles the filling of undefined values.
import pyinterp.fill

da=ds.SST_L3S.where(ds.mask, np.nan)
grid = pyinterp.backends.xarray.Grid2D(da.compute())

In [None]:
%%time
filled = pyinterp.fill.loess(grid, nx=30, ny=30)
# pyinterp has another method for filling gap
# This method takes much more time than 'loess' 
#has_converged, filled = pyinterp.fill.gauss_seidel(grid)

In [None]:
fill=da.copy()
fill[:,:]=np.transpose(filled)
ds["SST_L3S_filled_pyinterp"]=fill
ds["SST_L3S_filled_pyinterp"].where(ds.mask, -100).plot()

We can verify for example how much np.nan still remains

In [None]:
np.isnan(ds["SST_L3S_filled_pyinterp"].where(ds.mask, -100)).sum().item()

## Convert Data in HEALPix

The above gap-filling method does not take into account that our subject is a sphere. Because we work with full globe data, we need to treat the grid system that conveys the shape of the sphere. To do that, we use one of the Discrete Global Grid Systems (DGGS), HEALPix.

HEALPix stands for Hierarchical Equal Area isoLatitude Pixelation of a sphere. This pixelation produces a subdivision of a spherical surface in which each pixel covers the same surface area as every other pixel. 
![HEALPix](https://healpix.sourceforge.io/images/gorski_f1.jpg)

See [https://healpix.sourceforge.io](https://healpix.sourceforge.io) and/or [the HEALPix Primer](https://healpix.jpl.nasa.gov/pdf/intro.pdf) for more information.

The [healpy tutorial](https://healpy.readthedocs.io/en/latest/tutorial.html#NSIDE-and-ordering) is also a very good starting point to understand more about HEALPix.

### Resolution
The resolution of the grid is expressed by the parameter `Nside`, which defines the number of divisions along the side of a base-resolution pixel that is needed to reach a desired high-resolution partition.

### Ordering Systems

HEALPix supports two pixel ordering systems: `nested` and `ring`.

Detailed explanations of the two pixel ordering systems can be found at [https://healpix.jpl.nasa.gov/html/intronode4.htm](https://healpix.jpl.nasa.gov/html/intronode4.htm).

In our example we use `nested` 

In [None]:
ds = xr.open_dataset("SST.zarr", engine="zarr").isel(time=0).chunk(500).persist()
ds

### Define the HEALPix resolution

In [None]:
nside = 128
nest = True
full_cell_ids = range(0, 12*nside**2)

### Compute HEALPix Cell id

In [None]:
ds_healpix = (ds
    .stack(id=("latitude", "longitude"))
    .chunk("auto")
   ).persist()
ds_healpix["latitude_cp"] = ds_healpix["latitude"]
ds_healpix["longitude_cp"] = ds_healpix["longitude"]
th = (90.0 - ds_healpix['latitude'])/180.0*np.pi
ph = 2*np.pi - (ds_healpix['longitude'])/180.0*np.pi
cell_ids = hp.ang2pix(nside,th.data,ph.data,nest=nest)
ds_healpix = ds_healpix.compute()
ds_healpix

### Create Land-Sea mask with HEALPix cell ids
- group by the mask, and take mean, including np.nan values, to keep the trace of all missing values. (skipna=False )

In [None]:
%%time 
ds_mask = ds_healpix[["mask"]].assign_coords({
        "cell_ids":(["id"], cell_ids)}).sortby("cell_ids"
                                  ).groupby('cell_ids'
                                           ).mean(skipna=False, keep_attrs=True).compute()
ds_mask

### Create SST L3 with latitude and longitude coordinate and HEALPix cell id
- group by the mask, and take mean, skipping np.nan values, so that temperature data does not get affected. (skipna=True )

In [None]:
%%time 
ds_healpix = ds_healpix[["SST_L3S", "SST_L4",
                         "latitude_cp", "longitude_cp", 
                      ]].assign_coords({
        "cell_ids":(["id"], cell_ids)}).sortby("cell_ids"
                                  ).groupby("cell_ids"
                                           ).mean(skipna=True, keep_attrs=True)

In [None]:
ds_healpix["ocean"] = xr.where(ds_mask["mask"]==1, True, False)
ds_healpix['clouds'] = ds_healpix.SST_L3S.isnull()

# check if there is any forgotten latitude or longitude, if zero, it is ok.
print(ds_healpix.latitude_cp.isnull().sum().data, ds_healpix.longitude_cp.isnull().sum().data)
print(ds_healpix.ocean.sum().compute().data, ds_healpix.clouds.sum().compute().data)

ds_healpix

### Re-index and use xdggs to get and Xarray with DGGS grid and coordinate

In [None]:
%%time
ds_healpix = ds_healpix.reindex(cell_ids=full_cell_ids, fill_value=False)
ds_healpix = ds_healpix.persist()

ds_healpix.cell_ids.attrs = {
    "grid_name": "healpix",
    "nside": nside,
    "nest": nest,
}

ds_healpix = ds_healpix.reset_index('cell_ids').set_xindex("cell_ids", xdggs.DGGSIndex).dggs.assign_latlon_coords()

print(ds_healpix.latitude_cp.isnull().sum().data,ds_healpix.longitude_cp.isnull().sum().data)
print(ds_healpix.ocean.sum().compute().data, ds_healpix.clouds.sum().compute().data)
ds_healpix

#### Set chunks

`chunk()` re-set the chunks from i'ts original chunk form we loaded from the zarr file.  

In [None]:
ds = ds.chunk(1000).persist()
ds

### Save HEALPix SST L3S and SST L4 in local zarr 

In [None]:
ds_healpix.to_zarr('healpix.zarr', mode='w')

### Load HEALPix Zarr with SST L3S and SST L4 from local Zarr file

In [None]:
ds_healpix = xr.open_dataset('healpix.zarr', engine="zarr")
ds_healpix

### Visualize data using healpy

In [None]:
hp.cartview(ds_healpix.ocean.compute().data, cmap='binary', nest=True)

In [None]:
hp.cartview(ds_healpix.clouds.compute().data, cmap='binary', nest=True)

In [None]:
hp.cartview(ds_healpix.SST_L3S.where(ds_healpix.ocean,hp.UNSEEN).compute().data,
            cmap='coolwarm',
            nest=True,
            min=0, max=30)

## Fill the Clouds Using Specificity of Healpy

Healpy is designed to represent the sphere using spherical harmonics functions. We will use linear regression against spherical harmonics functions to fill the gaps caused by clouds.

### How Do Spherical Harmonics Look?

We will construct a function \( F \), which is defined as follows.

$
F = [\mathbb{R} {A_{00}},R{A_{10}}, R{A_{11}},
 I{A_{11}},..,R {A_{l_{max}m_{max}}} I {A_{l_{max}m_{max}}} ]
$

In [None]:
lmax=15
nside=128
#compute Alm to fit
#get the l and m availble for l<=lmax
l,m=hp.Alm.getlm(lmax=lmax)

#count the number of alm map (1 for m=0 and 2 for m>0)
n_alm=(m==0).sum()+2*(m>0).sum()
function=np.zeros([n_alm,12*nside**2])

alm=np.zeros([l.shape[0]],dtype='complex')

i=0

#array to store the l and m values of the A_lm
l_func=np.zeros(n_alm,dtype='int')
m_func=np.zeros(n_alm,dtype='int')
is_real_func=np.zeros(n_alm,dtype='int')

for k in range(l.shape[0]):
    alm[k]=1.0
    function[i]=hp.reorder(hp.alm2map(alm,nside),r2n=True)
    l_func[i]=l[k]
    m_func[i]=m[k]
    is_real_func[i]=1
    i+=1
    if m[k]>0:
        alm[k]=complex(0,1)
        function[i]=hp.reorder(hp.alm2map(alm,nside),r2n=True)
        l_func[i]=l[k]
        m_func[i]=m[k]
        is_real_func[i]=0
        i+=1
    alm[k]=0.0

In [None]:
lm=3
plt.figure(figsize=(12,5))
for k in range(l_func.shape[0]):
    pos=1+l_func[k]*(2*lm+1)+2*(is_real_func[k]-0.5)*m_func[k]-1+(lm+1)
    if is_real_func[k]==1:
        title='$\mathbb{R}(A_{\ell=%d,m=%d})$'%(l_func[k],m_func[k])
    else:
        title='$\mathbb{I}(A_{\ell=%d,m=%d})$'%(l_func[k],m_func[k])
    if l_func[k]<=lm:
        hp.mollview(function[k],nest=True,hold=False,sub=(lm+1,2*lm+1,pos)
                    ,title=title,cbar=False,cmap='coolwarm')

### Let's Fit!

Let's fit the existing values to the spherical harmonics functions.

To do that, we will use Healpy's Alm function. Lmax is the order of the spherical harmonics function. In this example, we use an order of lmax=15.

First, we extract the values where it is ocean, but not masked as clouds.  

In [None]:
sst_to_fit = ds_healpix.SST_L3S.where(
    ds_healpix.ocean & ~(ds_healpix.clouds)).dropna(dim='cell_ids')
sst_l4_to_fit = ds_healpix.SST_L4.where(
    ds_healpix.ocean ).dropna(dim='cell_ids')

### Compute the A_lm fit on the known SST

We will use chi2 method to fit the alm values using the function which we have constracted inj the last step.  
The concept is to approximate the SST values utilizing a spherical harmonic basis from the known data to populate mask one.




$
M= F^T.F \\
$
$
R=M^{-1} F^T . data
$


In [None]:
mat_l3=function[:,sst_to_fit.cell_ids]@function[:,sst_to_fit.cell_ids].T
vec_l3=function[:,sst_to_fit.cell_ids]@sst_to_fit.data
mat_l4=function[:,sst_l4_to_fit.cell_ids]@function[:,sst_l4_to_fit.cell_ids].T
vec_l4=function[:,sst_l4_to_fit.cell_ids]@sst_l4_to_fit.data
harm_l3=np.linalg.pinv(mat_l3)@vec_l3
harm_l4=np.linalg.pinv(mat_l4)@vec_l4
print("Number of functions ",harm_l3.shape)
# compute the fitted projected data
fit_data_l3=(harm_l3.reshape(1,harm_l3.shape[0])@function).flatten()
fit_data_l4=(harm_l4.reshape(1,harm_l4.shape[0])@function).flatten()

In [None]:

hp.cartview(fit_data_l3/ds_healpix.ocean.data,cmap='coolwarm',nest=True,title='Fitted L3S model')
hp.cartview(fit_data_l4/ds_healpix.ocean.data,cmap='coolwarm',nest=True,title='Fitted L4 model')

In [None]:
filled_data_l3=ds_healpix.SST_L3S.where(ds_healpix.ocean,hp.UNSEEN).compute().data
filled_data_l4=ds_healpix.SST_L4.data
filled_data_l3[ds_healpix.clouds]=fit_data_l3[ds_healpix.clouds]
filled_data_l3[ds_healpix.ocean==False]=fit_data_l3[ds_healpix.ocean==False]
filled_data_l4[ds_healpix.ocean==False]=fit_data_l4[ds_healpix.ocean==False]
hp.cartview(filled_data_l3/ds_healpix.ocean.data,cmap='coolwarm',nest=True,title='Filled data')

#Put the computed values in the  SST_L3S and save the result in a Data Aray SST_polyfit_filled
ds_healpix['SST_polyfit_filled'] = ds_healpix['SST_L3S'].copy()
ds_healpix['SST_polyfit_filled'].loc[{"cell_ids": ds_healpix['SST_L3S'].cell_ids}] = filled_data_l3
ds_healpix['SST_L4'].loc[{"cell_ids": ds_healpix['SST_L4'].cell_ids}] = filled_data_l4

### Plot the fitted results

In [None]:
rot=[120,-30]
plt.figure(figsize=(16,10))
hp.cartview(ds_healpix.SST_L3S.where(ds_healpix.ocean, hp.UNSEEN).compute().data, 
            title="SST L3S - 01/06/2024", nest=True, min=-3, max=30, cmap='cmc.vik',
            hold=False, sub=(2,2,1))
hp.cartview(ds_healpix.SST_polyfit_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data, 
            title="Polyfit -  01/06/2024", nest=True, min=-3, max=30, cmap='cmc.vik', hold=False,sub=(2,2,3))
hp.gnomview(ds_healpix.SST_L3S.where(ds_healpix.ocean,hp.UNSEEN).compute().data,
            nest=True, min=10,max=30, cmap='cmc.vik',rot=rot,reso=15, notext=True,
            title='01/06/2024', hold=False, sub=(2,2,2), xsize=256)
hp.gnomview(ds_healpix.SST_polyfit_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data,
            nest=True, min=10,max=30, cmap='cmc.vik', rot=rot, reso=15, notext=True,
            title='Synth. 01/06/2024', hold=False, sub=(2,2,4), xsize=256)

## What is scattering transform?

<style>
.responsive-wrap iframe{ max-width: 100%;}
</style>
<div class="responsive-wrap">
<!-- this is the embed code provided by Google -->
  <iframe src="https://docs.google.com/presentation/d/1_cLrYiDFJxouuquLi4K1WD-qWw_deq5M4VK64X68aok/embed?start=false&loop=false&delayms=3000" frameborder="0" width="960" height="569" allowfullscreen="true" mozallowfullscreen="true" webkitallowfullscreen="true"></iframe>
<!-- Google embed ends -->
</div>
	


## Apply scattering transform

#### learning process to fill the data.  

In this process, we will learn the process from the observed SST.  We need to fill the value with first naive 'guess', thus we start the learning process from ds_healpix.SST_polyfit_filled
- learn the property from observed SST (ocean area but without clouds) 
  ds_healpix.SST_polyfit_filled.where((ds_healpix.ocean & ~(ds_healpix.clouds)))
- compute the gradiant for the clouds area
  ds_healpix.SST_polyfit_filled.where( ~(ds_healpix.clouds)))
- We apply scattering transform for all oceanic area
  
  

In [None]:
plt.figure(figsize=(16,10))
hp.cartview(xr.where( (ds_healpix.ocean & ~ds_healpix.clouds),1,0).compute().data,
            cmap='binary', nest=True, title="observed (value=1)\n (only ocean, without clouds and lands)",
            sub=(2,2,1))
hp.cartview(xr.where( (ds_healpix.clouds),1,0).compute().data, cmap='binary', nest=True,
            title='clouds (value=1)\n',sub=(2,2,2))
hp.cartview(xr.where( (ds_healpix.ocean),1,0).compute().data, cmap='binary', nest=True,
            title='all ocean, including clouds', sub=(2,2,3))
hp.cartview(ds_healpix.SST_polyfit_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data, cmap='cmc.vik', nest=True,
            title='observed SST (Celcius), polyfitted',  sub=(2,2,4))

## Lets use foscat software ane make use of scattering transform

First, we set up foscat to fill missing data

In [None]:
scat = sc.funct(silent=True, 
                JmaxDelta=0)

### Prepare data and apply foscat on cloud data only

In [None]:
# Compute localized mask

nside_mask=1
local_mask=np.zeros([12*nside_mask**2,12*nside**2])
for k in range(12*nside_mask**2):
    l_data=hp.smoothing((np.arange(12*nside**2)*(nside_mask**2/nside**2)).astype('int')==k,60/180.*np.pi,nest=True)
    l_data/=l_data.max()
    hp.cartview(l_data,nest=True,hold=False,sub=(4,4,1+k),title='Mask #%d'%k)
    local_mask[k]=l_data

In [None]:
#our data to apply foscat
data = ds_healpix.SST_polyfit_filled.data.copy()
print(data.shape)

# here we mask non ocean point as 0
# cloud is 0, only observed points are 1. 
#mask_observe_only = (ds_healpix.ocean & ~(ds_healpix.clouds)).data.reshape(1, ds_healpix.cell_ids.size)
mask_observe_only = (ds_healpix.ocean).data.reshape(1, ds_healpix.cell_ids.size)

ref = scat.eval(ds_healpix.SST_L4.data, mask=mask_observe_only*local_mask)

In [None]:
def The_loss_function(x,scat_operator,args,return_all=False):

    ref = args[0]
    mask = args[1]

    learn=scat_operator.eval(x,mask=mask)

    loss=scat_operator.reduce_mean(scat_operator.square((ref-learn)/ref))

    return(loss)
# all the cloud is one. 
#
grd_mask=data[0]==hp.UNSEEN, 
#mask it as 1.0 if it is cloud. 
# here we mask non ocean point as 0
# all ocean points are 1. 
mask_all_ocean=(ds_healpix.ocean).data.reshape(1,ds_healpix.cell_ids.size)

print(mask_all_ocean.shape)

loss=synthe.Loss(The_loss_function,scat,ref,
                 scat.backend.constant(scat.backend.bk_cast(mask_all_ocean*local_mask)))

sy = synthe.Synthesis([loss])

### RUN ON SYNTHESIS

Here we set cloud as 1, because we want to compute gradient at clouds
and all the other points are 0

Before running your calculation, you can verify the type of GPU you have with follwoing command


In [None]:
!nvidia-smi


In [None]:
mask_clouds_only=((ds_healpix.clouds)).data.reshape(1,ds_healpix.cell_ids.size)
print(mask_clouds_only.shape)

In [None]:

omap=sy.run(scat.backend.bk_cast(data),
            EVAL_FREQUENCY=100,
            grd_mask=mask_clouds_only, # only the gradient of masked data is computed
            NUM_EPOCHS = 300).numpy()

In [None]:
#Put the computed values in the  SST_L3S and save the result in a Data Aray SST_foscat_filled
ds_healpix['SST_foscat_filled'] = ds_healpix['SST_L3S'].copy()
ds_healpix['SST_foscat_filled'].loc[:] = omap

# Plot the results

## Comparison with L4 product

We compare SST_polyfit_filled (red) and SST_foscat_filled (blue)

In [None]:
hist_range=[-5,5]
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
hy,hx=np.histogram((ds_healpix.SST_L4 - ds_healpix.SST_polyfit_filled).where(ds_healpix.ocean).data,range=hist_range,bins=100)
hy2,hx=np.histogram((ds_healpix.SST_L4 - ds_healpix.SST_foscat_filled).where(ds_healpix.ocean).data,range=hist_range,bins=100)
plt.plot((hx[1:]+hx[:-1])/2,100*hy/hy.max(),color='b',label=r'$A_{lm}$ Filled data')
plt.plot((hx[1:]+hx[:-1])/2,100*hy2/hy2.max(),color='r',label=r'FOSCAT Filled data')
plt.legend(frameon=False,loc=2)
plt.ylim(0,120)
plt.ylabel('Histogram for all ocean pixels [percent]')
plt.xlabel('$\Delta t$ with L4 [degree]')
#plt.yscale('log')
plt.subplot(1,2,2)
hy,hx=np.histogram((ds_healpix.SST_L4 - ds_healpix.SST_polyfit_filled).where(ds_healpix.clouds).data,range=hist_range,bins=100)
hy2,hx=np.histogram((ds_healpix.SST_L4 - ds_healpix.SST_foscat_filled).where(ds_healpix.clouds).data,range=hist_range,bins=100)
plt.plot((hx[1:]+hx[:-1])/2,100*hy/hy.max(),color='b',label=r'$A_{lm}$ Filled data')
plt.plot((hx[1:]+hx[:-1])/2,100*hy2/hy2.max(),color='r',label=r'FOSCAT Filled data')
plt.legend(frameon=False,loc=2)
plt.ylabel('')
plt.xlabel('$\Delta t$ with L4 [degree]')
plt.ylabel('Histogram for clouds [percent]')
plt.ylim(0,120)
#plt.yscale('log')

In [None]:
plt.figure(figsize=(16,10))
hp.cartview(
    (abs(ds_healpix.SST_L4 - ds_healpix.SST_foscat_filled)).where(ds_healpix.ocean,hp.UNSEEN).compute().data,
             nest=True,cmap='cmc.cork',min=-5,max=5,sub=(2,1,1),title='diff foscat L4')
hp.cartview(
    (abs(ds_healpix.SST_L4 - ds_healpix.SST_polyfit_filled)).where(ds_healpix.ocean,hp.UNSEEN).compute().data,
             nest=True,cmap='cmc.cork',min=-5,max=5,sub=(2,1,2),title='diff $A_{lm}$ L4')


In [None]:

plt.figure(figsize=(10, 6))
hp.cartview(ds_healpix.SST_polyfit_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data, nest=True, min=-3, max=30, cmap='cmc.vik',
            title='SST gaps filled with $A_{lm}$', hold=False, sub=(2,2,1))
hp.cartview(ds_healpix.SST_foscat_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data, nest=True, min=-3, max=30, cmap='cmc.vik',
            title='SST gaps filled with FOSCAT', hold=False, sub=(2,2,2))
hp.cartview(ds_healpix.SST_L4.where(ds_healpix.ocean,hp.UNSEEN).compute().data, nest=True, min=-3, max=30,
            cmap='cmc.vik', title='L4S', hold=False, sub=(2,2,3))

In [None]:
rot=[130,10]
reso=15

cmap='cmc.oleron'
cmap='coolwarm'
plt.figure(figsize=(12, 4))
hp.gnomview(ds_healpix.SST_polyfit_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data, nest=True, min=18, max=30, cmap=cmap, 
            rot=rot, reso=reso, 
            notext=True, title='L3S Filled with $A_{lm}$', hold=False, sub=(1,3,1), xsize=256)
hp.gnomview(ds_healpix.SST_foscat_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data, nest=True, min=18, max=30, cmap=cmap, 
            rot=rot, reso=reso, 
            notext=True, title='L3S Filled with FOSCAT ',hold=False,sub=(1,3,2), xsize=256)
hp.gnomview(ds_healpix.SST_L4.where(ds_healpix.ocean,hp.UNSEEN).compute().data, nest=True, min=18, max=30, 
            cmap=cmap,rot=rot,reso=reso,notext=True,
            title='L4S', hold=False, sub=(1,3,3), xsize=256)

plt.figure(figsize=(12, 4))
hp.gnomview(ds_healpix.SST_polyfit_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data-ds_healpix.SST_L4.where(ds_healpix.ocean,hp.UNSEEN).compute().data, 
            nest=True, min=-1, max=1, cmap=cmap, 
            rot=rot, reso=reso, 
            notext=True, title='L3S Filled with $A_{lm}$ - L4S', hold=False, sub=(1,3,1), xsize=256)
hp.gnomview(ds_healpix.SST_foscat_filled.where(ds_healpix.ocean,hp.UNSEEN).compute().data-ds_healpix.SST_L4.where(ds_healpix.ocean,hp.UNSEEN).compute().data, 
            nest=True, min=-1, max=1, cmap=cmap, 
            rot=rot, reso=reso, 
            notext=True, title='L3S Filled with FOSCAT  - L4S',hold=False,sub=(1,3,2), xsize=256)
hp.gnomview(ds_healpix.SST_L4.where(ds_healpix.ocean,hp.UNSEEN).compute().data-ds_healpix.SST_L4.where(ds_healpix.ocean,hp.UNSEEN).compute().data, 
            nest=True, min=-1, max=1, 
            cmap=cmap,rot=rot,reso=reso,notext=True,
            title='L4S - L4S', hold=False, sub=(1,3,3), xsize=256)

In [None]:
def computespectrum(ds_data,map2=None,mask=None):
    imask=hp.reorder(mask,n2r=True)
    tmp=hp.reorder(ds_data.compute().data,n2r=True)
    tmp=(tmp-np.median(tmp[imask==1]))*imask
    if map2 is not None:
        tmp2=hp.reorder(map2.compute().data,n2r=True)
        tmp2=(tmp2-np.median(tmp2[imask==1]))*imask
    else:
        tmp2=None
    return hp.anafast(tmp,map2=tmp2)

plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.title('All oceans')
cl0=computespectrum(ds_healpix.SST_L4,mask=mask_all_ocean[0])
cl1=computespectrum(ds_healpix.SST_polyfit_filled-ds_healpix.SST_L4,mask=mask_all_ocean[0])
cl2=computespectrum(ds_healpix.SST_foscat_filled-ds_healpix.SST_L4,mask=mask_all_ocean[0])
plt.plot(cl0,color='black',label='L4S')
plt.plot(cl1,color='b',label='L3S Filled with $A_{lm}$ - L4')
plt.plot(cl2,color='r',label='L3S Filled with FOSCAT - L4')
plt.xscale('log')
plt.yscale('log')
plt.legend(frameon=False,loc=1)
plt.xlabel('Multipole $\ell$')
plt.ylabel('PowerSpectrum [${degree}^{2}$]')
plt.subplot(1,2,2)
plt.title('Only Clouds')
cl0=computespectrum(ds_healpix.SST_L4,mask=mask_clouds_only[0])
cl1=computespectrum(ds_healpix.SST_polyfit_filled-ds_healpix.SST_L4,mask=mask_clouds_only[0])
cl2=computespectrum(ds_healpix.SST_foscat_filled-ds_healpix.SST_L4,mask=mask_clouds_only[0])
plt.plot(cl0,color='black',label='L4S')
plt.plot(cl1,color='b',label='L3S Filled with $A_{lm}$ - L4')
plt.plot(cl2,color='r',label='L3S Filled with FOSCAT - L4')
plt.xscale('log')
plt.yscale('log')
plt.legend(frameon=False,loc=1)
plt.xlabel('Multipole $\ell$')
plt.ylabel('PowerSpectrum [${degree}^{2}$]')

In [None]:
client.close()

<div class=\"alert alert-success\">
<i class=\"fa-check-circle fa\" style=\"font-size: 22px;color:#666;\"></i> <b>Key Points</b>
<br>
<ul>
   <li>HALPIx</li>
   <li>Access, read and get metadata from remote and local zarr</li>
   <li>Apply Cross Scattering Transform to fill missing values</li>
   </ul>
</div>

 ## References

```{bibliography}
:style: alpha
:filter: topic % "SST" and not topic % "package"
:keyprefix: a-
```

## Packages citation

```{bibliography}
:style: alpha
:filter: topic % "SST" and topic % "package"
:keyprefix: a-
```