# 1.3: Convert to extremes and binary!

In [None]:
# general use packages
%matplotlib inline
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
import numpy as np

# packages for altering time to match up!
import sys
import cftime

# climpred packages
import climpred
from climpred import HindcastEnsemble
from climpred.tutorial import load_dataset
from climpred.stats import rm_poly

# SMYLE Utility functions
from SMYLEutils import io_utils as io
from SMYLEutils import calendar_utils as cal
from SMYLEutils import stat_utils as stat

In [None]:
def detrend_linear(dat, dim):
    """ linear detrend dat along the axis dim """
    params = dat.polyfit(dim=dim, deg=1)
    fit = xr.polyval(dat[dim], params.polyfit_coefficients)
    dat = dat-fit
    return dat

In [None]:
def detrend_second(dat, dim):
    """ linear detrend dat along the axis dim """
    params = dat.polyfit(dim=dim, deg=1)
    fit = xr.polyval(dat[dim], params.polyfit_coefficients)
    dat = dat-fit
    return dat

## SMYLE

In [None]:
var = 'TEMP'
var2 = 'TEMP' # var
depth = 'surface'
init = '11'
level = 0.9 # 0.1 or 0.9

In [None]:
# select dataset
smyle = xr.open_dataset('/glade/derecho/scratch/smogen/SMYLE-Extreme/'+var + '.monthly.' + depth + '.' + init + '.regrid.nc')[var]
smyle_time = xr.open_dataset('/glade/derecho/scratch/smogen/SMYLE-Extreme/'+var2+'.monthly.' + init + '.time.nc')
smyle = smyle.drop('time')

In [None]:
%%time
# remove climatological drift from the data
smyle_anom,smyle_clim = stat.remove_drift(smyle,smyle_time,1970,2023)

In [None]:
%%time
# detrend
# smyle_anom = detrend_linear(smyle_anom.time,'Y')
smyle_anom = detrend_second(smyle_anom.time,'Y')

In [None]:
smyle_anom.isel(M=0,L=0).sel(lat=0.5,lon=-130.5).plot()

### threshold within SMYLE

In [None]:
thold_data = smyle_anom.mean('M')

In [None]:
print(level) # check the level before you run the code!

In [None]:
# calculate the threshold using a rolling 3-month average
%%time

ds_thold = []

# 1st through 23th month
for i in range(0,23):
    print(i)
    tst = thold_data.sel(L = [thold_data.L[i - 1],thold_data.L[i],thold_data.L[i+1]]).quantile(level,dim=('L','Y'),skipna=True)
    tst = tst.expand_dims('L')
    ds_thold.append(tst)
    
# 12th month
print('24')
last_month = thold_data.sel(L = [thold_data.L[22],thold_data.L[23],thold_data.L[0]]).quantile(level,dim=('L','Y'),skipna=True)
last_month = last_month.expand_dims('L')

ds_thold.append(last_month)

In [None]:
smyle_threshold = xr.concat(ds_thold,dim='L')

smyle_threshold = smyle_threshold.to_dataset(name='threshold')
smyle_threshold['L'] = smyle_threshold.L + 1

In [None]:
smyle_anom.sel(Y=1997,L=11,M=3).plot()

In [None]:
# var = 'omega_arag'
# smyle_threshold.to_netcdf('/glade/work/smogen/SMYLE-extremes/thresholds/smyle' + init +  '.' + var + '.thold.Rolling.full.nc')
smyle_threshold.to_netcdf('/glade/derecho/scratch/smogen/SMYLE-Extreme/thresholds/smyle' + init +  '.' + var + '.thold.Rolling.new_run2.nc')

### SMYLE to a binary

In [None]:
depth = 'surface'

In [None]:
# smyle = xr.open_dataset('/glade/derecho/scratch/smogen/SMYLE-Extreme/'+var + '.monthly.' + depth + '.' + init + '.regrid.nc')[var]
# smyle_time = xr.open_dataset('/glade/derecho/scratch/smogen/SMYLE-Extreme/'+var2+'.monthly.' + init + '.time.nc')
# smyle = smyle.drop('time')

smyle = xr.open_dataset('/glade/derecho/scratch/smogen/SMYLE-Extreme/'+var + '.monthly.' + depth + '.live11.regrid.update.new_run.combined.nc')[var]
smyle_time = xr.open_dataset('/glade/derecho/scratch/smogen/SMYLE-Extreme/'+var2+'.monthly.live11.time.update.nc')
smyle = smyle.drop('time')

In [None]:
# var = 'H+'
# thold = xr.open_dataset('/glade/work/smogen/SMYLE-extremes/thresholds/smyle' + init + '.' + var + '.thold.Rolling.full.nc')['threshold']
thold = xr.open_dataset('/glade/derecho/scratch/smogen/SMYLE-Extreme/thresholds/smyle' + init +  '.' + var + '.thold.Rolling.new_run2.nc')['threshold']

In [None]:
# drift correct anomalies
smyle_anom,smyle_clim = stat.remove_drift(smyle,smyle_time,1982,2023)

In [None]:
# detrend data
smyle_anom = detrend_linear(smyle_anom.time,'Y')
# smyle_anom = smyle_anom.time

# smyle_anom = detrend_second(smyle_anom.time,'Y')

In [None]:
# plot the threshold at lead=0 to check on the calculations
smyle_anom.isel(M=0,L=0).sel(lat=0.5,lon=-130.5).plot()
# smyle_anom_detr.isel(M=0,L=0).sel(lat=0.5,lon=-130.5).plot()

plt.axhline(thold.isel(L=0).sel(lat=0.5,lon=-130.5))

In [None]:
# define extremes!
# change the '<' and '>' depending on the threshold
smyle_extreme = smyle_anom.where(smyle_anom > thold)

In [None]:
smyle_extreme.sel(Y=2019).sum('M').isel(L=8).plot()

In [None]:
binary = ~np.isnan(smyle_extreme)

binary.sum(('M','lat','lon')).plot()
plt.show()

In [None]:
# save out the binary file!
binary = ~np.isnan(smyle_extreme)

# binary.to_dataset(name='binary').to_netcdf('/glade/work/smogen/SMYLE-extremes/thresholds/' + var +  '.monthly.' + depth + '.' + init + '.binary.Rolling.' + str(level) + '.nc')
binary.to_dataset(name='binary').to_netcdf('/glade/work/smogen/SMYLE-extremes/thresholds/' + var +  '.monthly.' + depth + '.binary.Rolling.live11.update.detrend.new_run.nc')

## Observations

In [None]:
var = 'omega_ar' # temperature, 
ds = xr.open_dataset('/glade/work/smogen/SMYLE-extremes/OceanSODA-ETHZ_GRaCER_v2021a_1982-2020.nc')[var]

In [None]:
# remove climatology
ds = ds.groupby('time.month') - ds.groupby('time.month').mean()

# remove trend - select level of polynomial based on variable
# ds = detrend_linear(ds,'time')
ds = detrend_second(ds,'time')

In [None]:
level= 0.1

In [None]:
%%time

ds_thold = []

# 1st month (January, which is three month average of DJF)
first_month = ds[(ds.time.dt.month >= 12) | (ds.time.dt.month <= 2)].quantile(level,dim='time',skipna=True)
first_month = first_month.expand_dims('month_arr')
ds_thold.append(first_month)

# 2nd through 11th month (February to November)
for i in range(2,12):
    tst = ds[(ds.time.dt.month >= i) & (ds.time.dt.month <= i)].quantile(level,dim='time',skipna=True)
    tst = tst.expand_dims('month_arr')
    ds_thold.append(tst)
    
# 12th month (December, which is three month average of NDJ
last_month = ds[(ds.time.dt.month >= 11) | (ds.time.dt.month <= 1)].quantile(level,dim='time',skipna=True)
last_month = last_month.expand_dims('month_arr')
ds_thold.append(last_month)

In [None]:
threshold = xr.concat(ds_thold,dim='month_arr')

In [None]:
threshold['month_arr'] = threshold.month_arr + 1

In [None]:
# define the quantile, select extreme values, convert to binary
ds_extreme = ds.where(ds.groupby('time.month') < threshold.rename({'month_arr':'month'}))
ds_extreme = ~np.isnan(ds_extreme)

In [None]:
ds_extreme = ds_extreme.drop('month').to_dataset(name='threshold')

In [None]:
ds_extreme.threshold.sum(('lat','lon')).plot()

In [None]:
ds_extreme.threshold.sel(time='1997-12').plot()

In [None]:
# save out
ds_extreme.to_netcdf('/glade/work/smogen/SMYLE-extremes/' + var + '.obs.rolling.thold.Rolling2.nc',mode='w')