# Filtered regression models1979-2014

# I think we'll use 8 year high and low pass filters. Should mean ENSO is high freq but everything longer is low freq.

# Should we just be filtering temperature and not sea ice? Don't think so.

### NEXT DO REGRESSING MONTHS 1-9 ON SEPT SIE. ALSO DIFFERENT MODELS. ALSO MAPS

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import scipy.stats as stats 
import os as os
import warnings
import matplotlib.pyplot as plt

from IPython.display import clear_output
from scipy import signal

In [None]:
import preprocess_utils as pputils
import analysis_utils as autils 

In [None]:
all_months = [1,2,3,4,5,6,7,8,9,10,11,12]
month_names = ['Jan','Feb','Mar','Apr','May','June','July','Aug','Sept','Oct','Nov','Dec','Annual']

### Process CMIP6 models: 

This is creating the list of models to loop over. I've had problems with some models, and have just been removing them from the list if they cause me problems for now :) Eventually we'll need to deal with the issues and try to get them working...

In [None]:
path = '/glade/work/mkbren/cmip6-data/cmip6-preprocessed/historical/'

preprocessed_files = os.listdir(path)
preprocessed_files.sort()
preprocessed_models = [d.split('_')[0] for d in preprocessed_files]
preprocessed_models = list(set(preprocessed_models))
preprocessed_models.sort()
# preprocessed_models.remove('ACCESS-CM2')
#preprocessed_models.remove('AWI-ESM-1-1-LR')
# preprocessed_models.remove('KIOST-ESM')
# preprocessed_models.remove('NESM3')
#preprocessed_models.remove('ACCESS-ESM1-5')
preprocessed_models.remove('CIESM')
preprocessed_models.remove('FGOALS-f3-L')
print(len(preprocessed_models))
print(preprocessed_models)

#for p in preprocessed_files:
#    print(p)

In [None]:
# I'm not sure where anomalies are being calculated. In preprocessing? Need to track that down.
# Looks like tas_gm_anom positive all years all months for 1979-2014 for at least one model? Weird.

In [None]:
# Butterworth filter functions

def butter_filt(x, filt_year, fs, order_butter, ftype):
    #filt_year = 1 #1 year
    #fs = 12 #monthly data
    # if ftype = band provide filt_year as a list or np array
    #fn = fs/2; # Nyquist Frequency
    fc = (1 / np.array(filt_year)) / 2 # cut off frequency 1sample/ 1year = (1/1)/2 equals 1 year filter (two half cycles/sample)
    #fc = (1/2)/2 # cut off frequency 1sample/ 2year = (1/1)/2 equals 2 year filter (two half cycles/sample)
    #fc = (1/4)/2 # cut off frequency 1sample/ 4year = (1/1)/2 equals 4 year filter (two half cycles/sample)
    #ftype = "low", "high" or "band"
    b, a = signal.butter(order_butter, fc, ftype, fs=fs, output='ba')

    return signal.filtfilt(b, a, x)

def filtfilt_butter(x, filt_year, fs, order_butter, ftype, dim='time'):
    # x ...... xr data array
    # dims .... dimension aong which to apply function    
    filt = xr.apply_ufunc(
                butter_filt,  # first the function
                x,# now arguments in the order expected by 'butter_filt'
                filt_year,  # as above
                fs,  # as above
                order_butter,  # as above
                ftype,
                input_core_dims=[[dim], [], [], [], []],  # list with one entry per arg
                output_core_dims=[[dim]],  # returned data has 3 dimension
                exclude_dims=set((dim,)),  # dimensions allowed to change size. Must be a set!
                vectorize=True,  # loop over non-core dims
                )

    return filt

def filtfilt_butter_monthly(x, filt_year, fs, order_butter, ftype, dim='time'):
    """
    Use this to filter each month of a timeseries separately and combine
    
    in:
    x = unfiltered vector time series e.g.: era_data['t2m_arc_mean_anom']
    others as for filtfilt_butter
    
    out:
    temp = filtered output
    """
    temp = np.zeros_like(x) * np.nan
    for m in range(0, 12):
        temp[m : len(temp) : 12] = filtfilt_butter(x.sel(time=x.time.dt.month == m+1),
                                                  filt_year=filt_year, fs=fs, order_butter=order_butter,
                                                  ftype=ftype, dim=dim)
    temp = xr.DataArray(temp, coords={'time': x.time}, dims=['time'])
    return temp

In [None]:
preprocessed_files_pan = []

# Fill these with each model's slopes
slope_annual, r_annual, intercept_annual = {}, {}, {}
slope_high_annual, r_high_annual, intercept_high_annual = {}, {}, {}
slope_low_annual, r_low_annual, intercept_low_annual = {}, {}, {}
slope, r, intercept = {}, {}, {}
slope_high, r_high, intercept_high = {}, {}, {}
slope_low, r_low, intercept_low = {}, {}, {}

for m, mod in enumerate(preprocessed_models):
    print(mod)
    if mod + '_historical_siconc_sivol_tas_pan_arctic_fields.nc' in preprocessed_files:
        fname = mod + '_historical_siconc_sivol_tas_pan_arctic_fields.nc'
        preprocessed_files_pan.append(fname)
    elif mod + '_historical_siconc_tas_pan_arctic_fields.nc' in preprocessed_files:
        fname = mod + '_historical_siconc_tas_pan_arctic_fields.nc'
        preprocessed_files_pan.append(fname)
    else:
        continue
    
    data = xr.load_dataset(os.path.join(path, fname))
    data = data.sel(time = slice("1979-01-01", "2014-12-31"))
    
    # Calculate annual anomalies
    data['tas_gm_annual'] = data.tas_gm.groupby('time.year').mean()
    data['tas_gm_annual_anom'] = data.tas_gm_annual - data.tas_gm_annual.mean(dim='year')
    
    data['sie_total_annual'] = data.sie_arc_tot.groupby('time.year').mean()
    data['sie_total_annual_anom'] = data.sie_total_annual - data.sie_total_annual.mean(dim='year')
    
    # Filter annual anomalies
    data['tas_gm_annual_anom_low_8'] = filtfilt_butter(data['tas_gm_annual_anom'], filt_year=8, fs=1,
                                                       order_butter=5, ftype='low', dim='year')
    data['tas_gm_annual_anom_high_8'] = filtfilt_butter(data['tas_gm_annual_anom'], filt_year=8, fs=1,
                                                       order_butter=5, ftype='high', dim='year')
    data['sie_total_annual_anom_low_8'] = filtfilt_butter(data['sie_total_annual_anom'], filt_year=8, fs=1,
                                                           order_butter=5, ftype='low', dim='year')
    data['sie_total_annual_anom_high_8'] = filtfilt_butter(data['sie_total_annual_anom'], filt_year=8, fs=1,
                                                           order_butter=5, ftype='high', dim='year')    

    # Filter monthly anomalies
    data['tas_gm_anom_low_8'] = filtfilt_butter_monthly(data['tas_gm_anom'], filt_year=8, fs=1,
                                                        order_butter=5, ftype='low', dim='time')
    data['tas_gm_anom_high_8'] = filtfilt_butter_monthly(data['tas_gm_anom'], filt_year=8, fs=1,
                                                        order_butter=5, ftype='high', dim='time')
    data['sie_total_anom_low_8'] = filtfilt_butter_monthly(data['sie_tot_arc_anom'], filt_year=8, fs=1,
                                                        order_butter=5, ftype='low', dim='time')
    data['sie_total_anom_high_8'] = filtfilt_butter_monthly(data['sie_tot_arc_anom'], filt_year=8, fs=1,
                                                        order_butter=5, ftype='high', dim='time')
    
    # Calculate annual slopes
    slope_annual[mod], intercept_annual[mod], r_annual[mod], _, _ = stats.linregress(data['tas_gm_annual_anom'],
                                                                                     data['sie_total_annual_anom'])

    slope_low_annual[mod], intercept_low_annual[mod], r_low_annual[mod], _, _ = stats.linregress(data['tas_gm_annual_anom_low_8'],
                                                                                                       data['sie_total_annual_anom_low_8'])
    
    slope_high_annual[mod], intercept_high_annual[mod], r_high_annual[mod], _, _ = stats.linregress(data['tas_gm_annual_anom_high_8'],
                                                                                                          data['sie_total_annual_anom_high_8']) 
    
    # Calculate monthly slopes
    [slope[mod], r[mod], intercept[mod]] = autils.scatter_linreg(data['tas_gm_anom'],
                                                                 data['sie_tot_arc_anom'],
                                                                 data, all_months,
                                                                 mod, plotflag=False)
    
    [slope_low[mod], r_low[mod], intercept_low[mod]] = autils.scatter_linreg(data['tas_gm_anom_low_8'],
                                                                             data['sie_total_anom_low_8'],
                                                                             data, all_months,
                                                                             mod, plotflag=False)

    [slope_high[mod], r_high[mod], intercept_high[mod]] = autils.scatter_linreg(data['tas_gm_anom_high_8'],
                                                                                data['sie_total_anom_high_8'],
                                                                                data, all_months,
                                                                                mod, plotflag=False)

In [None]:
months = list(range(1, 12+1))

fig, ax = plt.subplots(6, 5, figsize=(16, 20))
ax = ax.flatten()
a = 1 # axis counter, reserving 0 for obs and 1 for multimodel mean

for mod in list(slope.keys()):
    a += 1
    
    ax[a].hlines(0, 1, 13, 'grey')
    ax[a].plot(months, np.array(slope[mod]) / 1e6, 'k')
    ax[a].plot(months, np.array(slope_low[mod]) / 1e6, 'b')
    ax[a].plot(months, np.array(slope_high[mod]) / 1e6, 'r')

    ax[a].plot(13, np.array(slope_annual[mod]) / 1e6, '*k')
    ax[a].plot(13, np.array(slope_low_annual[mod]) / 1e6, '*b')
    ax[a].plot(13, np.array(slope_high_annual[mod]) / 1e6, '*r')

    ax[a].set_ylim([-5, 1.5])
    ax[a].set_title(mod)
    #ax[a].set_xlabel('Month')
    #ax[a].set_ylabel('Slope');

In [None]:
months = list(range(1, 12+1))

fig, ax = plt.subplots(6, 5, figsize=(16, 20))
ax = ax.flatten()
a = 1 # axis counter, reserving 0 for obs and 1 for multimodel mean

for mod in list(r.keys()):
    a += 1
    
    ax[a].hlines(0, 1, 13, 'grey')
    ax[a].hlines(1, 1, 13, 'grey')
    ax[a].plot(months, np.array(r[mod]) ** 2, 'k')
    ax[a].plot(months, np.array(r_low[mod]) ** 2, 'b')
    ax[a].plot(months, np.array(r_high[mod]) ** 2, 'r')

    ax[a].plot(13, np.array(r_annual[mod]) ** 2, '*k')
    ax[a].plot(13, np.array(r_low_annual[mod]) ** 2, '*b')
    ax[a].plot(13, np.array(r_high_annual[mod]) ** 2, '*r')

    ax[a].set_ylim([-0.05, 1.05])
    ax[a].set_title(mod)
    #ax[a].set_xlabel('Month')
    #ax[a].set_ylabel('...R^2');

# Below is from OBS

In [None]:
months = list(range(1, 12+1))

plt.hlines(0, 1, 13, 'grey')
plt.plot(months, slopes_obs, 'k')
plt.plot(months, slopes_obs_low_8, 'b')
plt.plot(months, slopes_obs_high_8, 'r')

plt.plot(13, slope_an, '*k')
plt.plot(13, slope_an_low_8, '*b')
plt.plot(13, slope_an_high_8, '*r')

plt.xlabel('Month')
plt.ylabel('Slope');

# Why not shoulder seasons like Katie's?

In [None]:
months = np.array(list(range(1, 12+1)))

plt.hlines(0, 1, 13, 'grey')
plt.hlines(1, 1, 13, 'grey')
plt.plot(months, np.array(r_obs) ** 2, 'k')
plt.plot(months, np.array(r_obs_low_8) ** 2, 'b')
plt.plot(months, np.array(r_obs_high_8) ** 2, 'r')

plt.plot(13, np.array(r_value_annual) ** 2, '*k')
plt.plot(13, np.array(r_value_annual_low_8) ** 2, '*b')
plt.plot(13, np.array(r_value_annual_high_8) ** 2, '*r')

plt.xlabel('Month')
plt.ylabel('R $^2$');

## Save reuslts (code not written yet)

In [None]:
#sstr = 'slopes_obs_filt_5_37_years_1979_2014.npz'
#np.savez(sstr,
#         slopes_annual_mean = slopes_annual_mean,
#         slopes_mean = slopes_mean)

# Load results
# meh = np.load(sstr, allow_pickle=True)
# meh.files
# meh['slopes_annual_mean'][()]
# meh['slopes_mm_mean'].shape

In [None]:
#sstr = 'r_obs_windows_5_37_years_1979_2014.npz'
#np.savez(sstr,
#         r_annual_mean = r_annual_mean,
#         r_mean = r_mean)

# Load results
# meh = np.load(sstr, allow_pickle=True)
# meh.files
# meh['r_annual_mean'][()]
# meh['r_mm_mean'].shape