# Southern Ocean Codes
## Environment Settings

In [None]:
# filter some warning messages
import warnings
warnings.filterwarnings("ignore") 

import os
os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION'] = '0'

In [None]:
import numpy as np
import xesmf as xe
import xarray as xr
import seawater as sw

import copy

dask settings

In [None]:
from dask.distributed import Client
from dask_gateway import Gateway

gateway = Gateway()
cluster = gateway.new_cluster()
cluster.adapt(minimum = 0, maximum = 40)

client = Client(cluster, timeout='50s') 
cluster

In [None]:
client.close()
cluster.close()

## Data Access and processing
### 1. Data reading and calculation
#### a) Read CMIP6 data from Google Cloud using intake

In [None]:
import intake
url = "https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/master.yaml"
cat = intake.open_catalog(url)
col = cat.climate.cmip6_gcs()

In [None]:
def rename_coords(ds):
    ds = ds.copy()
    """Rename all depth dim to `lev`"""
    if "olevel" in ds.coords:
        ds = ds.rename({"olevel": "lev"})
    if "lev_partial" in ds.coords:
        ds = ds.rename({"lev_partial": "lev"})
    """Rename all latitude, longitude dim to `lat`,`lon`"""
    if 'latitude' in ds.coords:
        ds = ds.rename({'longitude': 'lon', 'latitude': 'lat'})
    if 'nav_lat' in ds.coords:
        ds = ds.rename({'nav_lon': 'lon', 'nav_lat': 'lat'})
    if 'nav_lat' in ds.coords:
        ds = ds.rename({'nav_lon': 'lon', 'nav_lat': 'lat'})
    """decode cftime"""
    ds = xr.decode_cf(ds)
    return ds

def func_get_data(col, model, var, expe, freq, mem = 'r1i1p1f1', grid = 'gr'):
    dcat = col.search(variable_id = var,
                      table_id = freq, 
                      experiment_id = expe,
                      source_id = model,
                      member_id = mem, 
                      grid_label = grid)
    if dcat.df.empty:
        print('No data')
    else:
        key0 = dcat.keys()[0]
        data_source = dcat[key0]
        z_kwargs = {'consolidated': True, 'decode_times':False}
        dataset = data_source(zarr_kwargs=z_kwargs, preprocess=rename_coords).to_dask()
        
        dataset = dataset.squeeze('member_id')
        dataset = dataset.reset_coords('member_id', drop = True)
        
        return dataset

In [None]:
model_name = 'GFDL-CM4' 
# SST, Sanility, Heat Flux, Sea Ice Concentration
variables = ['thetao', 'so', 'hfds', 'siconc'] 
experiment = 'piControl'
frequency = 'Omon'



datasets = {}
for var in variables:
    print(var)
    if var == 'siconc':
        datasets[var] = func_get_data(col, model_name, var, experiment, 'SImon')
    else:
        datasets[var] = func_get_data(col, model_name, var, experiment, frequency)

#### b) Calculate density
Function smow and dens0 are modified from [python-seawater](https://github.com/pyoceans/python-seawater/tree/master/seawater)

In [None]:
def smow(t):
    a = (999.842594, 6.793952e-2, -9.095290e-3, 1.001685e-4, -1.120083e-6,
         6.536332e-9)

    T68 = t * 1.00024
    return (a[0] + (a[1] + (a[2] + (a[3] + (a[4] + a[5] * T68) * T68) * T68) *
            T68) * T68)

def dens0(s, t):
    T68 = t * 1.00024
    b = (8.24493e-1, -4.0899e-3, 7.6438e-5, -8.2467e-7, 5.3875e-9)
    c = (-5.72466e-3, 1.0227e-4, -1.6546e-6)
    d = 4.8314e-4
    return (smow(t) + (b[0] + (b[1] + (b[2] + (b[3] + b[4] * T68) * T68) *
            T68) * T68) * s + (c[0] + (c[1] + c[2] * T68) * T68) * s *
            s ** 0.5 + d * s ** 2)

def func_calc_dens(ds, start_index, stop_index):
    ds_t = ds['thetao']
    ds_s = ds['so']
    
    da_t = ds_t['thetao'].isel(time = slice(start_index, stop_index))
    da_s = ds_s['so'].isel(time = slice(start_index, stop_index))
    
    dens = dens0(da_s, da_t)
    
    return dens

In [None]:
da_dens = func_calc_dens(datasets, 0, 12*500)

In [None]:
def func_regrid(ds, ds_out, reuse=False, clear=True):
    dsr =  xe.Regridder(ds, ds_out, 'bilinear', periodic=True, reuse_weights = reuse, ignore_degenerate=True)
    dsr._grid_in = None
    dsr._grid_out = None
    dsr_out = dsr(ds)
    if clear:
        dsr.clean_weight_file()
    return dsr_out

In [None]:
##### Regridding output data grid ####
ds_out = xr.Dataset({'lat': (['lat'], np.arange(-90, 91, 1.0)),
                     'lon': (['lon'], np.arange(0, 361, 1.0))})

In [None]:
da_dens_gr = func_regrid(da_dens, ds_out)