In [None]:
import os
from copy import deepcopy

import xarray as xr
import numpy as np
import pandas as pd

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
import cartopy.crs as ccrs
import matplotlib.dates as mdates
import matplotlib.lines as mpll
import matplotlib.gridspec as mplgs
from pandas.plotting import register_matplotlib_converters

In [None]:
rnd = np.random.RandomState(42)

In [None]:
plt.style.use('paper')
plt.style.use('egu_journals')
register_matplotlib_converters()
rotated_pole = ccrs.RotatedPole(pole_longitude=-171.0, pole_latitude=41.5)
plate_carree = ccrs.PlateCarree()

# Load data

In [None]:
base_path = '/work/um0203/u300636/for2131/runs/da_enkf_for_soil'

In [None]:
h2o_path = os.path.join(base_path, '016', 'h2o_cleaned.nc')
ds_h2o = xr.open_dataset(h2o_path)['H2OSOI'].squeeze(drop=True).isel(levsoi=4)

In [None]:
t_path = os.path.join(base_path, '016', 'temp.nc')
ds_t = xr.open_dataset(t_path)['T'].squeeze(drop=True).isel(level=-1)

In [None]:
precip_path = os.path.join(base_path, '016', 'precip_cleaned.nc')
ds_precip = xr.open_dataset(precip_path)['TOT_PREC'].squeeze(drop=True)

In [None]:
const_path = '/work/um0203/u300636/for2131/runs/da_enkf_for_soil/initial/orig/ens001/clmoas.clm2.h0.2015-07-30-00900.nc'
const_data = xr.open_dataset(const_path).isel(time=0)
sat_point = const_data['WATSAT'].isel(levsoi=4)

## Data cleaning

In [None]:
h2o_time_ind = ds_h2o.indexes['time'].minute == 0
h2o_time_ind[0] = True
ds_h2o_sel = ds_h2o.sel(time=h2o_time_ind)
ds_smi_sel = (ds_h2o_sel / sat_point)[:, 30:-30, 30:-30]

In [None]:
t2m_time_ind = ds_t.indexes['time'].minute == 0
t2m_time_ind[0] = True
ds_t2m_sel = ds_t.sel(time=t2m_time_ind)
ds_t2m_sel = ds_t2m_sel.sel(time=~ds_t2m_sel.indexes['time'].duplicated())
ds_t2m_sel = ds_t2m_sel[:, 10:-10, 10:-10]

In [None]:
ds_precip_sel = ds_precip[:36].diff('time')
ds_precip_sel = xr.concat([ds_precip_sel, ds_precip[36:]], dim='time')
ds_precip_sel = (ds_precip_sel>0)[:, 10:-10, 10:-10]

# Plot mean values

In [None]:
smi_mean = ds_smi_sel.median(['lat', 'lon'])
t2m_mean = ds_t2m_sel.median(['rlat', 'rlon'])
precip_vals = ds_precip_sel.mean(['rlat', 'rlon'])

In [None]:
time_ticks = pd.date_range('2015-07-30', '2015-08-08', freq='2D')

In [None]:
figsize = [f*0.48 for f in deepcopy(plt.rcParams['figure.figsize'])]
grid_spec = mplgs.GridSpec(nrows=10, ncols=1)
fig = plt.figure(figsize=figsize)

ax_t2m = fig.add_subplot(grid_spec[:7, :])
ax_precip = ax_t2m.twinx()

ax_t2m.fill_betweenx(x1=[pd.to_datetime('2015-07-31 12:00')]*2, y=[270, 320], color='0.8', alpha=0.35, lw=0)

ax_precip.fill_between(x=precip_vals.time.values, y1=precip_vals, color='deepskyblue', zorder=-999, alpha=0.5, lw=0)
ax_precip.set_ylim(0, 1)
ax_precip.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9])
ax_precip.set_ylabel('Precipitation')
ax_precip.spines["left"].set_position(("axes", -0.25))
ax_precip.yaxis.tick_left()
ax_precip.yaxis.set_label_position('left')

ax_t2m.plot(t2m_mean.time, t2m_mean, c='firebrick', lw=1)
ax_t2m.text(x=0.02, y=0.98, s='(a)', transform=ax_t2m.transAxes, va='top', ha='left')
ax_t2m.set_xticks([])
ax_t2m.set_xlim(pd.to_datetime('2015-07-29 18:00'), pd.to_datetime('2015-08-08 00:00'))
ax_t2m.set_yticks([285, 290, 295, 300, 305])
ax_t2m.set_ylim(282, 308)
ax_t2m.set_ylabel('Temp (K)')

ax_land = fig.add_subplot(grid_spec[7:, :])

ax_land.fill_betweenx(x1=[pd.to_datetime('2015-07-31 12:00')]*2, y=[-0.5, 1.5], color='0.8', alpha=0.35, lw=0)
ax_land.text(x=pd.to_datetime('2015-07-30 15:00'), y=0.03, s='spin-up', ha='center', va='bottom')

ax_land.plot(smi_mean.time, smi_mean, c='sienna', lw=1)
ax_land.text(x=0.02, y=0.95, s='(b)', transform=ax_land.transAxes, va='top', ha='left')
ax_land.set_ylim(-0.05, 1.05)
ax_land.set_yticks([0.1, 0.5, 0.9])
ax_land.set_ylabel('SAT')
ax_land.set_xlim(pd.to_datetime('2015-07-29 18:00'), pd.to_datetime('2015-08-08 00:00'))
ax_land.set_xlabel('Date (2015 UTC)')
ax_land.set_xticks(time_ticks)
ax_land.set_xticklabels(time_ticks.strftime('%m-%d'))
fig.align_ylabels([ax_t2m, ax_land])
fig.subplots_adjust(wspace=0.1, hspace=0.02)
fig.savefig('../figures/fig_03_overview.png', dpi=300)
plt.show()