## Plot Climatology Correction Applied to Each Model in LDEO-HPD
## Plot STD of that correction as well

In [9]:
# Modules
import pandas as pd
import numpy as np
import xarray as xr
%run _Val_Mapping.ipynb
import glob
import os
%matplotlib inline
%config InlineBackend.figure_format = 'pdf'
%config InlineBackend.print_figure_kwargs = {'dpi':300}

In [10]:
root_dir = "/data/artemis/workspace/vbennington/LDEO_HPD/models/XGB/GCB_2021"  # directory output will be written to  
recon_output_dir = f"{root_dir}/reconstructions" # reconstructions saved here

#models = [ 'cesm_spco2_1x1_A','csiro_spco2_1x1_A','fesom_spco2_1x1_A','mpi_spco2_1x1_A','cnrm_spco2_1x1_A','ipsl_spco2_1x1_A',
#          'planktom_spco2_1x1_A','noresm_spco2_1x1_A','princeton_spco2_1x1_A']
models = [ 'cesm_sfco2_1x1_A','fesom2_sfco2_1x1_A','mpi_sfco2_1x1_A','cnrm_sfco2_1x1_A','ipsl_sfco2_1x1_A',
          'planktom_sfco2_1x1_A','noresm_sfco2_1x1_A','princeton_sfco2_1x1_A']
#
#-----------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------
# Climatology of Correction in this file:
#(f'{recon_output_dir}/pCO2_cc2000-2019_1x1_recon_1959-2019.nc') 

# Full corrections in this file (model output):
#(f"{recon_output_dir}/{mod}_recon_198201-201912.nc")

In [11]:
# Get seasonal climatologies of each model #
clim = xr.load_dataset(f"{recon_output_dir}/pCO2_cc2000-2020_1x1_recon_1959-2020.nc")
print(clim)

<xarray.Dataset>
Dimensions:     (lat: 180, lon: 360, model: 8, month: 12, time: 744)
Coordinates:
  * model       (model) object 'cesm_sfco2_1x1_A' ... 'princeton_sfco2_1x1_A'
  * month       (month) int64 1 2 3 4 5 6 7 8 9 10 11 12
  * time        (time) datetime64[ns] 1959-01-15 1959-02-15 ... 2020-12-15
  * lat         (lat) float64 -89.5 -88.5 -87.5 -86.5 ... 86.5 87.5 88.5 89.5
  * lon         (lon) float64 -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5
Data variables:
    pCO2        (model, time, lat, lon) float64 nan nan nan nan ... nan nan nan
    pCO2cc      (model, time, lat, lon) float64 nan nan nan nan ... nan nan nan
    correction  (model, month, lat, lon) float64 nan nan nan nan ... nan nan nan
Attributes:
    title:         LDEO-HPD Clim Correct 2000-2020
    history:       XGBoost results and Clim Corrections by Val Bennington
    institution:   Lamont Doherty Earth Observatory at Columbia
    references:    /home/vbennington/LDEO_HPD/model_output_processing/recon

In [12]:
clim['DJF'] = clim[f'correction'].sel(month=[12,1,2]).mean("month")
clim['MAM'] = clim[f'correction'].sel(month=[3,4,5]).mean("month")
clim['JJA'] = clim[f'correction'].sel(month=[6,7,8]).mean("month")
clim['SON'] = clim[f'correction'].sel(month=[9,10,11]).mean("month")

# Uncomment to USE

In [6]:
# Let's plot up the Climatology of the Error (That we use for prior to 1982):
# _clim_error
# Do by season, so we don't have 12 months to plot

region='world'
cmap = cm.cm.balance
plot_style = 'seaborn-talk'

#fig = plt.figure(figsize=(12,30))
#fig_shape=(len(models),4)
vrange = [-100, 100, 50]
    

#with plt.style.context(plot_style):
#    dia = SpatialMap2(nrows_ncols=fig_shape, fig=fig, cbar_location='bottom', cbar_orientation='horizontal')
#    i = 0
#    for mod in models:
#        for season in ['DJF','MAM','JJA','SON']:
#            data2 = clim[f'{season}'].sel(model=f"{mod}")
#            #data2 = xr_add_cyclic_point(data2, cyclic_coord='lon')    
#            sub = dia.add_plot(data=data2, vrange=vrange[0:3], cmap=cmap, ax=i)
#            if mod == models[0]:
#                dia.set_title(f"{season}",i,fontsize=14) 
#            i+=1
#        
#    col = dia.add_colorbar(sub)
#    dia.set_cbar_xlabel(col,f"$\mu$atm",fontsize=12)
#fig.savefig(f"{recon_output_dir}/allmodels_climcorrection_seasons.eps")
#plt.show()

In [13]:
# Let's plot each model separately, so we can select some for publication
#########################################################################
cmap = cm.cm.balance
plot_style = 'seaborn-talk'

vrange = [-100, 100, 50]
  
#with plt.style.context(plot_style):
    
#    for mod in models:
#        fig = plt.figure(figsize=(12,5))
#        fig_shape=(1,4)
#        dia = SpatialMap2(nrows_ncols=fig_shape, fig=fig, cbar_location='bottom', cbar_orientation='horizontal') # new plot each model
#        i = 0 # new subplot start each model
#        for season in ['DJF','MAM','JJA','SON']:
#            data2 = clim[f'{season}'].sel(model=f"{mod}")
#            sub = dia.add_plot(data=data2, vrange=vrange[0:3], cmap=cmap, ax=i)
#            dia.set_title(f"{season}",i,fontsize=12)
#            i+=1
#        col = dia.add_colorbar(sub)
#        dia.set_cbar_xlabel(col,f"$\mu$atm",fontsize=12)
#        # Save each model figure separately:
#        fig.savefig(f"{recon_output_dir}/{mod}_correction_seasons.eps")
#        plt.show()

In [14]:
region='world'
plot_style = 'seaborn-talk'
fig_shape=(1,4)
vrange = [-100, 100, 51] 

#for mod in models:
      
#    ds = xr.open_dataset(f"{recon_output_dir}/{mod}_recon_198201-202012.nc")
#    tmp = ds[f"error_{mod}"].groupby("time.month").std("time").transpose("month","ylat","xlon")
#    DJF = tmp.sel(month=[12,1,2]).mean("month")
#    MAM = tmp.sel(month=[3,4,5]).mean("month")
#    JJA = tmp.sel(month=[6,7,8]).mean("month")
#    SON = tmp.sel(month=[9,10,11]).mean("month")
    
#    fig = plt.figure(figsize=(12,5))
#    cmap = cm.cm.thermal
#    with plt.style.context(plot_style):
#        dia = SpatialMap2(nrows_ncols=fig_shape, fig=fig, cbar_location='bottom', cbar_orientation='horizontal')
#        i = 0
#        for season in ['DJF','MAM','JJA','SON']:
#            if season=="DJF":
#                data2=DJF
#            if season=="MAM":
#                data2=MAM
#            if season=="JJA":
#                data2=JJA
#            if season=="SON":
#                data2=SON
#            sub = dia.add_plot(data=data2, vrange=[0,25,26], cmap=cmap, ax=i)
#            dia.set_title(f"{season}",i,fontsize=14)
#            i+=1   
#        col = dia.add_colorbar(sub)
#        dia.set_cbar_xlabel(col,f"$\mu$atm",fontsize=12)
#        fig.savefig(f"{recon_output_dir}/{mod}_correction_std_seasons.eps")
#        plt.show()