In [None]:
import os
from copy import deepcopy

import xarray as xr
import numpy as np
import pandas as pd
import scipy.stats
import scipy.ndimage
from tqdm import tqdm_notebook as tqdm
import scipy.optimize


import distributed

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
import cartopy.crs as ccrs
import matplotlib.dates as mdates
import matplotlib.lines as mpll

In [None]:
rnd = np.random.RandomState(42)

In [None]:
plt.style.use('paper')
plt.style.use('egu_journals')
plt.rcParams['text.usetex'] = False
rotated_pole = ccrs.RotatedPole(pole_longitude=-171.0, pole_latitude=41.5)
plate_carree = ccrs.PlateCarree()

In [None]:
std_corr = np.sqrt(41 / 40)
DENSITY = 1000
RL = 287.05
RW = 461.45
LAM_VAP = 2.501E6
CP = 1.00464E3

In [None]:
cluster = distributed.LocalCluster(local_directory='/tmp')
client = distributed.Client(cluster)
client

# Load data

In [None]:
base_path = '/work/um0203/u300636/for2131/runs/da_enkf_for_soil'

## H2O

In [None]:
da_h2o_path = os.path.join(base_path, '020', 'h2o_cleaned.nc')
da_h2o = xr.open_dataset(da_h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4).chunk((40, 1, 302, 267))[..., 30:-30, 30:-30]

In [None]:
da_fluxes_path = os.path.join(base_path, '020', 'fluxes.nc')
da_fluxes = xr.open_dataset(da_fluxes_path).squeeze(drop=True).chunk({'ensemble':40, 'time':1, 'lat':302, 'lon':267}).isel(lat=slice(30, -30), lon=slice(30, -30))

## T2m

In [None]:
da_t2m_path = os.path.join(base_path, '020', 't2m_cleaned.nc')
da_t2m = xr.open_dataset(da_t2m_path)['T_2M'].squeeze(drop=True).chunk((40, 1, 109, 99))

In [None]:
da_rh2m_path = os.path.join(base_path, '020', 'rh2m.nc')
da_rh2m = xr.open_dataset(da_rh2m_path)['RELHUM_2M'].squeeze(drop=True).chunk((40, 1, 109, 99))

## Additional data

In [None]:
const_path = '/work/um0203/u300636/for2131/runs/utilities/clm_aux.nc'
const_data = xr.open_dataset(const_path).isel(time=0)
sat_point = const_data['WATSAT'].isel(levsoi=4)[..., 30:-30, 30:-30]

## Pre-process data

In [None]:
da_bg_t2m = da_t2m.isel(time=~da_t2m.indexes['time'].duplicated())[:, 1:]

## Get clm coordinates

In [None]:
prep_clm = da_h2o.stack(grid=['lat', 'lon'])
clm_coords_rotated = rotated_pole.transform_points(plate_carree, prep_clm.lon.values, prep_clm.lat.values)
clm_rot_index = pd.MultiIndex.from_arrays([clm_coords_rotated[:, 0], clm_coords_rotated[:, 1]], names=['rlon', 'rlat'])
clm_rlon = xr.DataArray(clm_coords_rotated[:, 0], coords={'grid': prep_clm.grid}, dims=['grid'])
clm_rlat = xr.DataArray(clm_coords_rotated[:, 1], coords={'grid': prep_clm.grid}, dims=['grid'])

In [None]:
da_bg_t2m = da_bg_t2m.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat']).unstack('grid')
da_bg_rh2m = da_rh2m.interp(rlon=clm_rlon, rlat=clm_rlat, method='linear').drop(['rlon', 'rlat']).unstack('grid')

In [None]:
da_sensible = da_fluxes['FSH']

## Sensible plot

In [None]:
def get_cov(x, y, dim='ensemble', ddof=1):
    x_perts = x - x.mean(dim=dim)
    y_perts = y - y.mean(dim=dim)
    
    cov = xr.dot(x_perts, y_perts, dims=dim) / (x.count(dim) - ddof)
    return cov

In [None]:
def get_corr(x, y, dim='ensemble'):
    cov_mat = get_cov(x, y, dim=dim, ddof=0)
    corr_mat = cov_mat / x.std(dim, ddof=0) / y.std(dim, ddof=0)
    return corr_mat

In [None]:
def estimate_water_pressure(temp):
    temp_c = temp - 273.15
    factor = 17.62 * temp_c / (temp_c + 243.12)
    pressure = 611.2 * np.exp(factor)
    return pressure

In [None]:
da_sensible = da_sensible[:, 3::4][:, 6:]
da_h2o = da_h2o[:, 3::4][:, 6:]
da_bg_t2m = da_bg_t2m.sel(time=da_h2o.time)
da_bg_rh2m = da_bg_rh2m.sel(time=da_h2o.time)

In [None]:
da_smi = da_h2o / sat_point
da_e_press = estimate_water_pressure(da_bg_t2m) * da_bg_rh2m / 100

In [None]:
da_qv = RL / RW * da_e_press / (101325-da_e_press)

In [None]:
cycle_qv = LAM_VAP * da_qv.groupby('time.hour').mean(['time', 'lat', 'lon', 'ensemble']).roll(hour=1, roll_coords=True).load()
cycle_sensible = da_sensible.groupby('time.hour').mean(['time', 'lat', 'lon', 'ensemble']).roll(hour=1, roll_coords=True).load()

In [None]:
corr_sens = np.tanh(np.arctanh(get_corr(da_sensible, da_bg_t2m)).groupby('time.hour').mean(['time', 'lat', 'lon'])).roll(hour=1, roll_coords=True).load()
corr_qv = np.tanh(np.arctanh(get_corr(da_qv, da_bg_t2m)).groupby('time.hour').mean(['time', 'lat', 'lon'])).roll(hour=1, roll_coords=True).load()
corr_h2o = np.tanh(np.arctanh(get_corr(da_h2o, da_bg_t2m)).groupby('time.hour').mean(['time', 'lat', 'lon'])).roll(hour=1, roll_coords=True).load()

In [None]:
def align_yaxis(ax1, v1, ax2, v2):
    """adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1"""
    _, y1 = ax1.transData.transform((0, v1))
    _, y2 = ax2.transData.transform((0, v2))
    inv = ax2.transData.inverted()
    _, dy = inv.transform((0, 0)) - inv.transform((0, y1-y2))
    miny, maxy = ax2.get_ylim()
    ax2.set_ylim(miny+dy, maxy+dy)

In [None]:
figsize = [s for s in plt.rcParams['figure.figsize']]
figsize[0] *= 0.48
fig, ax = plt.subplots(nrows=2, figsize=figsize)

ax[0].axhline(c='black', lw=0.5)
ax[1].axhline(c='black', lw=0.5)

ax[0].plot(np.arange(-1, 23), corr_sens, lw=0.8, c='salmon', label='Sensible')
ax[0].plot(np.arange(-1, 23), corr_qv, lw=0.8, c='C0', label='QV2m')
h2o_plt = ax[0].plot(np.arange(-1, 23), corr_h2o, lw=0.8, c='black', label=r'H2O$_{soil}$')

lam_plt = ax[1].plot(np.arange(-1, 23), cycle_qv-cycle_qv[0], lw=0.8, c='C0', label=r'$\lambda_{vap}$ * QV2m')

flux_ax = ax[1].twinx()
sen_plt = flux_ax.plot(np.arange(-1, 23), cycle_sensible, lw=0.8, c='salmon', label='Sensible')

ax[0].text(x=0.02, y=0.98, s='a)', transform=ax[0].transAxes, va='top', ha='left')
ax[1].text(x=0.02, y=0.98, s='b)', transform=ax[1].transAxes, va='top', ha='left')

legend = ax[0].legend(handles=[sen_plt[0], lam_plt[0], h2o_plt[0]], loc=8, fancybox=False, edgecolor='black',
                      ncol=3, bbox_to_anchor=(0.42, 1.0))
legend.get_frame().set_linewidth(0.8)

ax[1].set_ylabel(r'$\Delta$ Heat content (J/kg)')
flux_ax.spines["left"].set_position(("axes", -0.3))
flux_ax.yaxis.tick_left()
flux_ax.yaxis.set_label_position('left')
flux_ax.set_ylabel('Heat flux (W/m$^2$)')
#flux_ax.spines['right'].set_visible(True)
flux_ax.set_ylim(-100, 200)

align_yaxis(ax[1], 0, flux_ax, 0)

ax[0].set_ylabel('Correlation to T2m')
ax[0].set_ylim(-1, 1)
ax[0].set_yticks(np.arange(-0.75, 1, 0.25))
ax[0].spines['bottom'].set_visible(True)
ax[0].set_xticks([])
ax[0].set_xlim(-2, 23)

ax[1].set_xlim(-2, 23)
ax[1].set_xlabel('Time of day (UTC)')
ax[1].set_xticks(np.arange(0, 24, 6))

fig.align_ylabels(ax)
fig.subplots_adjust(hspace=0.005, wspace=0)

fig.savefig('../figures/fig_10_cycle_corr.png', dpi=300)