In [21]:
import os
import sys
import signal
MDTF_ROOT = '/Users/tsj/Documents/climate/MDTF-diagnostics'
sys.path.append(MDTF_ROOT)
sys.argv = ['mdtf']

import shlex
from src import core, diagnostic, cli, util, data_sources, preprocessor

import dagster as dg

In [None]:
class dgLocalFileDataSource(data_sources.SampleLocalFileDataSource):
    def preprocess_data(self):
        """Hook to run the preprocessing function on all variables.
        """
        update = True
        # really a while-loop, but limit # of iterations to be safe
        for _ in range(5):
            if update:
                # fetch alternates for any vars that failed since last time
                self.fetch_data()
                update = False
            else:
                # preproc_harness starts here
                break
        else:
            # only hit this if we don't break
            raise util.DataRequestError(
                f"Too many iterations in preprocess_data() for {self.full_name}."
            )

    def request_data(self):
        """Top-level method to iteratively query, fetch and preprocess all data
        requested by PODs, switching to alternate requested data as needed.
        """
        # Call cleanup method if we're killed
        signal.signal(signal.SIGTERM, self.query_and_fetch_cleanup)
        signal.signal(signal.SIGINT, self.query_and_fetch_cleanup)
        self.pre_query_and_fetch_hook()
        try:
            self.preprocess_data()
        except Exception as exc:
            self.log.exception("%s at DataSource level: %r.",
                util.exc_descriptor(exc), exc)
        # clean up regardless of success/fail
        self.post_query_and_fetch_hook()
        for p in self.iter_children():
            for v in p.iter_children():
                if v.status == core.ObjectStatus.ACTIVE:
                    v.log.debug('Data request for %s completed succesfully.',
                        v.full_name)
                    # v.status = core.ObjectStatus.SUCCEEDED
                elif v.failed:
                    v.log.debug('Data request for %s failed.', v.full_name)
                else:
                    v.log.debug('Data request for %s not used.', v.full_name)
            if p.failed:
                p.log.debug('Data request for %s failed.', p.full_name)
            else:
                p.log.debug('Data request for %s completed succesfully.',
                    p.full_name)

In [22]:
@dg.solid(
    output_defs=[dg.OutputDefinition(name="case")]
)
def init_harness(context):
    os.chdir(MDTF_ROOT)
    cli_obj = cli.MDTFTopLevelArgParser(MDTF_ROOT, argv=context.solid_config['cli_args'])
    fmwk = cli_obj.dispatch()
    fmwk.DataSource = dgLocalFileDataSource

    # unroll framework.main()
    fmwk.cases = dict(list(fmwk.cases.items())[0:1])
    new_d = dict()
    for case_name, case_d in fmwk.cases.items():
        context.log.info(f"### {fmwk.full_name}: initializing case '{case_name}'.")
        case = fmwk.DataSource(case_d, parent=fmwk)
        case.setup()
        new_d[case_name] = case
    fmwk.cases = new_d
    # util.transfer_log_cache(close=True)
    for case_name, case in fmwk.cases.items():
        if not case.failed:
            context.log.info(f"### {fmwk.full_name}: requesting data for case '{case_name}'.")
            case.request_data()
            yield dg.Output(case, output_name='case')

@dg.solid(
    input_defs=[dg.InputDefinition(name="case")],
    output_defs=[dg.DynamicOutputDefinition(name="case_var")]
)
def query_fanout(context, case):
    vars_to_query = [
        v for v in case.iter_vars_only(active=True) \
            if v.stage < diagnostic.VarlistEntryStage.QUERIED
    ]
    context.log.debug(f"Query batch: [{', '.join(v.full_name for v in vars_to_query)}].")
    case.pre_query_hook(vars_to_query)
    for v in vars_to_query:
        yield dg.DynamicOutput(value=(case, v), mapping_key=v.name, output_name="case_var")

@dg.solid(
    input_defs=[dg.InputDefinition(name="case_v")],
    output_defs=[dg.OutputDefinition(name="v")]
)
def query_data(context, case_v):
    case, v = case_v
    try:
        context.log.info(f"Querying {v.translation}.")
        case.query_dataset(v) # sets v.data
        if not v.data:
            raise util.DataQueryEvent("No data found.", v)
        v.stage = diagnostic.VarlistEntryStage.QUERIED
    except util.DataQueryEvent as exc:
        v.deactivate(exc)
    except Exception as exc:
        chained_exc = util.chain_exc(exc,
            f"querying {v.translation} for {v.full_name}.",
            util.DataQueryEvent)
        v.deactivate(chained_exc)
    return v

@dg.solid(
    input_defs=[dg.InputDefinition(name="case"), dg.InputDefinition(name="vars_list")],
    output_defs=[dg.DynamicOutputDefinition(name="case_var")]
)
def fanin_select_fanout(context, case, vars_list):
    case.post_query_hook(vars_list)
    case.set_experiment()

    vars_to_fetch = [
        v for v in case.iter_vars_only(active=True) \
            if v.stage < diagnostic.VarlistEntryStage.FETCHED
    ]
    context.log.debug(f"Fetch batch: [{', '.join(v.full_name for v in vars_to_fetch)}].")
    case.pre_fetch_hook(vars_to_fetch)
    for v in vars_to_fetch:
        yield dg.DynamicOutput(value=(case, v), mapping_key=v.name, output_name="case_var")

@dg.solid(
    input_defs=[dg.InputDefinition(name="case_v")],
    output_defs=[dg.OutputDefinition(name="v")]
)
def fetch_data(context, case_v):
    case, v = case_v
    try:
        context.log.info(f"Fetching {str(v)}.")
        # fetch on a per-DataKey basis
        for d_key in v.iter_data_keys(status=core.ObjectStatus.ACTIVE):
            try:
                if not case.is_fetch_necessary(d_key):
                    continue
                v.log.debug(f"Fetching {str(d_key)}.")
                case.fetch_dataset(v, d_key)
            except Exception as exc:
                d_key.deactivate(exc)
                break # no point continuing
        # check if var received everything
        for d_key in v.iter_data_keys(status=core.ObjectStatus.ACTIVE):
            if not d_key.local_data:
                raise util.DataFetchEvent("Fetch failed.", d_key)
        v.stage = diagnostic.VarlistEntryStage.FETCHED
    except Exception as exc:
        chained_exc = util.chain_exc(exc,
            f"fetching data for {v.full_name}.",
            util.DataFetchEvent)
        v.deactivate(chained_exc)
    return v

@dg.solid(
    input_defs=[dg.InputDefinition(name="case"), dg.InputDefinition(name="vars_list")],
    output_defs=[dg.DynamicOutputDefinition(name="pod_var")]
)
def fanin_preproc_setup_fanout(context, case, vars_list):
    case.post_fetch_hook(vars_list)
    for pod in case.iter_children(status=core.ObjectStatus.ACTIVE):
        pod.preprocessor.setup(case, pod)

    for pv in case.iter_vars(active=True):
        if pv.var.stage < diagnostic.VarlistEntryStage.PREPROCESSED:
            yield dg.DynamicOutput(value=(pv.pod, pv.var), output_name="pod_var",
                mapping_key=f"{pv.pod.name}_{pv.var.name}")


@dg.solid(
    input_defs=[dg.InputDefinition(name="pod_var")],
    output_defs=[dg.OutputDefinition(name="v")]
)
def fetch_data(context, pod_var):
    pod, var = pod_var
    try:
        pvar.log.info("Preprocessing %s.", pv.var)
        pod.preprocessor.process(pv.var)
        var.stage = diagnostic.VarlistEntryStage.PREPROCESSED
    except Exception as exc:
        update = True
        self.log.exception("%s while preprocessing %s: %r",
            util.exc_descriptor(exc), pv.var.full_name, exc)
        for d_key in var.iter_data_keys(status=core.ObjectStatus.ACTIVE):
            var.deactivate_data_key(d_key, exc)







@dg.solid(
    input_defs=[dg.InputDefinition(name="case")],
    output_defs=[dg.DynamicOutputDefinition(name="pv")]
)
def mdtf_preproc_harness(context, case):
    for pv in case.iter_vars():
        print(pv.pod.name, pv.var.name, pv.var.status, pv.var.stage) #, pv.var.format_log())
    vars_to_process = [
        pv for pv in case.iter_vars(active=True) \
            if pv.var.stage < diagnostic.VarlistEntryStage.PREPROCESSED
    ]
    # if not vars_to_process:
    #     break # exit: processed everything or nothing active
    print(len(vars_to_process))
    for pod in case.iter_children(status=core.ObjectStatus.ACTIVE):
        print('XX', pod.name)
        pod.preprocessor.setup(case, pod)
    for pv in vars_to_process:
        try:
            print("Preprocessing {pv.var}")
            pv.var.stage = diagnostic.VarlistEntryStage.PREPROCESSED
            # pv.pod.preprocessor.process(pv.var)
            yield dg.DynamicOutput(pv, mapping_key=f"podvar{pv}")
        except Exception as exc:
            update = True
            context.log.exception(f"{util.exc_descriptor(exc)} while preprocessing {pv.var.full_name}")
            for d_key in pv.var.iter_data_keys(status=core.ObjectStatus.ACTIVE):
                pv.var.deactivate_data_key(d_key, exc)
            continue

@dg.solid(
    input_defs=[dg.InputDefinition(name="pv")],
    output_defs=[dg.OutputDefinition(name="var"), dg.OutputDefinition(name="preproc")]
)
def get_attrs(context, pv):
    yield dg.Output(pv.var, output_name="var")
    yield dg.Output(pv.pod.preprocessor, output_name="preproc")

@dg.solid(
    input_defs=[dg.InputDefinition(name="preproc"), dg.InputDefinition(name="var")],
    output_defs=[dg.OutputDefinition(name="ds")]
)
def load_ds(context, preproc, var):
    return preproc.load_ds(var)

@dg.solid(
    input_defs=[dg.InputDefinition(name="var"), dg.InputDefinition(name="ds")]
)
def crop_daterange(context, var, ds):
    func = preprocessor.CropDateRangeFunction(_,_)
    return func.process(var, ds)

@dg.solid(
    input_defs=[dg.InputDefinition(name="var"), dg.InputDefinition(name="ds")]
)
def precip_rate_to_flux(context, var, ds):
    func = preprocessor.PrecipRateToFluxFunction(_,_)
    return func.process(var, ds)

@dg.solid(
    input_defs=[dg.InputDefinition(name="var"), dg.InputDefinition(name="ds")]
)
def convert_units(context, var, ds):
    func = preprocessor.ConvertUnitsFunction(_,_)
    return func.process(var, ds)

@dg.solid(
    input_defs=[dg.InputDefinition(name="var"), dg.InputDefinition(name="ds")]
)
def extract_level(context, var, ds):
    func = preprocessor.ExtractLevelFunction(_,_)
    return func.process(var, ds)

@dg.solid(
    input_defs=[dg.InputDefinition(name="var"), dg.InputDefinition(name="ds")]
)
def rename_variables(context, var, ds):
    func = preprocessor.RenameVariablesFunction(_,_)
    return func.process(var, ds)

@dg.solid(
    input_defs=[
        dg.InputDefinition(name="preproc"),
        dg.InputDefinition(name="var"),
        dg.InputDefinition(name="ds")
    ]
)
def write_ds(context, preproc, var, ds):
    preproc.write_ds(var, ds)

@dg.composite_solid(
    input_defs=[dg.InputDefinition(name="pv")]
)
def preprocess_one_file(pv):
    var, preproc = get_attrs(pv)
    ds = load_ds(preproc, var)
    ds = crop_daterange(var, ds)
    ds = precip_rate_to_flux(var, ds)
    ds = convert_units(var, ds)
    ds = extract_level(var, ds)
    ds = rename_variables(var, ds)
    write_ds(preproc, var, ds)


In [20]:
@dg.pipeline
def test_preproc():
    mdtf_preproc_harness(mdtf_test_harness()).map(preprocess_one_file)

run_config = {
    "solids": {
        "mdtf_test_harness": {
            "config": {
                "cli_args": '-v -f="temp_test.jsonc" -p Wheeler_Kiladis',
            }
        }
    }
}
dg.execute_pipeline(test_preproc, run_config=run_config)

2021-08-24 21:31:35 - dagster - DEBUG - test_preproc - d7b86c17-0853-4dd0-81ff-65b0638e2f23 - 11992 - PIPELINE_START - Started execution of pipeline "test_preproc".
2021-08-24 21:31:35 - dagster - DEBUG - test_preproc - d7b86c17-0853-4dd0-81ff-65b0638e2f23 - 11992 - ENGINE_EVENT - Executing steps in process (pid: 11992)
2021-08-24 21:31:35 - dagster - DEBUG - test_preproc - d7b86c17-0853-4dd0-81ff-65b0638e2f23 - 11992 - ENGINE_EVENT - Starting initialization of resources [io_manager].
2021-08-24 21:31:35 - dagster - DEBUG - test_preproc - d7b86c17-0853-4dd0-81ff-65b0638e2f23 - 11992 - ENGINE_EVENT - Finished initialization of resources [io_manager].
2021-08-24 21:31:35 - dagster - DEBUG - test_preproc - d7b86c17-0853-4dd0-81ff-65b0638e2f23 - 11992 - mdtf_test_harness - LOGS_CAPTURED - Started capturing logs for solid: mdtf_test_harness.
2021-08-24 21:31:35 - dagster - DEBUG - test_preproc - d7b86c17-0853-4dd0-81ff-65b0638e2f23 - 11992 - mdtf_test_harness - STEP_START - Started executio

Wheeler_Kiladis rlut succeeded fetched
Wheeler_Kiladis pr failed inited
Wheeler_Kiladis pr succeeded fetched
Wheeler_Kiladis omega500 succeeded fetched
Wheeler_Kiladis omega500 inactive inited
Wheeler_Kiladis u200 succeeded fetched
Wheeler_Kiladis u200 inactive inited
Wheeler_Kiladis u850 succeeded fetched
Wheeler_Kiladis u850 inactive inited
0
XX Wheeler_Kiladis


<dagster.core.execution.results.PipelineExecutionResult at 0x7f9f70b35b80>

In [4]:
%tb

SystemExit: 1

In [3]:
os.chdir(MDTF_ROOT)
util.logs.initial_log_config()
cli_obj = cli.MDTFTopLevelArgParser(MDTF_ROOT)
fmwk = cli_obj.dispatch('-v -f="temp_test.jsonc" -p Wheeler_Kiladis')
#print(fmwk)
#fmwk.DataSource = dgLocalFileDataSource

In [4]:
fmwk = cli_obj.dispatch(args='-v -f="temp_test.jsonc" -p Wheeler_Kiladis')

x1
x2
x3
x4 data_manager
x4 environment_manager
x4 runtime_manager
x4 output_manager
x5
x1
x2 <class 'src.core.MDTFFramework'> (ArgParserHarness(prog='mdtf', usage='mdtf run [options] [CASE_ROOT_DIR]', description='Runs diagnostics in the NOAA Model Diagnostics Task Force (MDTF) package. See\ndocumentation at https://mdtf-diagnostics.rtfd.io.\n\nThe scripts runs specified diagnostics on data at CASE_ROOT_DIR, using\nconfiguration set on the command line and/or through a file passed via\n--input_file.', formatter_class=<class 'src.cli.CustomHelpFormatter'>, conflict_handler='error', add_help=True),) {}


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)

