In [None]:
import xarray as xr
from collections import namedtuple
from itertools import combinations, combinations_with_replacement, product
from matplotlib import pyplot as plt

In [None]:
%matplotlib inline
import matplotlib as mpl
mpl.rcParams['figure.figsize']=(10,10)
nc = namedtuple('File', 'ds dims mean')
xr.open_dataset('../../finite/geo/RC1SD-base-08_ECHAM5_2000-2013_variable_by_lat_over_lev.nc')

In [None]:
from typing import NamedTuple
import pandas as pd

class Distr(NamedTuple):
    var: str
    along: str
    over: str
    data: pd.DataFrame
    
    def __hash__(self):
        return '{var}{along}{over}'.format(var=self.var,
                                               along=self.along,
                                               over=self.over).__hash__()
    def __eq__(self, other):
        return hash(self) == hash(other)

In [None]:
ncfiles = [
    nc(ds=xr.open_dataset('../../finite/geo/RC1SD-base-08_ECHAM5_2000-2013_variable_by_lat_over_lon.nc'),
         dims=['lat', 'lon', 'month'],
      mean='lev'),
    nc(ds=xr.open_dataset('../../finite/geo/RC1SD-base-08_ECHAM5_2000-2013_variable_by_lat_over_lev.nc'),
         dims=['lat', 'lev', 'month'],
      mean='lon'),
    nc(ds=xr.open_dataset('../../finite/geo/RC1SD-base-08_ECHAM5_2000-2013_variable_by_lon_over_lev.nc'),
         dims=['lev', 'lon', 'month'],
      mean='lat'),
]

In [None]:
ncfiles[0].ds

In [None]:
results = []
for file in ncfiles:
    for var in file.ds.data_vars:
        for mean2, group in product(file.dims,file.dims):
            if mean2 == group:
                continue
            over = [x for x in file.dims if x not in [mean2, group]][0]
            data = getattr(file.ds, var).mean(mean2).groupby(group).var().to_dataframe()#.plot(subplots=True, title='std() over {}'.format(over))
            results.append(Distr(var, group, over, data)) 

In [None]:
# [x.data.plot(title='var() over {}'.format(x.over), logy=True) for x in set(results) if x.var=='tm1' and x.along=='lat']

In [None]:
var = 'tm1'
mapping = {'tm1':'Temperature', 'month':'Time', 'lon':'Longitude', 'lat':'Latitude', 'lev':'Altitude',
           'um1':'Zonal Wind', 'vm1': 'Meridional Wind', 'qm1':'Specific Humidity',
           'press': 'Pressure', 'geopot':'Geopotential Height'}

In [None]:
for var in ['tm1','um1', 'vm1', 'qm1', 'press', 'geopot']:
    fig = plt.figure();
    fig, axes = plt.subplots(1, 4, sharey=True, figsize=(20,5));
    along = 'lat'
    tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
    df = pd.DataFrame(None, index=tm1_lev[0].data.index)
    for run in tm1_lev:
        df[mapping[run.over]] = run.data
    df.sort_index(axis=1).plot(logy=True, linewidth=3, fontsize=14, ax=axes.flat[0], 
            title="Variance of {} along {}".format(mapping[var], mapping[along]))
    axes.flat[0].set_xlabel(mapping[along])
    # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
    # plt.show();
    plt.tight_layout()

    along = 'lon'
    tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
    df = pd.DataFrame(None, index=tm1_lev[0].data.index)
    for run in tm1_lev:
        df[mapping[run.over]] = run.data
    df.sort_index(axis=1).plot(ax=axes.flat[1], logy=True, linewidth=3, fontsize=14, 
            title="Variance of {} along {}".format(mapping[var], mapping[along]))
    axes.flat[1].set_xlabel(mapping[along])
    # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
    # plt.show();
    plt.tight_layout()

    along = 'month'
    tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
    df = pd.DataFrame(None, index=tm1_lev[0].data.index)
    for run in tm1_lev:
        df[mapping[run.over]] = run.data
    df.sort_index(axis=1).plot(ax=axes.flat[2], logy=True, linewidth=3, fontsize=14, 
            title="Variance of {} along {}".format(mapping[var], mapping[along]))
    # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
    # plt.show();
    axes.flat[2].set_xlabel(mapping[along]+" [months]")
    plt.tight_layout()

    along = 'lev'
    tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
    df = pd.DataFrame(None, index=tm1_lev[0].data.index)
    for run in tm1_lev:
        df[mapping[run.over]] = run.data
    df.sort_index(axis=1).plot(ax=axes.flat[3], logy=True, linewidth=3, fontsize=14, 
            title="Variance of {} along {}".format(mapping[var], mapping[along]))
    axes.flat[3].invert_xaxis()
    axes.flat[3].set_xlabel(mapping[along])
    # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
    plt.tight_layout()
    plt.savefig('/home/ucyo/Developments/dissertation/Figures/analysis/variance-global-{}-2000-2013.svg'.format(mapping[var].replace(' ', '-')), dpi=96)
    plt.show();

In [None]:
# from matplotlib import pyplot as plt
# var = 'um1'
# along = 'lat'
# tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
# df = pd.DataFrame(None, index=tm1_lev[0].data.index)
# for run in tm1_lev:
#     df[var+'_'+run.over] = run.data
# df.plot(logy=True, ylim=(0,1.5*10**3), linewidth=3, fontsize=14, ax=ax1, title="Variance of {} along {}".format(var, along))
# # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
# # plt.show();

In [None]:
# from scipy.interpolate import spline
# import numpy as np
# x_smooth = np.linspace(df.index.min(), df.index.max(), 800)
# y_smooth = spline(df.index, df['tm1_month'], x_smooth)
# df_n = pd.DataFrame({x:spline(df.index, getattr(df,x), x_smooth) for x in df.columns}, index=x_smooth)
# df_n.plot(logy=True, ylim=(0,1.5*10**3), linewidth=3, fontsize=14, title="Variance of {} along {}".format(var, along))

In [None]:
# along = 'lon'
# tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
# df = pd.DataFrame(None, index=tm1_lev[0].data.index)
# for run in tm1_lev:
#     df[var+'_'+run.over] = run.data
# df.plot(logy=True, ylim=(0,1.5*10**3), linewidth=3, fontsize=14, title="Variance of {} along {}".format(var, along))
# # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
# plt.show();

In [None]:
# along = 'month'
# tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
# df = pd.DataFrame(None, index=tm1_lev[0].data.index)
# for run in tm1_lev:
#     df[var+'_'+run.over] = run.data
# df.plot(logy=True, ylim=(0,1.5*10**3), linewidth=3, fontsize=14, title="Variance of {} along {}".format(var, along))
# # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
# plt.show();

In [None]:
# along = 'lev'
# tm1_lev = [x for x in set(results) if x.var==var and x.along==along]
# df = pd.DataFrame(None, index=tm1_lev[0].data.index)
# for run in tm1_lev:
#     df[var+'_'+run.over] = run.data
# df.plot(logy=True, ylim=(0,1.5*10**3), linewidth=3, fontsize=14, title="Variance of {} along {}".format(var, along))
# # plt.savefig('../../../pasc/egu/var-{}-{}.svg'.format(var, along), dpi=96, format='svg',bbox_inches='tight')
# plt.show();

In [None]:

# plt.show()