# Plots of atmospheric heat transport

In [1]:
# Purpose: Calculate the moist static energy (MSE) transport for output from the CESM model - both to verify control climate
# and to find difference after 4xCO2

# By: Ty Janoski
# Updated: 11.15.2021

## Setup

In [2]:
# import statments
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import scipy as sp
from cftime import DatetimeNoLeap
from dask.diagnostics import ProgressBar


%matplotlib inline
%config InlineBackend.figure_format = "pdf"

In [3]:
def read_in(exp,mon,ens,var):
    """
    Use xarray to read in a netCDF file.

    Keyword arguments:
    exp -- CO2 scenario
    mon -- starting month in which CO2 is altered
    ens -- ensemble number
    var -- model output variable
    """
    filein = '/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.'+exp+'.'+str(
        f"{mon:02d}")+'.'+str(f"{ens:02d}")+'.h1_'+var+'.nc'
    return(xr.open_dataset(filein,chunks=None))

In [4]:
def preprocess(ds):
    dsnew = ds.copy()
    dsnew['time'] = np.arange(1,731,1)
    return dsnew

In [5]:
# read in cell area file for taking spatial averages
areacella = xr.open_dataarray('/dx01/janoski/cesm/output/gridarea.nc')
area=areacella
# create function for taking spatial averages, while weighting for latitude
def spav(ds_in, areacella=areacella, lat_bound_s = -91, lat_bound_n = 91):
    """
    Use xarray/numpy to calculate spatial average while weighting for latitude.
    
    Keyword arguments:
    ds_in -- Dataset or DataArray to take the average of with coords lat and lon
    areacella -- Dataset or DataArray containing the grid cell area with coords lat and lon
    lat_bound_s -- float, Southern boundary of area to average
    lat_bound_n -- float, Northern boundary of area to average
    """
    ds_in = ds_in.sel(lat=slice(lat_bound_s,lat_bound_n))
    areacella = areacella.sel(lat=slice(lat_bound_s,lat_bound_n))
    out = (ds_in*(areacella/areacella.sum(dim=['lat','lon']))).sum(dim=['lat','lon'])
    return(out)

## January initializations

In [6]:
# calculated as residual

ctrl_Fconv_r = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.01.*.h1_Fconv_resi.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
ctrl_Wconv_r = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.01.*.h1_Wconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
ctrl_Sconv_r = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.01.*.h1_Sconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

# calculated explicitly
ctrl_Fconv_e = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.01.*.h1_Fwall_70S_expl.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

In [7]:
ctrl_MSE_r = spav(ctrl_Fconv_r.mean(dim='ens'),lat_bound_n=-70)
ctrl_MSE_e = ctrl_Fconv_e.mean(dim='ens')
ctrl_Wconv = spav(ctrl_Wconv_r.mean(dim='ens'),lat_bound_n=-70)
with ProgressBar():
    ctrl_MSE_r = ctrl_MSE_r.compute()
    ctrl_Wconv = ctrl_Wconv.compute()
    ctrl_MSE_e = ctrl_MSE_e.compute()

[########################################] | 100% Completed |  3min 58.9s
[########################################] | 100% Completed |  1min 43.8s
[########################################] | 100% Completed |  1.7s


In [8]:
ctrl_MSE_r.AHT.plot()
ctrl_MSE_e.MSE.plot()

[<matplotlib.lines.Line2D at 0x7f8fb4fa44f0>]

<Figure size 432x288 with 1 Axes>

In [9]:
ctrl_MSE_e.VQ.plot()
ctrl_Wconv.Wconv.plot()

[<matplotlib.lines.Line2D at 0x7f8fb4a6f400>]

<Figure size 432x288 with 1 Axes>

In [10]:
# calculated as residual

diff_Fconv_r = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.01.*.h1_AHT_resi.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
diff_Wconv_r = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.01.*.h1_Wconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
diff_Sconv_r = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.01.*.h1_Wconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

# calculated explicitly
diff_Fconv_e = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.01.*.h1_AHT_70S_expl.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

In [11]:
diff_MSE_r = spav(diff_Fconv_r.mean(dim='ens'),lat_bound_n=-70)
diff_MSE_e = diff_Fconv_e.mean(dim='ens')
diff_Wconv = spav(diff_Wconv_r.mean(dim='ens'),lat_bound_n=-70)
with ProgressBar():
    diff_MSE_r = diff_MSE_r.compute()
    diff_Wconv = diff_Wconv.compute()
    diff_MSE_e = diff_MSE_e.compute()

[########################################] | 100% Completed | 17min 40.6s
[########################################] | 100% Completed |  2min 50.4s
[########################################] | 100% Completed |  1.5s


In [19]:
diff_MSE_r.AHT.plot()
diff_MSE_e.MSE.plot(linestyle='dotted')

[<matplotlib.lines.Line2D at 0x7f8fb4fa3700>]

<Figure size 432x288 with 1 Axes>

In [20]:
diff_MSE_e.VQ.plot()
diff_Wconv.Wconv.plot(linestyle='dotted')

[<matplotlib.lines.Line2D at 0x7f8fb5169df0>]

<Figure size 432x288 with 1 Axes>

## July init

In [14]:
# calculated as residual

ctrl_Fconv_r = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.07.*.h1_Fconv_resi.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
ctrl_Wconv_r = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.07.*.h1_Wconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
ctrl_Sconv_r = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.07.*.h1_Sconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

# calculated explicitly
ctrl_Fconv_e = xr.open_mfdataset('/dx02/janoski/cesm/ctrl_4xCO2_transports/b40.1850.cam5-lens.ctrl.07.*.h1_Fwall_70S_expl.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

In [15]:
ctrl_MSE_r = spav(ctrl_Fconv_r.mean(dim='ens'),lat_bound_n=-70)
ctrl_MSE_e = ctrl_Fconv_e.mean(dim='ens')
ctrl_Wconv = spav(ctrl_Wconv_r.mean(dim='ens'),lat_bound_n=-70)
with ProgressBar():
    ctrl_MSE_r = ctrl_MSE_r.compute()
    ctrl_Wconv = ctrl_Wconv.compute()
    ctrl_MSE_e = ctrl_MSE_e.compute()

[##                                      ] | 5% Completed | 18.1s


KeyboardInterrupt: 

In [None]:
ctrl_MSE_r.AHT.plot()
ctrl_MSE_e.MSE.plot()

In [None]:
ctrl_MSE_e.VQ.plot()
ctrl_Wconv.Wconv.plot()

In [None]:
# calculated as residual

diff_Fconv_r = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.07.*.h1_AHT_resi.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
diff_Wconv_r = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.07.*.h1_Wconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')
diff_Sconv_r = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.07.*.h1_Wconv.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

# calculated explicitly
diff_Fconv_e = xr.open_mfdataset('/dx02/janoski/cesm/vert_int_feedbacks/b40.1850.cam5-lens.07.*.h1_AHT_70S_expl.nc',
                              preprocess=preprocess,combine='nested',concat_dim='ens')

In [None]:
diff_MSE_r = spav(diff_Fconv_r.mean(dim='ens'),lat_bound_n=-70)
diff_MSE_e = diff_Fconv_e.mean(dim='ens')
diff_Wconv = spav(diff_Wconv_r.mean(dim='ens'),lat_bound_n=-70)
with ProgressBar():
    diff_MSE_r = diff_MSE_r.compute()
    diff_Wconv = diff_Wconv.compute()
    diff_MSE_e = diff_MSE_e.compute()

In [None]:
diff_MSE_r.AHT.plot()
diff_MSE_e.MSE.plot()

In [None]:
diff_MSE_e.VQ.plot()
diff_Wconv.Wconv.plot()

In [None]:
# read in all control datasets, overwrite time axes, and produce one dataset with an ensemble coordinate
hold = []
for e in range(1,101,1):
    ds = read_in('ctrl',1,e,'flux_70N_lon')
    
    ds['time'] = np.arange(0,730,1)
    
    hold.append(ds)
    
# combine
ctrl = xr.concat(hold,dim='ensemble')
ctrl['ensemble'] = np.arange(1,101,1)

# repeat for 4xCO2 simulations

hold = []
for e in range(1,101,1):
    ds = read_in('4xCO2',1,e,'flux_70N_lon')
    
    ds['time'] = np.arange(0,730,1)
    
    hold.append(ds)
    
# combine
exp = xr.concat(hold,dim='ensemble')
exp['ensemble'] = np.arange(1,101,1)

In [None]:
ctrl['dayofyear'] = ctrl.time%365
clim = ctrl.groupby(ctrl.dayofyear).mean(dim='time')
area = 0.15e14

In [None]:
fig,ax = plt.subplots()

yvals = np.array([0,90,181,273])
ylabs = (['Jan 1','Apr 1','Jul 1','Oct 1'])

plot = (clim.MSE/area).mean(dim='ensemble').plot(x="lon")
ax.set_yticks(yvals)
ax.set_yticklabels(ylabs)
plt.title('Climatological MSE Transport (W/m$^2$) at 70°N')


plt.tight_layout()
plt.show()

In [None]:
# take difference as 4xCO2 - piControl

diff = exp - ctrl

In [None]:
fig,ax = plt.subplots(nrows=2,sharex=True,sharey=True,figsize=(5,6))

yvals = np.array([0,90,181,273])
yvals = np.concatenate([yvals,yvals+365])
ylabs = np.tile(['Jan 1','Apr 1','Jul 1','Oct 1'],2)

(diff.MSE/area).mean(dim='ensemble').plot(x="lon",ax=ax[0],cmap='RdBu_r',vmin=-65,vmax=65)
ax[0].set_yticks(yvals)
ax[0].set_yticklabels(ylabs)
ax[0].set_title('ΔMSE (W/m$^2$)')

((diff.VT + diff.VZ)/area).mean(dim='ensemble').plot(x="lon",ax=ax[1],cmap='RdBu_r',vmin=-65,vmax=65)
ax[1].set_title('ΔDSE (W/m$^2$)')
ax[1].set_xticks(np.arange(0,351,50))


plt.tight_layout()
plt.show()

In [None]:
fig,ax = plt.subplots()

yvals = np.array([0,90,181,273])
yvals = np.concatenate([yvals,yvals+365])
ylabs = np.tile(['Jan 1','Apr 1','Jul 1','Oct 1'],2)

plot = (diff.VQ/area).mean(dim='ensemble').plot(x="lon")
ax.set_yticks(yvals)
ax.set_yticklabels(ylabs)
plt.title('ΔVQ (W/m$^2$)')


plt.tight_layout()
plt.show()

In [None]:
ens_mean = diff.mean(dim='ensemble')
ens_mean['time'] = xr.cftime_range(start="0001-01-01 12:00:00",end="0002-12-31 12:00:00",freq='D',calendar='noleap')

In [None]:
ens_mean_mon = ens_mean.resample(time='1M').mean(dim='time')

In [None]:
fig,ax = plt.subplots(nrows=2,sharex=True,sharey=True,figsize=(5,6))

yvals = np.arange(0,24,3)
ylabs = np.tile(['Jan','Apr','Jul','Oct'],2)

(ens_mean_mon.MSE/area).plot(x="lon",ax=ax[0],cmap='RdBu_r',vmin=-65,vmax=65)
ax[0].set_yticks(yvals)
ax[0].set_yticklabels(ylabs)
ax[0].set_title('ΔMSE (W/m$^2$)')

((ens_mean_mon.VT + ens_mean_mon.VZ)/area).plot(x="lon",ax=ax[1],cmap='RdBu_r',vmin=-65,vmax=65)
ax[1].set_title('ΔDSE (W/m$^2$)')
ax[1].set_xticks(np.arange(0,351,50))



plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(2,4))

xs = np.zeros(len(trans.MSE.ensemble))
plt.scatter(xs,trans.MSE.mean(dim='time'),alpha=0.4,color='palevioletred')

plt.grid()
plt.tight_layout()
plt.show()

In [None]:
# create figure object
fig, ax = plt.subplots(ncols=1,nrows=4,figsize=(9,8),sharey=False,sharex=True)

xs = np.arange(1,731,1)

first_days = np.array([1,91,182,274])

x_vals = np.concatenate((first_days,first_days+365))
x_labs = np.tile(['Jan 1','Apr 1','Jul 1','Oct 1'],
                 (2))
mean = trans.MSE.mean(dim='ensemble')/area
std = (trans.MSE/area).std(dim='ensemble')

ax[0].plot(xs, mean, color='black',label='MSE')
ax[0].fill_between(xs, mean-std, mean+std,
                     color='black',alpha=0.4)
ax[0].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[0].tick_params(axis='x', rotation=45,labelsize=10)
ax[0].tick_params(axis='y',labelsize=10)
ax[0].set_ylabel(r'MSE at 70°N (W/m$^2$)',fontsize=10)
ax[0].set_xticks(x_vals)
ax[0].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[0].grid()

mean = trans.VT.mean(dim='ensemble')/area
std = (trans.VT/area).std(dim='ensemble')

ax[1].plot(xs, mean, color='red',label='MSE')
ax[1].fill_between(xs, mean-std, mean+std,
                     color='red',alpha=0.4)
ax[1].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[1].tick_params(axis='x', rotation=45,labelsize=10)
ax[1].tick_params(axis='y',labelsize=10)
ax[1].set_ylabel(r'VT',fontsize=10)
ax[1].set_xticks(x_vals)
ax[1].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[1].grid()

mean = trans.VZ.mean(dim='ensemble')/area / 9.81
std = (trans.VZ/area).std(dim='ensemble') / 9.81

ax[2].plot(xs, mean, color='green',label='MSE')
ax[2].fill_between(xs, mean-std, mean+std,
                     color='green',alpha=0.4)
ax[2].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[2].tick_params(axis='x', rotation=45,labelsize=10)
ax[2].tick_params(axis='y',labelsize=10)
ax[2].set_ylabel(r'VZ',fontsize=10)
ax[2].set_xticks(x_vals)
ax[2].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[2].grid()

mean = trans.VQ.mean(dim='ensemble')/area
std = (trans.VQ/area).std(dim='ensemble')

ax[3].plot(xs, mean, color='blue',label='MSE')
ax[3].fill_between(xs, mean-std, mean+std,
                     color='blue',alpha=0.4)
ax[3].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[3].tick_params(axis='x', rotation=45,labelsize=8)
ax[3].tick_params(axis='y',labelsize=8)
ax[3].set_ylabel(r'VQ',fontsize=10)
ax[3].set_xticks(x_vals)
ax[3].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[3].grid()

plt.tight_layout()
plt.show()

In [None]:
# create figure object
fig, ax = plt.subplots(figsize=(6,3))

xs = np.arange(1,366,1)

first_days = np.array([1,91,182,274])

x_vals = first_days
x_labs = ['Jan 1','Apr 1','Jul 1','Oct 1']
mean = trans.MSE.mean(dim='ensemble')/area
std = (trans.MSE/area).std(dim='ensemble')

ax.plot(xs, mean.isel(time=slice(None,365)), color='black',label='MSE')
ax.fill_between(xs, mean.isel(time=slice(None,365))-std.isel(time=slice(None,365)), mean.isel(time=slice(None,365))+std.isel(time=slice(None,365)),
                     color='black',alpha=0.4)
ax.set_xlim([1,366])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax.tick_params(axis='x', rotation=45,labelsize=10)
ax.tick_params(axis='y',labelsize=10)
ax.set_ylabel(r'MSE at 70°N (W/m$^2$)',fontsize=10)
ax.set_xticks(x_vals)
ax.set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax.grid()

plt.tight_layout()
plt.savefig('Fwall_arc_clim.svg')

Now that we were able to accurately depict the seasonal cycle of MSE transport into the Arctic, we can now see how this changes with 4xCO2. We will read in both the control and 4xCO2 experiments (preserving their calendars as a check to make sure everything is looking good)

In [None]:
# read in all control datasets, overwrite time axes, and produce one dataset with an ensemble coordinate
hold = []
for e in range(1,101,1):
    ctrl = read_in('ctrl',1,e,'flux_70N')
    exp = read_in('4xCO2',1,e,'flux_70N')
    
    diff = exp - ctrl
    
    diff['time'] = np.arange(0,730,1)
    hold.append(diff)
    
# combine
trans_diff = xr.concat(hold,dim='ensemble')
trans_diff['ensemble'] = np.arange(1,101,1)

In [None]:
# create figure object
fig, ax = plt.subplots(ncols=1,nrows=4,figsize=(8,8),sharey=False,sharex=True)

xs = np.arange(1,731,1)

first_days = np.array([1,91,182,274])

x_vals = np.concatenate((first_days,first_days+365))
x_labs = np.tile(['Jan 1','Apr 1','Jul 1','Oct 1'],
                 (2))
mean = trans_diff.MSE.mean(dim='ensemble')/area
std = (trans_diff.MSE/area).std(dim='ensemble')

ax[0].plot(xs, mean, color='black',label='MSE')
ax[0].fill_between(xs, mean-std, mean+std,
                     color='black',alpha=0.4)
ax[0].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[0].tick_params(axis='x', rotation=45,labelsize=10)
ax[0].tick_params(axis='y',labelsize=10)
ax[0].set_ylabel(r'ΔMSE transport at 70°N (W/m$^2$)',fontsize=10)
ax[0].set_xticks(x_vals)
ax[0].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[0].grid()

mean = trans_diff.VT.mean(dim='ensemble')/area
std = (trans_diff.VT/area).std(dim='ensemble')

ax[1].plot(xs, mean, color='red')
ax[1].fill_between(xs, mean-std, mean+std,
                     color='red',alpha=0.4)
ax[1].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[1].tick_params(axis='x', rotation=45,labelsize=10)
ax[1].tick_params(axis='y',labelsize=10)
ax[1].set_ylabel(r'ΔVT',fontsize=10)
ax[1].set_xticks(x_vals)
ax[1].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[1].grid()

mean = trans_diff.VZ.mean(dim='ensemble')/area / 9.81
std = (trans_diff.VZ/area).std(dim='ensemble') / 9.81

ax[2].plot(xs, mean, color='green')
ax[2].fill_between(xs, mean-std, mean+std,
                     color='green',alpha=0.4)
ax[2].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[2].tick_params(axis='x', rotation=45,labelsize=10)
ax[2].tick_params(axis='y',labelsize=10)
ax[2].set_ylabel(r'ΔVZ',fontsize=10)
ax[2].set_xticks(x_vals)
ax[2].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[2].grid()

mean = trans_diff.VQ.mean(dim='ensemble')/area
std = (trans_diff.VQ/area).std(dim='ensemble')

ax[3].plot(xs, mean, color='blue',label='MSE')
ax[3].fill_between(xs, mean-std, mean+std,
                     color='blue',alpha=0.4)
ax[3].set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax[3].tick_params(axis='x', rotation=45,labelsize=10)
ax[3].tick_params(axis='y',labelsize=10)
ax[3].set_ylabel(r'ΔVQ',fontsize=10)
ax[3].set_xticks(x_vals)
ax[3].set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax[3].grid()

plt.tight_layout()
plt.show()

In [None]:
# create figure object
fig, ax = plt.subplots(figsize=(8,3),sharey=False,sharex=True)

xs = np.arange(1,731,1)

first_days = np.array([1,91,182,274])

x_vals = np.concatenate((first_days,first_days+365))
x_labs = np.tile(['Jan 1','Apr 1','Jul 1','Oct 1'],
                 (2))
mean = (trans_diff.MSE.mean(dim='ensemble')/area)/norm
std = ((trans_diff.MSE/area)/norm).std(dim='ensemble')

ax.plot(xs, mean, color='black',label='MSE')
ax.fill_between(xs, mean-std, mean+std,
                     color='black',alpha=0.4)
ax.set_xlim([1,731])
# ax[0,0].set_yticks(np.arange(-3,16,3))

ax.tick_params(axis='x', rotation=45,labelsize=10)
ax.tick_params(axis='y',labelsize=10)
ax.set_ylabel(r'ΔMSE transport at 70°N (W/m$^2$)',fontsize=10)
ax.set_xticks(x_vals)
ax.set_xticklabels(x_labs)
# ax[0].legend(loc='upper left')
ax.grid()

plt.tight_layout()
# plt.savefig('dAHT_70N.svg')
plt.show()

In [None]:
# create figure object
fig, ax = plt.subplots(nrows=3,figsize=(8,7),sharex=True)

xs = np.arange(1,731,1)

first_days = np.array([1,91,182,274])

x_vals = np.concatenate((first_days,first_days+365))
x_labs = np.tile(['Jan 1','Apr 1','Jul 1','Oct 1'],
                 (2))

mean = (trans_diff.MSE.mean(dim='ensemble')/area)/norm
std = ((trans_diff.MSE/area)/norm).std(dim='ensemble')

# plot 4xCO2 starting in Jan
ax[0].plot(xs, mean, color='black',label='MSE transport at 70°N')
ax[0].fill_between(xs, mean-std, mean+std,
                     color='black',alpha=0.4)
ax[2].set_xlim([1,731])

mean = ((trans_diff.VT + trans_diff.VZ).mean(dim='ensemble')/area)/norm
std = (((trans_diff.VT + trans_diff.VZ)/area)/norm).std(dim='ensemble')

# plot 4xCO2 starting in Jan
ax[1].plot(xs, mean, color='lightsalmon',label='DSE')
ax[1].fill_between(xs, mean-std, mean+std,
                     color='lightsalmon',alpha=0.4)

mean = (trans_diff.VQ.mean(dim='ensemble')/area)/norm
std = ((trans_diff.VQ/area)/norm).std(dim='ensemble')

# plot 4xCO2 starting in Jan
ax[2].plot(xs, mean, color='turquoise',label='VQ')
ax[2].fill_between(xs, mean-std, mean+std,
                     color='turquoise',alpha=0.4)


ax[2].tick_params(axis='x', rotation=45,labelsize=10)
ax[2].tick_params(axis='y',labelsize=10)
# ax[1].set_ylabel(r'Diff in warming contr. (K)',fontsize=10)
ax[0].legend(loc='upper left',fontsize=10)
ax[1].legend(loc='upper left',fontsize=10)
ax[2].legend(loc='upper left',fontsize=10)
ax[2].set_xticks(x_vals)
ax[2].set_xticklabels(x_labs)
ax[1].grid()
ax[0].grid()
ax[2].grid()

ax[0].set_ylim([-60,60])
ax[1].set_ylim([-60,60])
ax[2].set_ylim([-10,10])

ax[1].set_ylabel('Diff in warming contr. (K)',fontsize=12)
# ax[1].set_ylabel('Diff in warming contr. (K)')

plt.tight_layout()
# plt.show()           
plt.savefig('dAHT_70N.svg')

In [None]:
# create function for taking spatial averages, while weighting for latitude
def spatial_mean(ds_in, lat_bound_s = -91, lat_bound_n = 91):
    """
    Use xarray/numpy to calculate spatial average while weighting for latitude.
    
    Keyword arguments:
    ds_in -- Dataset or DataArray to take the average of
    lat_bound_s -- float, Southern boundary of area to average
    lat_bound_n -- float, Northern boundary of area to average
    """
    zonal = ds_in.mean(dim='lon').sel(lat=slice(lat_bound_s,lat_bound_n))
    weights = np.cos(np.deg2rad(zonal.lat)) / np.sum(np.cos(np.deg2rad(zonal.lat)))
    return((zonal * weights).sum(dim='lat'))

## compare AHT as residual vs calculated explicitly

In [None]:
def read_in(exp,mon,ens,var):
    """
    Use xarray to read in a netCDF file.

    Keyword arguments:
    exp -- CO2 scenario
    mon -- starting month in which CO2 is altered
    ens -- ensemble number
    var -- model output variable
    """
    filein = '/dx02/janoski/cesm-LE/output/b40.1850.cam5-lens.'+exp+'.'+str(
        f"{mon:02d}")+'.'+str(f"{ens:02d}")+'.h1_'+var+'.nc'
    return(xr.open_dataset(filein,chunks=None))

In [None]:
gw = xr.open_dataarray('/dx02/janoski/cesm-LE/output/cam5_gauss_weights.nc')
# create function for taking spatial averages, while weighting for latitude
def gw_mean(ds_in, gw, lat_bound_s = -91, lat_bound_n = 91):
    """
    Use xarray/numpy to calculate spatial average while weighting for latitude.
    
    Keyword arguments:
    ds_in -- Dataset or DataArray to take the average of, ALREADY ZONALLY AVERAGED
    gw -- Array of guassian weights. Should only have latitude dimension.
    lat_bound_s -- float, Southern boundary of area to average
    lat_bound_n -- float, Northern boundary of area to average
    """
    return (ds_in.sel(lat=slice(lat_bound_s,lat_bound_n)) * gw.sel(lat=slice(lat_bound_s,lat_bound_n))/gw.sel(
        lat=slice(lat_bound_s,lat_bound_n)).sum(dim='lat')).sum(dim='lat')

In [None]:
test_exp = xr.open_dataset('/dx02/janoski/cesm-LE/output/b40.1850.cam5-lens.ctrl.01.01.h1_flux_70N.nc')
test_res = xr.open_dataset('/dx05/janoski/d10/Arctic_Research/cesm-LE/vert_int_feedbacks/b40.1850.cam5-lens.01.01.h1_ctrl_AHT_residual.nc')

In [None]:
res_70N = gw_mean(test_res.mean(dim='lon'),gw, lat_bound_s=70)

In [None]:
area = 0.15e14
(test_exp.MSE/area).plot()
res_70N.AHT.plot(linestyle='.')

In [None]:
res_70N