In [1]:
import os
import itertools
import logging
logging.basicConfig()
import json
import tempfile
from copy import deepcopy
from collections import defaultdict
from functools import partial

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import ipywidgets as widgets
from ipywidgets import GridspecLayout, Layout, Label, VBox, HBox, Box, Output, Button, SelectMultiple, Checkbox, IntText, Textarea

In [2]:
logging.getLogger("cicliminds_lib").setLevel(logging.INFO)
from cicliminds_lib.query.files import get_datasets
from cicliminds_lib.query.datasets import get_list_of_files
from cicliminds_lib.query.models import list_model_configurations

from cicliminds_lib.bindings import remove_grid_from_data

from cicliminds_lib.masks.masks import get_land_mask
from cicliminds_lib.masks.masks import get_antarctica_mask
from cicliminds_lib.masks.masks import iter_reference_region_masks
from cicliminds_lib.masks.loaders import load_reference_regions_meta

from cicliminds_lib.plotting.plot_recipes import plot_means_of_hists
from cicliminds_lib.plotting.plot_recipes import plot_means_of_hists_diff
from cicliminds_lib.plotting.plot_recipes import plot_hists_of_means
from cicliminds_lib.plotting.plot_recipes import plot_hists_of_means_diff
from cicliminds_lib.plotting.plot_recipes import plot_hist_of_timeavgs
from cicliminds_lib.plotting.plot_recipes import plot_hist_of_timeavgs_diff

In [3]:
from cicliminds.widgets.filter import FilterWidget

In [4]:
plt.rcParams["figure.figsize"] = (12,8)
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["xtick.major.size"] = 8
plt.rcParams["xtick.major.width"] = 1.6
plt.rcParams["xtick.minor.width"] = 0.8
plt.rcParams["xtick.minor.size"] = 4
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["ytick.major.width"] = 1.6
plt.rcParams["ytick.minor.width"] = 0.8
plt.rcParams["ytick.major.size"] = 8
plt.rcParams["ytick.minor.size"] = 4
plt.rcParams["font.size"] = 16
plt.rcParams["lines.linewidth"] = 3
plt.rcParams["lines.markersize"] = 5
plt.rcParams["savefig.dpi"] = 300/2.4
plt.rcParams["savefig.transparent"] = False
plt.rcParams["savefig.facecolor"] = "white"

In [5]:
DATA_DIR = os.environ["DATA_DIR"]
display(DATA_DIR)

'/home/viktoana/projects/cicero/data/Climdex_base1981-2010'

In [6]:
DATASETS = get_datasets(DATA_DIR)

In [7]:
DATASETS

Unnamed: 0_level_0,variable,frequency,model,scenario,init_params,timespan
path,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/altcddETCCDI_yr_CNRM-ESM2-1_historical_r1i1p1f2_1850-2014.nc,altcddETCCDI,yr,CNRM-ESM2-1,historical,r1i1p1f2,1850-2014
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/altcddETCCDI_yr_ACCESS-CM2_historical_r1i1p1f1_1850-2014.nc,altcddETCCDI,yr,ACCESS-CM2,historical,r1i1p1f1,1850-2014
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/altcddETCCDI_yr_CNRM-ESM2-1_ssp126_r1i1p1f2_2015-2100.nc,altcddETCCDI,yr,CNRM-ESM2-1,ssp126,r1i1p1f2,2015-2100
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/altcddETCCDI_yr_ACCESS-CM2_ssp126_r1i1p1f1_2015-2100.nc,altcddETCCDI,yr,ACCESS-CM2,ssp126,r1i1p1f1,2015-2100
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/altcddETCCDI_yr_CNRM-ESM2-1_ssp245_r1i1p1f2_2015-2100.nc,altcddETCCDI,yr,CNRM-ESM2-1,ssp245,r1i1p1f2,2015-2100
...,...,...,...,...,...,...
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/wsdiETCCDI_yr_UKESM1-0-LL_historical_r1i1p1f2_1850-2014.nc,wsdiETCCDI,yr,UKESM1-0-LL,historical,r1i1p1f2,1850-2014
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/wsdiETCCDI_yr_UKESM1-0-LL_ssp126_r1i1p1f2_2015-2100.nc,wsdiETCCDI,yr,UKESM1-0-LL,ssp126,r1i1p1f2,2015-2100
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/wsdiETCCDI_yr_UKESM1-0-LL_ssp245_r1i1p1f2_2015-2100.nc,wsdiETCCDI,yr,UKESM1-0-LL,ssp245,r1i1p1f2,2015-2100
/home/viktoana/projects/cicero/data/Climdex_base1981-2010/wsdiETCCDI_yr_UKESM1-0-LL_ssp370_r1i1p1f2_2015-2100.nc,wsdiETCCDI,yr,UKESM1-0-LL,ssp370,r1i1p1f2,2015-2100


In [8]:
region_names = [f'{r["LAB"]} :: {r["NAME"]}' for _, r in load_reference_regions_meta().iterrows()]

In [17]:
def update_filters(filters_widget, change):
#     mask = np.full(DATASETS.shape[0], True)
#     to_update = {}
#     for field, widget in field_widgets_items(filter_widget_panel):
#         if not widget.value:
#             to_update[field] = widget
#             continue
#         mask = mask & DATASETS[field].isin(widget.value)
    filtered_dataset = filters_widget.get_filtered_dataset()
    filters_widget.update_state_from_dataset(filtered_dataset)
    if filtered_dataset.shape[0] > 200:
        return
    options = []
    for idx, row in filtered_dataset.iterrows():
        options.append((",").join(map(str, row[filters_widget.FILTER_FIELDS].values)))
    filtered_configurations.options = options
    rows = min(20, len(options))
    filtered_configurations.rows = rows
    filtered_configurations.notify_change({"type": "change", "name": "rows", "new": rows})
    filtered_configurations.notify_change({"type": "change", "name": "options", "new": options})

In [32]:
def aggregate_models(models):
    global plot_types
    global select_regions
    global reference_window_size
    global sliding_window_size
    global slide_step
    global aggregate_regions
    global subtract_reference
    global filter_widget
    
    res = [{
        "reference_window_size": reference_window_size.value,
        "sliding_window_size": sliding_window_size.value,
        "slide_step": slide_step.value,
        "subtract_reference": subtract_reference.value,
        "binsize": None,
        "bincount": 10,
    }]
    
    new_res = []
    for block in res:
        for plot_type in plot_types.value:
            new_block = deepcopy(block)
            new_block.update({
                "plot_type": plot_type,
            })
            new_res.append(new_block)
    res = new_res
    
    selected_regions = select_regions.value or []
    selected_region_ids = [region.split(":")[0].strip() for region in selected_regions]
    new_res = []
    if aggregate_regions.value or not selected_region_ids:
        for block in res:
            new_block = deepcopy(block)
            new_block.update({
                "regions": selected_region_ids
            })
            new_res.append(new_block)
    else:
        for block in res:
            for region in selected_region_ids:
                new_block = deepcopy(block)
                new_block.update({
                    "regions": [region]
                })
                new_res.append(new_block)
    res = new_res

    ts_index = defaultdict(list)
    for block in res:
        for model_str in models:
            new_block = deepcopy(block)
            new_block.update(
                dict(zip(filter_widget.FILTER_FIELDS, model_str.strip().split(",")))
            )
            key = ":".join(new_block[f] for f in filter_widget.FILTER_FIELDS if f not in ["scenario", "timespan"])
            if new_block["scenario"] == "historical":
                ts_index[key].insert(0, new_block)
            else:
                ts_index[key].append(new_block)
    
    for blocks in ts_index.values():
        hist_block = blocks[0]
        hist_block["scenario"] = [hist_block["scenario"]]
        hist_block["timespan"] = [hist_block["timespan"]]
        if len(blocks) == 1:
            yield hist_block
            continue
        if aggregate_years.value:
            for block in blocks[1:]:
                new_block = deepcopy(hist_block)
                new_block["scenario"].append(block["scenario"])
                new_block["timespan"].append(block["timespan"])
                yield new_block

In [33]:
from cicliminds_lib.bindings import cdo_cat
from cicliminds_lib.bindings import remove_grid


def _get_dataset(query, tmpfile):
    common_mask = (DATASETS["model"] == query["model"]) \
               & (DATASETS["init_params"] == query["init_params"]) \
               & (DATASETS["frequency"] == query["frequency"]) \
               & (DATASETS["variable"] == query["variable"]) \
               & (DATASETS["scenario"].isin(query["scenario"])) \
               & (DATASETS["timespan"].isin(query["timespan"]))
    files = DATASETS[common_mask].sort_values(by=["scenario"], key=lambda x: x.apply(lambda y: query["scenario"].index(y)))
    cdo_cat(tmpfile, files.index.values)


def _mask_regions(data, regions):
    mask = get_land_mask(data)
    
    if not regions:
        mask = mask & (~get_antarctica_mask(data))
        return data.where(mask)
    
    all_reg_mask = np.any([reg_mask for _, reg_mask in iter_reference_region_masks(data, regions)])
    mask = mask & all_reg_mask
    return data.where(mask)    


PLOT_FUNCS = {
    "fldmean first": [plot_hists_of_means, plot_hists_of_means_diff],
    "fldmean last": [plot_means_of_hists, plot_means_of_hists_diff],
    "avg time": [plot_hist_of_timeavgs, plot_hist_of_timeavgs_diff]
}


def plot_by_query(query, ax):
    with tempfile.NamedTemporaryFile("r") as tmpfile1, \
         tempfile.NamedTemporaryFile("r") as tmpfile2:
        _get_dataset(query, tmpfile1.name)
        remove_grid(tmpfile2.name, tmpfile1.name)
        raw_data = xr.load_dataset(tmpfile2.name)
    masked_data = _mask_regions(raw_data, query["regions"])
    plot_func = PLOT_FUNCS[query["plot_type"]][int(query["subtract_reference"])]
    plot_func(ax, masked_data[query["variable"]])

In [34]:
def rebuild_one_block_action(block, change):
    cfg = json.loads(block.children[0].children[0].value)
    output = block.children[1]
    fig, ax = plt.subplots()
    plot_by_query(cfg, ax)
    plt.close()
    with output:
        output.clear_output()
        display(fig)

In [35]:
def staged_block(model_config):
    config_widget = Textarea(value=json.dumps(model_config, indent=True), layout={"flex": "6 1 260px", "height": "10em", "overflow": "hidden", "margin": "0 20px 0 0"})
    unstage_button = Button(description="Unstage", button_style="danger", icon="trash")
    rebuild_button = Button(description="Rebuild", button_style="success", icon="redo")
    output = Output(layout={"flex": "1 1 0px"})
    block = VBox([
        HBox([config_widget, VBox([unstage_button, rebuild_button], layout={"flex": "1 1 100px", "margin": "0 20px 0 0"})]),
        output
    ], layout={"margin": "5px 0"})
    rebuild_button.on_click(partial(rebuild_one_block_action, block))
    unstage_button.on_click(unstage_one_block_action)
    return block

In [36]:
def stage_action(change):
    global filtered_configurations
    global staged_list
    to_agg = filtered_configurations.value or filtered_configurations.options
    to_add = []
    for model_config in aggregate_models(to_agg):
        to_add.append(staged_block(model_config))
    staged_list.children = tuple(to_add) + staged_list.children

In [37]:
def rebuild_all_action(change):
    global staged_list
    for block in staged_list.children:
        block.children[0].children[1].children[1].click()
        
def unstage_all_action(change):
    global staged_list
    staged_list.children = tuple()
    
def unstage_one_block_action(change):
    global staged_list
    for i, block in enumerate(staged_list.children):
        if change == block.children[0].children[1].children[0]:
            break
    staged_list.children = staged_list.children[:i] + staged_list.children[i+1:]

In [38]:
# filter widget

# button_reset = Button(description="Reset filter", button_style="danger", icon="broom")
# button_reset.on_click(reset_filters)
# filter_controls = VBox([button_reset])
# filter_widget_panel = get_filter_widget_panel(DATASETS)
# for _, widget in field_widgets_items(filter_widget_panel):
#     widget.observe(update_filters, names="value")

# filter_widget = VBox([Label("Configuration filter:"), filter_widget_panel, filter_controls]) 
filter_widget = FilterWidget(DATASETS)
filter_widget.observe(update_filters)


# filtered widget

filtered_configurations = SelectMultiple(disabled=False, layout={"width": "auto"})
filtered_widget = VBox([Label("Filtered configurations:"), filtered_configurations])


# staged controls

select_regions = SelectMultiple(options=region_names, value=region_names[:1], rows=10, layout={"width": "auto"})

aggregate_years = Checkbox(description="years", indent=False, value=True, layout={"width": "auto"})
aggregate_regions = Checkbox(description="regions", indent=False, layout={"width": "auto"})

plot_types = SelectMultiple(options=["fldmean first", "fldmean last", "avg time"], value=("fldmean last",), rows=6, layout={"width": "auto"})
subtract_reference = Checkbox(description="Subtract reference", indent=False, layout={"width": "auto"})

reference_window_size = IntText(value=50, layout={"width": "auto"})
sliding_window_size = IntText(value=20, layout={"width": "auto"})
slide_step = IntText(value=10, layout={"width": "auto"})

button_stage = Button(description="Stage", button_style="success", icon="plus")
button_stage.on_click(stage_action)

staging_panel = HBox([
    VBox([Label("Regions"), select_regions], layout={"flex": "1 1 100px", "width": "auto", "margin": "0 20px 0 0"}),
    VBox([Label("Aggregate"), aggregate_years, aggregate_regions], layout={"flex": "1 1 100px", "width": "auto", "margin": "0 20px 0 0"}),
    VBox([Label("Plot type"), plot_types, subtract_reference], layout={"flex": "1 1 100px", "width": "auto", "margin": "0 20px 0 0"}),
    VBox([Label("Reference window size"), reference_window_size, Label("Sliding window size"), sliding_window_size, Label("Slide step"), slide_step], layout={"flex": "1 1 100px", "width": "auto", "margin": "0 20px 0 0"}),
    VBox([Label(), button_stage], layout={"flex": "1 1 100px", "width": "auto", "margin": "0 20px 0 0"})
])


# staged widget

button_rebuild_all = Button(description="Rebuild all", icon="redo", button_style="success")
button_rebuild_all.on_click(rebuild_all_action)
button_unstage_all = Button(description="Unstage all", icon="trash", button_style="danger")
button_unstage_all.on_click(unstage_all_action)
staged_controls = HBox([button_rebuild_all, button_unstage_all])
staged_list = VBox()

staged_widget = VBox([
    Label("Staged for plotting:"),
    staged_controls,
    staged_list])

app = VBox([filter_widget.render(),
            filtered_widget,
            staging_panel,
            staged_widget])

output = Output()
display(output, app)
filter_widget.button_reset.click()

Output()

VBox(children=(VBox(children=(Label(value='Configuration filter:'), HBox(children=(VBox(children=(Label(value=…