In [None]:
import xarray as xr
import numpy as np
import warnings
from tqdm.autonotebook import tqdm
from datetime import datetime
#from tqdm import tqdm
from matplotlib import pyplot as plt
%matplotlib inline
warnings.simplefilter('ignore', xr.SerializationWarning)
xr.__version__

In [None]:
from time import sleep
for n in tqdm(range(5)):
    sleep(0.1)

In [None]:
#!/usr/bin/env python
from __future__ import print_function
import requests
import numpy
import pandas as pd
from collections import OrderedDict

# API AT: https://github.com/ESGF/esgf.github.io/wiki/ESGF_Search_REST_API#results-pagination

def check_doc_for_malformed_id(d):
    source_id = d['source_id'][0]
    expt_id = d['experiment_id'][0]
    if not  f"{source_id}_{expt_id}" in d['id']:
        raise ValueError(f"Dataset id {d['id']} is malformed")
                         
def maybe_squeze_values(d):
    def _maybe_squeeze(value):
        if isinstance(value, str):
            return value
        try:
            if len(value)==1:
                return value[0]
        except TypeError:
            return(value)
    return {k: _maybe_squeeze(v) for k, v in d.items()}
                         
def get_request(client, server, verbose=False, **payload):
    url_keys = [] 
    url_keys = ["{}={}".format(k, payload[k]) for k in payload]
    url = "{}/?{}".format(server, "&".join(url_keys))
    if verbose:
        print(url)
    r = client.get(url)
    r.raise_for_status()
    resp = r.json()["response"]
    return resp

def esgf_search(server="https://esgf-node.llnl.gov/esg-search/search",
                files_type="OPENDAP", local_node=True, project="CMIP6",
                # this option should not be necessary with local_node=True
                filter_server_url=None,
                verbose=False, format="application%2Fsolr%2Bjson",
                use_csrf=False, **search):
    client = requests.session()
                         
    payload = search
    payload["project"] = project
    payload["type"]= "File"
    if local_node:
        payload["distrib"] = "false"
    if use_csrf:
        client.get(server)
        if 'csrftoken' in client.cookies:
            # Django 1.6 and up
            csrftoken = client.cookies['csrftoken']
        else:
            # older versions
            csrftoken = client.cookies['csrf']
        payload["csrfmiddlewaretoken"] = csrftoken

    payload["format"] = format

    init_resp = get_request(client, server, offset=0, verbose=verbose, **payload)
    num_found = int(init_resp["numFound"])
                         
    offset = 0
    all_files = []
    files_type = files_type.upper()
         
    with tqdm(total=num_found, desc='ESGF Search', unit='docs') as pbar:
        while offset < num_found:
            resp = get_request(client, server, offset=offset, verbose=verbose, **payload)

            docs = resp["docs"]
            offset += len(docs)
            pbar.update(len(docs))

            for d in docs:
                try:
                    check_doc_for_malformed_id(d)
                except ValueError:
                    continue
                dataset_id = d['dataset_id']
                item = OrderedDict(dataset_id=dataset_id, id=d['id'])
                #item.update({field: d[field][0] for field in required_fields})
                target_urls = d.pop('url')
                item.update(maybe_squeze_values(d))
                for f in target_urls:
                    sp = f.split("|")
                    if sp[-1] == files_type:
                        opendap_url = sp[0].replace('.html', '')
                        if filter_server_url is None or filter_server_url in opendap_url:
                            item.update({f'{files_type}_url': opendap_url})
                            all_files.append(item)
        pbar.close()
    # dropping duplicates on checksum removes all identical files
    return pd.DataFrame(all_files).drop_duplicates(subset='checksum')

In [None]:
## 233 files
#files = esgf_search(mip_era='CMIP6', activity_drs='CMIP', variable="ta", table_id='Amon', filter_server_url='aims3.llnl.gov')

## 1179 files
#files = esgf_search(mip_era='CMIP6', activity_drs='CMIP', variable="ta")

## 2615 files
files = esgf_search(mip_era='CMIP6', activity_drs='CMIP', institution_id='IPSL',
                    table_id='Amon', latest='true', filter_server_url='aims3.llnl.gov')

## 13453 files
#files = esgf_search(mip_era='CMIP6', activity_drs='CMIP',
#                    table_id='Amon', latest='true', filter_server_url='aims3.llnl.gov')

## 83273 files
#files = esgf_search(mip_era='CMIP6', activity_drs='CMIP', latest='true',
#                    filter_server_url='aims3.llnl.gov')

files.head()

In [None]:
files.experiment_id.unique()

In [None]:
def set_bnds_as_coords(ds):
    new_coords_vars = [var for var in ds.data_vars if 'bnds' in var or 'bounds' in var]
    ds = ds.set_coords(new_coords_vars)
    return ds

def fix_climatology_time(ds):
    for dim in ds.dims:
        if 'climatology' in ds[dim].attrs:
            ds = ds.rename({dim: dim + '_climatology'})
    return ds

def set_coords(ds):
    # there should only be one variable per file
    # everything else is coords
    varname = ds.attrs['variable_id']
    coord_vars = set(ds.data_vars) - {varname}
    ds = ds.set_coords(coord_vars)
    ds = fix_climatology_time(ds)
    return(ds)

def concat_timesteps(urls):
    urls = list(urls)
    if len(urls) > 1:
        ds = xr.open_mfdataset(urls, concat_dim='time',
                                 chunks={'time': 'auto'},
                                 preprocess=set_coords)
        # start history fresh
        ds.attrs['history'] = f"{datetime.now()} xarray.open_mfdataset({urls}, concat_dim='time')"
    else:
        ds = xr.open_dataset(urls[0], chunks={'time': 'auto'})
        ds.attrs['history'] = f"{datetime.now()} xarray.open_dataset('{urls[0]}')"
        ds = set_coords(ds)
    return ds

def dict_union(*dicts, merge_keys=['history', 'further_info_url'],
               drop_keys=['DODS_EXTRA.Unlimited_Dimension']):
    if len(dicts) > 2:
        return reduce(dict_union, dicts)
    elif len(dicts)==2:
        d1, d2 = dicts
        d = type(d1)()
        # union
        all_keys = set(d1) | set(d2)
        for k in all_keys:
            v1 = d1.get(k)
            v2 = d2.get(k)
            if (v1 is None and v2 is None) or k in drop_keys:
                pass
            elif v1 is None:
                d[k] = v2
            elif v2 is None:
                d[k] = v1
            elif v1==v2:
                d[k] = v1
            elif k in merge_keys:
                d[k] = '\n'.join([v1, v2])
        return d
    elif len(dicts)==1:
        return dicts[0]

def concat_ensembles(member_dsets, member_ids, join='outer'):
    if len(member_dsets)==1:
        return member_dsets[0]
    concat_dim = xr.DataArray(member_ids, dims='member_id', name='member_id')
    
    # merge attributes
    attrs = dict_union(*[ds.attrs for ds in member_dsets])
    
    # align first to deal with the fact that some ensemble members have different lengths
    # inner join keeps only overlapping segments of each ensemble
    # outer join gives us the longest possible record
    member_dsets_aligned = xr.align(*member_dsets, join=join)
    
    # keep only coordinates from first ensemble member to simplify merge
    first = member_dsets_aligned[0]
    rest = [mds.reset_coords(drop=True) for mds in member_dsets_aligned[1:]]
    objs_to_concat = [first] + rest
    
    ds = xr.concat(objs_to_concat, dim=concat_dim, coords='minimal')
    attrs['history'] += f"\n{datetime.now()} xarray.concat(<ALL_MEMBERS>, dim='member_id', coords='minimal')"
    ds.attrs = attrs
    return ds

def merge_vars(ds1, ds2):
    # merge two datasets at a time - designed for recursive merging
    # drop all variables from second that already exist in first's coordinates

    # I can't believe xarray doesn't have a merge that keeps attrs
    attrs = dict_union(ds1.attrs, ds2.attrs)
    
    # non dimension coords
    ds1_ndcoords = set(ds1.coords) - set(ds1.dims)
    
    # edge case for variable 'ps', which is a coordinate in some datasets
    # and a data_var in its own dataset
    ds2_dropvars = set(ds2.variables).intersection(ds1_ndcoords)
    ds2_drop = ds2.drop(ds2_dropvars)
    
    ds = xr.merge([ds1, ds2_drop])
    ds.attrs = attrs
    return ds


from functools import reduce
def merge_recursive(dsets):
    dsm = reduce(merge_vars, dsets)
    dsm.attrs['history'] += f"\n{datetime.now()} xarray.merge(<ALL_VARIABLES>)"
    # fix further_info_url
    fi_urls = set(dsm.attrs['further_info_url'].split('\n'))
    dsm.attrs['further_info_url'] = '\n'.join(fi_urls)
    return dsm

In [None]:
dataset_fields = ['institution_id', 'source_id', 'experiment_id', 'table_id', 'grid_label']
all_dsets = {}
for dset_keys, dset_files in tqdm(files.groupby(dataset_fields)):
    dset_id = '.'.join(dset_keys)
    print(dset_id)
    all_vars = []
    for var_id, var_files in dset_files.groupby('variable_id'):
        print('-', var_id)
        member_dsets = []
        member_ids = []
        for m_id, m_files in var_files.groupby('member_id'):
            print('  -', m_id, len(m_files))
            member_ids.append(m_id)
            member_dsets.append(concat_timesteps(m_files.OPENDAP_url))
        dset = concat_ensembles(member_dsets, member_ids)
        all_vars.append(dset)
    ds_merged = merge_recursive(all_vars)
    ds_rechunk = ds.chunk({'member_id': 1, 'time': 'auto', })
    all_dsets[dset_id] = ds_rechunk