## Visual QC notebook
This notebook is designed to be used in a visual QC check of regridded CMIP6 data. Regridded files are counted, and a percentage of the files (here 10%) are randomly selected for QC. Using only the output file names, we will:

- locate the original CMIP6 source data
- plot source data alongside regridded data to compare visually

#### How to use with `prefect` via `papermill`:
This notebook should be run as the final step of the prefect regridding flow. The output will be saved as a new notebook in the QC directory created during the flow. To accomplish this, create a task in the prefect flow that will execute this notebook from the command line using `papermill`, e.g.:

```papermill path/to/repo/regridding/visual_qc.ipynb path/to/qc/output/output.ipynb -r output_directory "/path/to/output/dir" -r cmip6_directory "/path/to/cmip6/dir"```

The first argument is this notebook's location, which can be constructed using the `{output_directory}` parameter of the flow run (ie, the notebook's location within the downloaded repo directory). The second argument is the desired notebook output location, which can also be constructed using the `{output_directory}` parameter of the flow run. The remaining arguments are raw strings (denoted by `-r`) of the working and input directories used in the flow run.

Papermill parameter cell:

In [1]:
# this cell is tagged "parameters" and contains default parameter values for this notebook
# any parameters injected by papermill during the prefect flow will be written into a new cell directly beneath this one
# and will override the values in this cell
output_directory = "/beegfs/CMIP6/snapdata/cmip6_regridding"
cmip6_directory = "/beegfs/CMIP6/arctic-cmip6/CMIP6"
vars = "tas"
freqs = "mon"
models = "GFDL-ESM4"
scenarios = "ssp370"

In [61]:
# this cell is tagged "parameters" and contains default parameter values for this notebook
# any parameters injected by papermill during the prefect flow will be written into a new cell directly beneath this one
# and will override the values in this cell
output_directory = "/beegfs/CMIP6/kmredilla/cmip6_regridding"
cmip6_directory = "/beegfs/CMIP6/arctic-cmip6/CMIP6"
vars = "tas"
freqs = "mon"
models = "MRI-ESM2-0"
scenarios = "ssp370"

Import packages

In [3]:
import cftime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import random
from pandas.errors import OutOfBoundsDatetime
from pathlib import Path
from config import *
from regrid import *
from qc import *

import concurrent.futures

#### Setup

Define data sources and parameters for QC. This notebook is expected to only QC the data that was processed in the flow run, i.e. only those files derived from source files which are listed in the existing batch files. We will want to verify that the supplied parameters correspond to these regridded files. 

Determine which regridded files to check:

In [5]:
# set cmip6_dir
cmip6_dir = Path(cmip6_directory)
output_dir = Path(output_directory)
regrid_dir = output_dir.joinpath("regrid")


regrid_batch_dir = output_dir.joinpath("regrid_batch")
slurm_dir = output_dir.joinpath("slurm")
slurm_rerid_dir = slurm_dir.joinpath("regrid")

In [62]:
src_fps = get_source_fps_from_batch_files(regrid_batch_dir)

Make sure the expected source files match the parameters supplied to notebook. If not then the notebook was not run with the expected parameters!

In [86]:
src_params = [extract_params_from_src_filepath(fp) for fp in src_fps]
for p_name, p_str in zip(
    ["model", "scenario", "frequency", "variable_id"], [models, scenarios, freqs, vars]
):
    assert all(
        [params[p_name] in p_str for params in src_params]
    ), f"Source files submitted for regridding contain values for the {p_name} parameter ({', '.join(list(set([params[p_name] for params in src_params])))}) which were not supplied for QC in this notebook ({p_str})."

Ignore certain files based on results in slurm output files:

In [75]:
# check slurm files
fps_to_ignore = summarize_slurm_out_files(slurm_dir)
for fp in fps_to_ignore:
    if fp in src_fps:
        src_fps.remove(fp)

Now compare expected files to existing files and make sure values OK. This will open and check files in parallel and could take a while. 

In [91]:
ds_errors, value_errors = compare_expected_to_existing_and_check_values(
    regrid_dir,
    regrid_batch_dir,
    vars,
    freqs,
    models,
    scenarios,
    fps_to_ignore,
)

(['/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_Amon_MRI-ESM2-0_ssp370_regrid_201501-201512.nc',
  '/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_Amon_MRI-ESM2-0_ssp370_regrid_201601-201612.nc',
  '/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_Amon_MRI-ESM2-0_ssp370_regrid_201701-201712.nc',
  '/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_Amon_MRI-ESM2-0_ssp370_regrid_201801-201812.nc',
  '/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_Amon_MRI-ESM2-0_ssp370_regrid_201901-201912.nc',
  '/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_Amon_MRI-ESM2-0_ssp370_regrid_202001-202012.nc',
  '/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_Amon_MRI-ESM2-0_ssp370_regrid_202101-202112.nc',
  '/beegfs/CMIP6/kmredilla/cmip6_regridding/regrid/MRI-ESM2-0/ssp370/Amon/tas/tas_A

Here is a summary of the errors:

In [None]:
# print summary messages
error_count = len(ds_errors) + len(value_errors)
print(f"QC process complete: {error_count} errors found.")
if len(ds_errors) > 0:
    print(
        f"Errors in opening some datasets. {len(ds_errors)} files could not be opened. See {str(error_file)} for error log."
    )
if len(value_errors) > 0:
    print(
        f"Errors in dataset values. {len(value_errors)} files have regridded values outside of source file range. See {str(error_file)} for error log."
    )

#### Visual assessment

Here we will build some helper functions to aid us perform a qualitative assessment of the regridding. Using just the regridded file name and the main CMIP6 source directory, we will reconstruct the source file path and by plot a comparison of source files and regridded files.

In [None]:
# set min and max number of files to QC
min_qc = 20
max_qc = 75

# pick a percentage of the regridded files for visual QC, and count those files
pct = 10
pct_count = round(len(regrid_fps) * (pct / 100))

# use all files if less than minimum are available
# use max number of random files if percentage exceeds maximum
# or just use percentage of random files
if len(regrid_fps) <= min_qc:
    qc_files = regrid_fps
elif pct_count >= max_qc:
    qc_files = random.sample(regrid_fps, max_qc)
else:
    qc_files = random.sample(regrid_fps, pct_count)

In [4]:
def get_matching_time_filepath(fps, test_date):
    """Find a file from a given list of raw CMIP6 filepaths that conatins the test date within the timespan in the filename."""
    matching_fps = []
    for fp in fps:
        start_str, end_str = fp.name.split(".nc")[0].split("_")[-1].split("-")
        start_str = f"{start_str}01" if len(start_str) == 6 else start_str
        # end date should be constructed as the end of month for monthly data
        #  (and should always be December??)
        end_str = f"{end_str}31" if len(end_str) == 6 else end_str
        format_str = "%Y%m%d"
        try:
            start_dt = pd.to_datetime(start_str, format=format_str)
            # it should be OK if end date is
            end_dt = pd.to_datetime(end_str, format=format_str)
        except OutOfBoundsDatetime:
            # we should not be regridding files with time values that cause this (2300 etc)
            continue

        if start_dt <= test_date < end_dt:
            matching_fps.append(fp)

    # there should only be one
    assert (
        len(matching_fps) == 1
    ), f"Test date {test_date} matched {len(matching_fps)} files (from {fps})."

    return matching_fps[0]


def generate_cmip6_filepath_from_regrid_filename(fn):
    """Get the path to the original CMIP6 filename from a regridded file name.

    Because the original CMIP6 filenames were split up during the processing,
    this method finds the original filename based on matching all possible attributes,
    then testing for inclusion of regrid file start date within the date range formed by the CMIP6 file timespan.
    """
    var_id, freq, model, scenario, _, timespan = fn.split(".nc")[0].split("_")
    institution = model_inst_lu[model]
    experiment_id = "ScenarioMIP" if scenario in prod_scenarios else "CMIP"
    # Construct the original CMIP6 filepath from the filename.
    # Need to use glob because of the "grid type" filename attribute that we do not have a lookup for.
    var_dir = cmip6_dir.joinpath(f"{experiment_id}/{institution}/{model}/{scenario}")
    glob_str = f"*/{freq}/{var_id}/*/*/{var_id}_{freq}_{model}_{scenario}_*.nc"
    candidate_fps = list(var_dir.glob(glob_str))

    assert (
        candidate_fps
    ), f"No files found for regridded file {fn} in {var_dir} with {glob_str}."

    start_str = timespan.split("-")[0]
    format_str = "%Y%m" if len(start_str) == 6 else "%Y%m%d"
    start_dt = pd.to_datetime(start_str, format=format_str)
    cmip6_fp = get_matching_time_filepath(candidate_fps, start_dt)

    return cmip6_fp


def plot_comparison(regrid_fp):
    """For a given regridded file, find the source file and plot side by side."""
    src_fp = generate_cmip6_filepath_from_regrid_filename(regrid_fp.name)
    src_ds = open_and_crop_dataset(src_fp, lat_slice=prod_lat_slice)
    # entire plotting function is inside this try block
    # if the dataset cannot be opened, just print a message instead of an error
    try:
        regrid_ds = xr.open_dataset(regrid_fp)
    except:
        print(f"Regridded dataset could not be opened: {regrid_fp}")

    # lat axis is flipped in regrid files
    src_lat_slice = slice(55, 75)
    regrid_lat_slice = slice(75, 55)
    lon_slice_src = slice(200, 240)
    lon_slice_regrid = slice(-160, -120)
    time_val = regrid_ds.time.values[0]
    var_id = src_ds.attrs["variable_id"]
    assert get_var_id(src_ds) == var_id, "Variable ID mismatch"
    assert get_var_id(regrid_ds) == var_id, "Variable ID mismatch"

    fig, axes = plt.subplots(1, 2, figsize=(15, 4))
    fig.suptitle(
        f"Variable: {var_id}     Model: {src_ds.attrs['source_id']}     Scenario: {src_ds.attrs['experiment_id']}"
    )

    # now, there are multiple possible time formats for the source dataset.
    # convert the chosen time value to that matching format for subsetting.
    sel_method = None
    if isinstance(src_ds.time.values[0], cftime._cftime.Datetime360Day):
        # It seems like monthly data use 16 for the day
        src_hour = src_ds.time.dt.hour[0]
        src_time = cftime.Datetime360Day(
            year=time_val.year,
            month=time_val.month,
            day=time_val.day,
            hour=src_hour,
        )
    elif isinstance(
        src_ds.time.values[0], pd._libs.tslibs.timestamps.Timestamp
    ) or isinstance(src_ds.time.values[0], np.datetime64):
        src_hour = src_ds.time.dt.hour[0].values.item()
        src_time = pd.to_datetime(
            f"{time_val.year}-{time_val.month}-{time_val.day}T{src_hour}:00:00"
        )
    else:
        if time_val not in src_ds.time.values:
            src_hour = src_ds.time.dt.hour[0]
            src_time = cftime.DatetimeNoLeap(
                year=time_val.year,
                month=time_val.month,
                day=time_val.day,
                hour=src_hour,
            )
        else:
            src_time = time_val
    if src_time not in src_ds.time.values:
        print(f"Sample timestamp not found in source file ({src_fp}). Using nearest.")
        # probably safe to just use nearest method in any event
        # since there can be incorrectly labeled frequency attributes
        sel_method = "nearest"
        if src_ds.attrs["frequency"] == "mon":
            # We expect that the file will be monthly if the source time chosen is not actually in the dataset
            # This is because we make the time values consistent in the regridded files,
            # and monthly source files might have used e.g. 16th day
            pass
        else:
            # OK this happens and there is nothing we can do about the source data not having consistent attributes.
            # Don't need to fail. Just print a message.
            print("Expected monthly file but frequency attribute does not match.")

    # ensure units are consistent with regridded dataset
    src_ds = convert_units(src_ds)

    # get a vmin and vmax from src dataset to use for both plots, if a map
    try:
        vmin = (
            src_ds[var_id]
            .sel(time=src_time, method=sel_method)
            .sel(lat=src_lat_slice, lon=lon_slice_src)
            .values.min()
        )
        vmax = (
            src_ds[var_id]
            .sel(time=src_time, method=sel_method)
            .sel(lat=src_lat_slice, lon=lon_slice_src)
            .values.max()
        )
    except:
        print("Error getting vmin and vmax values from source data.")

    try:  # maps
        src_ds[var_id].sel(time=src_time, method=sel_method).sel(
            lat=src_lat_slice, lon=lon_slice_src
        ).plot(ax=axes[0], vmin=vmin, vmax=vmax)
        axes[0].set_title(f"Source dataset (timestamp: {src_time})")
        regrid_ds[var_id].sel(time=time_val).sel(
            lat=regrid_lat_slice,
            lon=lon_slice_regrid,
            # explitictly set the x axis to be the standard longitude for regridded data
            #  because grid is (confusingly) oriented time, lon, lat for rasdaman
        ).plot(ax=axes[1], vmin=vmin, vmax=vmax, x="lon")
        axes[1].set_title(f"Regridded dataset (timestamp: {time_val})")
        axes[1].set_xlabel("longitude [standard]")
        plt.show()

    except:  # histograms
        src_ds[var_id].sel(time=src_time, method=sel_method).sel(
            lat=src_lat_slice, lon=lon_slice_src
        ).plot(ax=axes[0])
        axes[0].set_title(f"Source dataset (timestamp: {src_time})")
        regrid_ds[var_id].sel(time=time_val).sel(
            lat=regrid_lat_slice, lon=lon_slice_regrid
        ).plot(ax=axes[1])
        axes[1].set_title(f"Regridded dataset (timestamp: {time_val})")

    plt.show()

From our previous random selection of regridded files to QC, plot comparisons in a spatial domain that includes Alaska and western Canada.

In [5]:
for fp in qc_files:
    plot_comparison(fp)