In [None]:
from pyesgf.search import SearchConnection
from pyesgf.logon import LogonManager
import pandas as pd
import xarray as xr
import numpy as np
from pathlib import Path
import pylab as plt
from joblib import Parallel, delayed
import operator
from typing import Union
from tqdm.auto import tqdm
from cdo import Cdo
from functools import reduce

## Log on with OpenID

In [None]:
lm = LogonManager()
lm.logoff()
lm.is_logged_on()

my_id = "aaschwanden"

OPENID = f"https://esgf-node.llnl.gov/esgf-idp/openid/{my_id}"
lm.logon_with_openid(openid=OPENID, password=None, bootstrap=True)
lm.is_logged_on()

## Select Experiments and Variables

In [None]:
experiments = ["lig127k", "midPliocene-eoi400"]
variables = ["pr", "tas"]

In [None]:
conn = SearchConnection('https://esgf-node.llnl.gov/esg-search', distrib=True)
ctx = conn.new_context(facets='project,experiment_id')
facets='project,experiment_family'
ctx = conn.new_context(project='CMIP6',
                       activity_id="PMIP",
                       realm="atmos",
                       table_id="Amon",
                       variable_id=variables,
                       experiment_id=experiments,
                       facets=facets)
print('Hits:', ctx.hit_count)

In [None]:
def mapping(result):
    return list(map(lambda f : {'variable': f.filename.split("_")[0], 
                                'table_id': f.filename.split("_")[1],
                                'source_id': f.filename.split("_")[2],
                                'experiment_id': f.filename.split("_")[3],
                                'sub_experiment_id': f.filename.split("_")[4],
                                'filename': f.filename, 
                                'url': f.opendap_url, 
                                'size': f.size, 
                                'checksum': f.checksum, 
                                'checksum_type': f.checksum_type},
                    result))

In [None]:
n_jobs = 16
results = ctx.search()
n_files = len(results)

joblib_files = Parallel(n_jobs=n_jobs)(
    delayed(mapping)(results[i].file_context().search())
    for i in tqdm(range(n_files))
)

In [None]:
all_files = [joblib_files[i][0] for i in range(len(joblib_files)) if len(joblib_files[i]) > 0]
all_files = sorted(all_files, key=operator.itemgetter("filename"))

## Generate a DataFrame with all files. Save to disk for later use.

In [None]:
df = pd.DataFrame.from_dict(all_files).drop_duplicates()
df.to_csv("experiments.csv")

## Read Experiment Table

In [None]:
df = pd.read_csv("experiments.csv")

In [None]:
def download(row, odir: Union[Path, str] = "pmip_raw"):
    """
    Download function
    """
    url = row["url"]
    filename = Path(row["filename"])
    if not isinstance(odir, Path):
        odir = Path(odir)
    odir.mkdir(exist_ok=True)
    m_filename = odir / filename
    try:
        ds = xr.open_dataset(url, decode_times=False, engine="netcdf4")
        if not Path(m_filename).exists():
            ds = xr.open_dataset(url, decode_times=False)
            ds["experiment_id"] = ds.attrs["experiment_id"]
            if "pr" in ds:
                ds["pr"] *= 31556925.9747
                ds["pr"]["units"] = "kg m-2 yr-1"
            print(f"Saving {m_filename}")
            ds.to_netcdf(m_filename)
    except:
        pass

In [None]:
n_files = len(df)
print(n_files)

joblib_files = Parallel(n_jobs=10)(
    delayed(download)(row)
    for _, row in tqdm(df.iterrows())
)

In [None]:
cdo = Cdo()

## Merge time and remap to common grid to allow computing stats

In [None]:
def check_files(files):
    if f.exists():
        return f.as_posix()

n_jobs = 8
idir: Union[Path, str] = Path("pmip_raw")
odir: Union[Path, str] = Path("pmip_processed")
odir.mkdir(exist_ok=True)
processed_df = []
for (m_var, m_exp, m_source), gcm_df in df.groupby(by=["variable", "experiment_id", "source_id"]):
    ifiles = [idir / Path(f) for f in gcm_df["filename"]]
    ofile = odir / Path(f"{m_var}_Amon_{m_source}_{m_exp}.nc")
    try:
        cdo.remapycon("r360x180 -timmean -mergetime", input=[check_file(f) for f in ifiles], output=ofile.as_posix(),  options =f"-O -f nc -z zip_3 -P {n_jobs}")
        processed_df.append(pd.DataFrame.from_dict({"variable_id": [m_var], "experiment_id": [m_exp], "source_id": [m_source], "filename": ofile}))
    except:
        pass

In [None]:
proccessed_df = pd.concat(processed_df).reset_index(drop=True)

In [None]:
def preprocess(ds):
    ds["experiment_id"] = ds.attrs["experiment_id"]
    source_id = ds.attrs["source_id"]
    ds = ds.assign_coords({"source_id": source_id}).drop("height", errors="ignore")
    return ds.sel(lat=slice(60, 85), lon=slice(285, 350))

In [None]:
intersection  = proccessed_df.groupby(by=["variable_id", "experiment_id"])["source_id"].unique()
intersection_gcms = reduce(lambda  left,right: list(set(left).intersection(set(right))), intersection)
intersection_df = proccessed_df[proccessed_df["source_id"].isin(intersection_gcms)]

In [None]:
for (m_var, m_exp), m_df in intersection_df.groupby(by=["variable_id", "experiment_id"]):
    p = m_df["filename"]
    ds = xr.open_mfdataset(p, parallel=True, concat_dim="source_id", combine="nested",
                           data_vars='minimal', coords='minimal', compat='override', preprocess=preprocess, decode_times=False)
    f = ds[m_var].mean(dim="time").plot(col="source_id", col_wrap=6)
    mean = ds[m_var].mean(dim=["time", "lat", "lon"]).to_dataframe()
    variance = ds[m_var].std(dim=["time", "source_id"])
    [f.axs.ravel()[source[0]].text(0.1, 0.9, f"""mean={np.round(source[1][-1].to_numpy()[0])}""", color="w", horizontalalignment='left',
     verticalalignment='center', transform=f.axs.ravel()[source[0]].transAxes) for source in enumerate(mean.iterrows())]
    f.fig.savefig(f"{m_var}_{m_exp}.pdf")
    fig, ax = plt.subplots(1, 1)
    f_var = variance.plot(ax=ax)



In [None]:
plt.cle