In [None]:
%load_ext autoreload
%autoreload 2
from multiprocessing import get_context
import os
import time
from pathlib import Path
import shutil
import numpy as np
import thuner.data as data
import thuner.data.dispatch as dispatch
import thuner.grid as grid
import thuner.option as option
import thuner.track as track
import thuner.analyze as analyze
import thuner.parallel as parallel
import thuner.visualize as visualize
import thuner.log as log

notebook_name = "gridrad_severe_gadi_demo.ipynb"

In [None]:
# Parent directory for saving outputs
base_local = Path("/scratch/w40/esh563/THUNER_output")
year=2010
event_directories = data.gridrad.get_event_directories(year)
event_directory = event_directories[0]
start, end, event_start = data.gridrad.get_event_times(event_directory)

period = parallel.get_period(start, end)
intervals = parallel.get_time_intervals(start, end, period=period)

output_parent = base_local / f"runs/dev/gridrad_{event_start.replace('-', '')}"
# if output_parent.exists():
    # shutil.rmtree(output_parent)
options_directory = output_parent / "options"

# Create the data_options dictionary
gridrad_parent = str(base_local / "input_data/raw")
converted_options = {"save": True, "load": False, "parent_converted": None}
gridrad_options = data.gridrad.gridrad_data_options(
    start=start,
    end=end,
    converted_options=converted_options,
    event_start=event_start,
    parent_local=gridrad_parent,
)
era5_parent = "/g/data/rt52"
era5_pl_options = data.era5.data_options(
    start=start, end=end, parent_local=era5_parent
)
kwargs = {"start": start, "end": end, "data_format": "single-levels"}
kwargs.update({"parent_local": era5_parent})
era5_sl_options = data.era5.data_options(**kwargs)

data_options = option.consolidate_options(
    [gridrad_options, era5_pl_options, era5_sl_options]
)
dispatch.check_data_options(data_options)
data.option.save_data_options(data_options, options_directory=options_directory)

# Create the grid_options dictionary using the first file in the cpol dataset
grid_options = grid.create_options(
    name="geographic", regrid=False, altitude_spacing=None, geographic_spacing=None
)
grid.check_options(grid_options)
grid.save_grid_options(grid_options, options_directory=options_directory)

# Create the track_options dictionary
track_options = option.default_track_options(dataset="gridrad")
track_options.levels[1].objects[0].tracking.global_flow_margin = 70
track_options.levels[1].objects[0].tracking.unique_global_flow = False
track_options.to_yaml(options_directory / "track.yml")

# Create the display_options dictionary
visualize_options = None

In [None]:
times = data.utils.generate_times(data_options.dataset_by_name("cpol"))
tracks = track.track(
    times,
    data_options,
    grid_options,
    track_options,
    visualize_options,
    output_directory=output_parent
)

In [None]:
num_processes = int(os.cpu_count() * 0.5)
num_processes = 6
with log.logging_listener(), get_context("spawn").Pool(
    initializer=parallel.initialize_process, processes=num_processes
) as pool:
    results = []
    for i, time_interval in enumerate(intervals):
        args = [i, time_interval, data_options.model_copy(deep=True), grid_options.copy()]
        args += [track_options.copy(), visualize_options]
        args += [output_parent, "gridrad"]
        args = tuple(args)
        results.append(pool.apply_async(parallel.track_interval, args))
    pool.close()
    pool.join()
    parallel.check_results(results)

In [None]:
parallel.stitch_run(output_parent, intervals, cleanup=True)

In [None]:
analysis_options = analyze.mcs.AnalysisOptions()
analyze.mcs.process_velocities(output_parent)
analyze.mcs.quality_control(output_parent, analysis_options)
analyze.mcs.classify_all(output_parent)
figure_options = visualize.option.horizontal_attribute_options(
    "mcs_velocity_analysis", style="gadi", attributes=["velocity", "offset"]
)
start_time = np.datetime64(start)
end_time = np.datetime64(end)
args = [output_parent, start_time, end_time, figure_options]
kwargs = {"parallel_figure": False, "dt": 7200, "by_date": False}
visualize.attribute.mcs_series(*args, **kwargs)