<a href="https://colab.research.google.com/github/thunderhoser/loss_function_paper_2022/blob/main/loss_functions_journal_paper_2022.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# README

This notebook goes with the following journal paper.
<br /><br />

Lagerquist, Ryan, and Imme Ebert-Uphoff, 2022: "Can we integrate spatial verification methods into neural-network loss functions for atmospheric science?" *Artificial Intelligence for the Earth Systems*, [https://doi.org/10.1175/AIES-D-22-0021.1](https://doi.org/10.1175/AIES-D-22-0021.1).

This notebook gives you a chance to experiment with the loss functions developed in the paper, which include spatial verification tools such as neighbourhood filtering and spectral filtering.
<br /><br />
Chances are, most of this notebook will not be interesting to you.  Most of this notebook contains utility methods that do things like find files (containing predictor and target values), read files, convert time stamps from one format to another, etc.  Utility methods are buried in the sections titled "Define private methods" and "Define public methods".  Typically, I use a Python library called ml4convection (which I developed for this project), where the utility methods are defined in separate files and thus easier to ignore.  However, I wanted to separate this notebook from ml4convection, providing an end-to-end resource that people can use to implement spatially enhanced loss functions, without the hassle of downloading and installing ml4convection.
<br /><br />
**In this notebook, most of the section titles are in <font color='red'>red</font> and followed by the word "required".  This means that you must run the corresponding code cell before running later code cells (further down) in the notebook.  If you do not run a code cell marked "required," later code cells will not work.**
<br /><br />
Unless you *really* care about my utility methods, you can auto-pilot your way through most of this notebook, just running each code cell and then moving on to the next one.  Exceptions to this rule are the following sections:

 - "Define loss functions" and "Define Fourier/wavelet methods" (if you care about exactly how I've implemented spatially enhanced loss functions)
 - "Experiment with different loss functions" (should be interesting to everyone who clicked on this notebook)

# <font color='red'>Make sure to use a GPU!</font>

The CNNs in this notebook run $\sim$10 times faster with a graphics-processing unit (GPU).  To request a GPU, do the following:

 1. Click on "Runtime" in the menu at the top (next to "File," "Edit," "View," etc.).
 2. Click on "Change runtime type".
 3. Click on "GPU" under the "Hardware accelerator" dropdown menu.

# <font color='red'>Install WaveTF (required)</font>

In [None]:
!pip install git+https://github.com/thunderhoser/WaveTF.git

Collecting git+https://github.com/thunderhoser/WaveTF.git
  Cloning https://github.com/thunderhoser/WaveTF.git to /tmp/pip-req-build-t9898lqb
  Running command git clone -q https://github.com/thunderhoser/WaveTF.git /tmp/pip-req-build-t9898lqb
Building wheels for collected packages: WaveTF
  Building wheel for WaveTF (setup.py) ... [?25l[?25hdone
  Created wheel for WaveTF: filename=WaveTF-0.1-py3-none-any.whl size=25398 sha256=99c9a6b9e2210c56c09e0c637bb83928b1b22284f1ba8adba042db6dd479ea8e
  Stored in directory: /tmp/pip-ephem-wheel-cache-phws62ui/wheels/f5/79/aa/8d2d7d532ed44ea39663ac5ee680c67564849646d729ab00bf
Successfully built WaveTF
Installing collected packages: WaveTF
Successfully installed WaveTF-0.1


# <font color='red'>Restart runtime (required)</font>

Before doing anything else, click on "Runtime" (in the top-left menu) and then "Restart Runtime".  Otherwise, WaveTF will not be properly installed in the Python environment.

# <font color='red'>Import packages (required)</font>

The next cell imports all Python packages used in this notebook.  If the notebook crashes anywhere, it will most likely be here.

In [None]:
import os
import copy
import glob
import dill
import errno
import time
import calendar
import gzip
import shutil
import tarfile
from urllib.request import urlretrieve
import netCDF4
import numpy
from scipy.ndimage.morphology import binary_erosion
import keras
import tensorflow
import tensorflow.keras as tf_keras
from tensorflow.keras import backend as K
from wavetf import WaveTFFactory

# <font color='red'>Define constants (required)</font>

The next cell defines constants used in this notebook.

In [None]:
TOLERANCE = 1e-6

DAYS_TO_SECONDS = 86400
DATE_STRING_FORMAT = '%Y%m%d'
GZIP_FILE_EXTENSION = '.gz'

NUM_RADARS = 4
GRID_SPACING_DEG = 0.0125
MAX_TARGET_RESOLUTION_DEG = 6.4

PREDICTOR_MATRIX_UNNORM_KEY = 'predictor_matrix_unnorm'
PREDICTOR_MATRIX_NORM_KEY = 'predictor_matrix_norm'
PREDICTOR_MATRIX_UNIF_NORM_KEY = 'predictor_matrix_unif_norm'
VALID_TIMES_KEY = 'valid_times_unix_sec'
BAND_NUMBERS_KEY = 'band_numbers'
LATITUDES_KEY = 'latitudes_deg_n'
LONGITUDES_KEY = 'longitudes_deg_e'
NORMALIZATION_FILE_KEY = 'normalization_file_name'
MASK_MATRIX_KEY = 'mask_matrix'
FULL_MASK_MATRIX_KEY = 'full_mask_matrix'
FULL_LATITUDES_KEY = 'full_latitudes_deg_n'
FULL_LONGITUDES_KEY = 'full_longitudes_deg_e'

ONE_PER_PREDICTOR_TIME_KEYS = [
    PREDICTOR_MATRIX_UNNORM_KEY, PREDICTOR_MATRIX_NORM_KEY,
    PREDICTOR_MATRIX_UNIF_NORM_KEY, VALID_TIMES_KEY
]

TARGET_MATRIX_KEY = 'target_matrix'
PREDICTOR_MATRIX_KEY = 'predictor_matrix'
ONE_PER_TARGET_TIME_KEYS = [TARGET_MATRIX_KEY, VALID_TIMES_KEY]

BATCH_SIZE_KEY = 'num_examples_per_batch'
MAX_DAILY_EXAMPLES_KEY = 'max_examples_per_day_in_batch'
BAND_NUMBERS_KEY = 'band_numbers'
LEAD_TIME_KEY = 'lead_time_seconds'
LAG_TIMES_KEY = 'lag_times_seconds'
INCLUDE_TIME_DIM_KEY = 'include_time_dimension'
FIRST_VALID_DATE_KEY = 'first_valid_date_string'
LAST_VALID_DATE_KEY = 'last_valid_date_string'
NORMALIZE_FLAG_KEY = 'normalize'
UNIFORMIZE_FLAG_KEY = 'uniformize'
FOURIER_TRANSFORM_KEY = 'fourier_transform_targets'
WAVELET_TRANSFORM_KEY = 'wavelet_transform_targets'
MIN_TARGET_RESOLUTION_KEY = 'min_target_resolution_deg'
MAX_TARGET_RESOLUTION_KEY = 'max_target_resolution_deg'
PREDICTOR_DIRECTORY_KEY = 'top_predictor_dir_name'
TARGET_DIRECTORY_KEY = 'top_target_dir_name'

DEFAULT_GENERATOR_OPTION_DICT = {
    BATCH_SIZE_KEY: 256,
    MAX_DAILY_EXAMPLES_KEY: 64,
    LAG_TIMES_KEY: numpy.array([0], dtype=int),
    NORMALIZE_FLAG_KEY: True,
    UNIFORMIZE_FLAG_KEY: True,
    FOURIER_TRANSFORM_KEY: False,
    WAVELET_TRANSFORM_KEY: False
}

NUM_EPOCHS_KEY = 'num_epochs'
NUM_TRAINING_BATCHES_KEY = 'num_training_batches_per_epoch'
TRAINING_OPTIONS_KEY = 'training_option_dict'
NUM_VALIDATION_BATCHES_KEY = 'num_validation_batches_per_epoch'
VALIDATION_OPTIONS_KEY = 'validation_option_dict'
EARLY_STOPPING_KEY = 'do_early_stopping'
PLATEAU_LR_MUTIPLIER_KEY = 'plateau_lr_multiplier'
LOSS_FUNCTION_KEY = 'loss_function_name'
METRIC_NAMES_KEY = 'metric_names'
MASK_MATRIX_KEY = 'mask_matrix'

METADATA_KEYS = [
    NUM_EPOCHS_KEY, NUM_TRAINING_BATCHES_KEY,
    TRAINING_OPTIONS_KEY, NUM_VALIDATION_BATCHES_KEY, VALIDATION_OPTIONS_KEY,
    EARLY_STOPPING_KEY, PLATEAU_LR_MUTIPLIER_KEY, LOSS_FUNCTION_KEY,
    METRIC_NAMES_KEY, MASK_MATRIX_KEY
]

SCORE_NAME_KEY = 'score_name'
HALF_WINDOW_SIZE_KEY = 'half_window_size_px'

FSS_NAME = 'fss'
BRIER_SCORE_NAME = 'brier'
CSI_NAME = 'csi'
HEIDKE_SCORE_NAME = 'heidke'
GERRITY_SCORE_NAME = 'gerrity'
PEIRCE_SCORE_NAME = 'peirce'
FREQUENCY_BIAS_NAME = 'bias'
IOU_NAME = 'iou'
ALL_CLASS_IOU_NAME = 'all-class-iou'
DICE_COEFF_NAME = 'dice'
CROSS_ENTROPY_NAME = 'xentropy'

PLATEAU_PATIENCE_EPOCHS = 10
DEFAULT_LEARNING_RATE_MULTIPLIER = 0.5
PLATEAU_COOLDOWN_EPOCHS = 0
EARLY_STOPPING_PATIENCE_EPOCHS = 30
LOSS_PATIENCE = 0.

ONLINE_TARGET_AND_MODEL_FILE_NAME_ZIPPED = (
    'https://storage.googleapis.com/loss-functions-paper-2022-colab/targets_and_model.tar'
)
TARGET_AND_MODEL_FILE_NAME_ZIPPED = '/content/targets_and_model.tar'
TARGET_AND_MODEL_DIR_NAME = '/content/targets_and_model'

ONLINE_PREDICTOR_FILE_NAMES_ZIPPED = [
    'https://storage.googleapis.com/loss-functions-paper-2022-colab/'
    'predictors_2016_radars1and3.tar',
    'https://storage.googleapis.com/loss-functions-paper-2022-colab/predictors_2017_radars1and3.tar'
]

PREDICTOR_FILE_NAMES_ZIPPED = [
    '/content/predictors/predictors_2016.tar',
    '/content/predictors/predictors_2017.tar'
]

PREDICTOR_DIR_NAME = '/content/predictors'

# <font color='red'>Define private methods (required)</font>

The next cell defines private methods used in this notebook.  Private methods have a name starting with an underscore (*e.g.*, `_ceiling_to_nearest`) and should not be called directly by the user (*e.g.*, you).  Private methods should be called only by public methods, whose names do not start with an underscore.

In [None]:
def _ceiling_to_nearest(input_value, rounding_base):
    """Rounds numbers *up* to nearest x, where x is a positive real number.

    :param input_value: Either numpy array of real numbers or scalar real
        number.
    :param rounding_base: Numbers will be rounded *up* to this base.
    :return: output_value: Same as input_value, except rounded.
    """

    return rounding_base * numpy.ceil(input_value / rounding_base)


def _floor_to_nearest(input_value, rounding_base):
    """Rounds numbers *down* to nearest x, where x is a positive real number.

    :param input_value: Either numpy array of real numbers or scalar real
        number.
    :param rounding_base: Numbers will be rounded *down* to this base.
    :return: output_value: Same as input_value, except rounded.
    """

    return rounding_base * numpy.floor(input_value / rounding_base)


def _time_string_to_unix_sec(time_string, time_directive):
    """Converts time from string to Unix format.

    Unix format = seconds since 0000 UTC 1 Jan 1970.

    :param time_string: Time string.
    :param time_directive: Format of time string (examples: "%Y%m%d" if string
        is "yyyymmdd", "%Y-%m-%d-%H%M%S" if string is "yyyy-mm-dd-HHMMSS",
        etc.).
    :return: unix_time_sec: Time in Unix format.
    """

    return calendar.timegm(time.strptime(time_string, time_directive))


def _time_unix_sec_to_string(unix_time_sec, time_directive):
    """Converts time from Unix format to string.

    Unix format = seconds since 0000 UTC 1 Jan 1970.

    :param unix_time_sec: Time in Unix format.
    :param time_directive: Format of time string (examples: "%Y%m%d" if string
        is "yyyymmdd", "%Y-%m-%d-%H%M%S" if string is "yyyy-mm-dd-HHMMSS",
        etc.).
    :return: time_string: Time string.
    """

    return time.strftime(time_directive, time.gmtime(unix_time_sec))


def _time_range_to_list(start_time_unix_sec, end_time_unix_sec,
                        time_interval_sec, include_endpoint=True):
    """Converts time period from range and interval to list of exact times.

    N = number of exact times

    :param start_time_unix_sec: Start time (Unix format).
    :param end_time_unix_sec: End time (Unix format).
    :param time_interval_sec: Interval (seconds) between successive exact times.
    :param include_endpoint: Boolean flag.  If True, endpoint will be included
        in list of time steps.  If False, endpoint will be excluded.
    :return: unix_times_sec: length-N numpy array of exact times (Unix format).
    """

    start_time_unix_sec = int(_floor_to_nearest(
        float(start_time_unix_sec), time_interval_sec
    ))
    end_time_unix_sec = int(_ceiling_to_nearest(
        float(end_time_unix_sec), time_interval_sec
    ))

    if not include_endpoint:
        end_time_unix_sec -= time_interval_sec

    num_time_steps = 1 + int(numpy.round(
        (end_time_unix_sec - start_time_unix_sec) / time_interval_sec
    ))

    return numpy.linspace(
        start_time_unix_sec, end_time_unix_sec, num=num_time_steps, dtype=int
    )


def _get_dates_in_range(first_date_string, last_date_string):
    """Returns list of dates in range.

    :param first_date_string: First SPC date in range (format "yyyymmdd").
    :param last_date_string: Last SPC date in range (format "yyyymmdd").
    :return: spc_date_strings: 1-D list of SPC dates (format "yyyymmdd").
    """

    first_date_unix_sec = _time_string_to_unix_sec(
        first_date_string, DATE_STRING_FORMAT
    )
    last_date_unix_sec = _time_string_to_unix_sec(
        last_date_string, DATE_STRING_FORMAT
    )
    dates_unix_sec = _time_range_to_list(
        start_time_unix_sec=first_date_unix_sec,
        end_time_unix_sec=last_date_unix_sec,
        time_interval_sec=DAYS_TO_SECONDS, include_endpoint=True
    )

    return [
        _time_unix_sec_to_string(t, DATE_STRING_FORMAT) for t in dates_unix_sec
    ]


def _get_previous_date(date_string):
    """Returns previous date.

    :param date_string: Date (format "yyyymmdd").
    :return: prev_date_string: Previous date (format "yyyymmdd").
    """

    unix_time_sec = _time_string_to_unix_sec(date_string, DATE_STRING_FORMAT)
    return _time_unix_sec_to_string(
        unix_time_sec - DAYS_TO_SECONDS, DATE_STRING_FORMAT
    )


def _find_predictor_file(
        top_directory_name, date_string, radar_number=None, prefer_zipped=True,
        allow_other_format=True, raise_error_if_missing=True):
    """Finds NetCDF file with predictors.

    :param top_directory_name: Name of top-level directory where file is
        expected.
    :param date_string: Date (format "yyyymmdd").
    :param radar_number: Radar number (non-negative integer).  If you are
        looking for data on the full grid, leave this alone.
    :param prefer_zipped: Boolean flag.  If True, will look for zipped file
        first.  If False, will look for unzipped file first.
    :param allow_other_format: Boolean flag.  If True, will allow opposite of
        preferred file format (zipped or unzipped).
    :param raise_error_if_missing: Boolean flag.  If file is missing and
        `raise_error_if_missing == True`, will throw error.  If file is missing
        and `raise_error_if_missing == False`, will return *expected* file path.
    :return: predictor_file_name: File path.
    :raises: ValueError: if file is missing
        and `raise_error_if_missing == True`.
    """

    predictor_file_name = '{0:s}/{1:s}/predictors_{2:s}{3:s}.nc{4:s}'.format(
        top_directory_name, date_string[:4], date_string,
        '' if radar_number is None else '_radar{0:d}'.format(radar_number),
        GZIP_FILE_EXTENSION if prefer_zipped else ''
    )

    if os.path.isfile(predictor_file_name):
        return predictor_file_name

    if allow_other_format:
        if prefer_zipped:
            predictor_file_name = (
                predictor_file_name[:-len(GZIP_FILE_EXTENSION)]
            )
        else:
            predictor_file_name += GZIP_FILE_EXTENSION

    if os.path.isfile(predictor_file_name) or not raise_error_if_missing:
        return predictor_file_name

    error_string = 'Cannot find file.  Expected at: "{0:s}"'.format(
        predictor_file_name
    )
    raise ValueError(error_string)


def _find_many_predictor_files(
        top_directory_name, first_date_string, last_date_string,
        radar_number=None, prefer_zipped=True, allow_other_format=True,
        raise_error_if_all_missing=True, raise_error_if_any_missing=False,
        test_mode=False):
    """Finds many NetCDF files with predictors.

    :param top_directory_name: See doc for `_find_predictor_file`.
    :param first_date_string: First date (format "yyyymmdd").
    :param last_date_string: Last date (format "yyyymmdd").
    :param radar_number: See doc for `_find_predictor_file`.
    :param prefer_zipped: Same.
    :param allow_other_format: Same.
    :param raise_error_if_any_missing: Boolean flag.  If any file is missing and
        `raise_error_if_any_missing == True`, will throw error.
    :param raise_error_if_all_missing: Boolean flag.  If all files are missing
        and `raise_error_if_all_missing == True`, will throw error.
    :param test_mode: Leave this alone.
    :return: predictor_file_names: 1-D list of paths to target files.  This list
        does *not* contain expected paths to non-existent files.
    :raises: ValueError: if all files are missing and
        `raise_error_if_all_missing == True`.
    """

    date_strings = _get_dates_in_range(first_date_string, last_date_string)
    predictor_file_names = []

    for this_date_string in date_strings:
        this_file_name = _find_predictor_file(
            top_directory_name=top_directory_name, date_string=this_date_string,
            radar_number=radar_number,
            prefer_zipped=prefer_zipped, allow_other_format=allow_other_format,
            raise_error_if_missing=raise_error_if_any_missing
        )

        if test_mode or os.path.isfile(this_file_name):
            predictor_file_names.append(this_file_name)

    if raise_error_if_all_missing and len(predictor_file_names) == 0:
        error_string = (
            'Cannot find any file in directory "{0:s}" from dates {1:s} to '
            '{2:s}.'
        ).format(
            top_directory_name, first_date_string, last_date_string
        )
        raise ValueError(error_string)

    return predictor_file_names


def _find_target_file(
        top_directory_name, date_string, radar_number=None, prefer_zipped=True,
        allow_other_format=True, raise_error_if_missing=True):
    """Finds NetCDF file with targets.

    :param top_directory_name: Name of top-level directory where file is
        expected.
    :param date_string: Date (format "yyyymmdd").
    :param radar_number: Radar number (non-negative integer).  If you are
        looking for data on the full grid, leave this alone.
    :param prefer_zipped: Boolean flag.  If True, will look for zipped file
        first.  If False, will look for unzipped file first.
    :param allow_other_format: Boolean flag.  If True, will allow opposite of
        preferred file format (zipped or unzipped).
    :param raise_error_if_missing: Boolean flag.  If file is missing and
        `raise_error_if_missing == True`, will throw error.  If file is missing
        and `raise_error_if_missing == False`, will return *expected* file path.
    :return: target_file_name: File path.
    :raises: ValueError: if file is missing
        and `raise_error_if_missing == True`.
    """

    target_file_name = '{0:s}/{1:s}/targets_{2:s}{3:s}.nc{4:s}'.format(
        top_directory_name, date_string[:4], date_string,
        '' if radar_number is None else '_radar{0:d}'.format(radar_number),
        GZIP_FILE_EXTENSION if prefer_zipped else ''
    )

    if os.path.isfile(target_file_name):
        return target_file_name

    if allow_other_format:
        if prefer_zipped:
            target_file_name = target_file_name[:-len(GZIP_FILE_EXTENSION)]
        else:
            target_file_name += GZIP_FILE_EXTENSION

    if os.path.isfile(target_file_name) or not raise_error_if_missing:
        return target_file_name

    error_string = 'Cannot find file.  Expected at: "{0:s}"'.format(
        target_file_name
    )
    raise ValueError(error_string)


def _find_many_target_files(
        top_directory_name, first_date_string, last_date_string,
        radar_number=None, prefer_zipped=True, allow_other_format=True,
        raise_error_if_all_missing=True, raise_error_if_any_missing=False,
        test_mode=False):
    """Finds many NetCDF files with targets.

    :param top_directory_name: See doc for `_find_target_file`.
    :param first_date_string: First date (format "yyyymmdd").
    :param last_date_string: Last date (format "yyyymmdd").
    :param radar_number: See doc for `_find_target_file`.
    :param prefer_zipped: Same.
    :param allow_other_format: Same.
    :param raise_error_if_any_missing: Boolean flag.  If any file is missing and
        `raise_error_if_any_missing == True`, will throw error.
    :param raise_error_if_all_missing: Boolean flag.  If all files are missing
        and `raise_error_if_all_missing == True`, will throw error.
    :param test_mode: Leave this alone.
    :return: target_file_names: 1-D list of paths to target files.  This list
        does *not* contain expected paths to non-existent files.
    :raises: ValueError: if all files are missing and
        `raise_error_if_all_missing == True`.
    """

    date_strings = _get_dates_in_range(first_date_string, last_date_string)
    target_file_names = []

    for this_date_string in date_strings:
        this_file_name = _find_target_file(
            top_directory_name=top_directory_name, date_string=this_date_string,
            radar_number=radar_number,
            prefer_zipped=prefer_zipped, allow_other_format=allow_other_format,
            raise_error_if_missing=raise_error_if_any_missing
        )

        if test_mode or os.path.isfile(this_file_name):
            target_file_names.append(this_file_name)

    if raise_error_if_all_missing and len(target_file_names) == 0:
        error_string = (
            'Cannot find any file in directory "{0:s}" from dates {1:s} to '
            '{2:s}.'
        ).format(
            top_directory_name, first_date_string, last_date_string
        )
        raise ValueError(error_string)

    return target_file_names


def _find_days_with_both_inputs(
        predictor_file_names, target_file_names, lead_time_seconds,
        lag_times_seconds):
    """Finds days with both inputs (predictor and target file) available.

    :param predictor_file_names: See doc for `_read_inputs_one_day`.
    :param target_file_names: Same.
    :param lead_time_seconds: Same.
    :param lag_times_seconds: Same.
    :return: valid_date_strings: List of valid dates (target dates) for which
        both predictors and targets exist, in format "yyyymmdd".
    """

    max_time_diff_seconds = numpy.max(lag_times_seconds) + lead_time_seconds

    predictor_date_strings = [
        file_name_to_date(f) for f in predictor_file_names
    ]
    target_date_strings = [
        file_name_to_date(f) for f in target_file_names
    ]
    valid_date_strings = []

    for this_target_date_string in target_date_strings:
        if this_target_date_string not in predictor_date_strings:
            continue

        if max_time_diff_seconds > 0:
            if (
                    _get_previous_date(this_target_date_string)
                    not in predictor_date_strings
            ):
                pass
                # continue

        valid_date_strings.append(this_target_date_string)

    return valid_date_strings


def _read_targets(dataset_object):
    """Reads targets from NetCDF file.

    This method should be called only from `_read_target_file`.

    :param dataset_object: Instance of `netCDF4.Dataset`.
    :return: target_dict: See doc for `_read_target_file`.
    """

    target_dict = {
        TARGET_MATRIX_KEY: dataset_object.variables[TARGET_MATRIX_KEY][:],
        VALID_TIMES_KEY: dataset_object.variables[VALID_TIMES_KEY][:],
        LATITUDES_KEY: dataset_object.variables[LATITUDES_KEY][:],
        LONGITUDES_KEY: dataset_object.variables[LONGITUDES_KEY][:],
        MASK_MATRIX_KEY:
            dataset_object.variables[MASK_MATRIX_KEY][:].astype(bool)
    }

    if FULL_MASK_MATRIX_KEY in dataset_object.variables:
        target_dict[FULL_MASK_MATRIX_KEY] = (
            dataset_object.variables[FULL_MASK_MATRIX_KEY][:].astype(bool)
        )
        target_dict[FULL_LATITUDES_KEY] = (
            dataset_object.variables[FULL_LATITUDES_KEY][:]
        )
        target_dict[FULL_LONGITUDES_KEY] = (
            dataset_object.variables[FULL_LONGITUDES_KEY][:]
        )
    else:
        target_dict[FULL_MASK_MATRIX_KEY] = copy.deepcopy(
            target_dict[MASK_MATRIX_KEY]
        )
        target_dict[FULL_LATITUDES_KEY] = target_dict[LATITUDES_KEY] + 0.
        target_dict[FULL_LONGITUDES_KEY] = target_dict[LONGITUDES_KEY] + 0.

    if numpy.any(numpy.diff(target_dict[LATITUDES_KEY]) < 0):
        target_dict[LATITUDES_KEY] = target_dict[LATITUDES_KEY][::-1]
        target_dict[TARGET_MATRIX_KEY] = numpy.flip(
            target_dict[TARGET_MATRIX_KEY], axis=1
        )
        target_dict[MASK_MATRIX_KEY] = numpy.flip(
            target_dict[MASK_MATRIX_KEY], axis=0
        )

    if numpy.any(numpy.diff(target_dict[FULL_LATITUDES_KEY]) < 0):
        target_dict[FULL_LATITUDES_KEY] = target_dict[FULL_LATITUDES_KEY][::-1]
        target_dict[FULL_MASK_MATRIX_KEY] = numpy.flip(
            target_dict[FULL_MASK_MATRIX_KEY], axis=0
        )

    return target_dict


def _read_target_file(netcdf_file_name):
    """Reads targets from NetCDF file.

    E = number of examples
    M = number of rows in grid
    N = number of columns in grid

    :param netcdf_file_name: Path to input file.
    :return: target_dict: Dictionary with the following keys.
    target_dict['target_matrix']: E-by-M-by-N numpy array of target values
        (0 or 1), indicating when and where convection occurs.
    target_dict['valid_times_unix_sec']: length-E numpy array of valid times.
    target_dict['latitudes_deg_n']: length-M numpy array of latitudes
        (deg north).
    target_dict['longitudes_deg_e']: length-N numpy array of longitudes
        (deg east).
    target_dict['mask_matrix']: M-by-N numpy array of Boolean flags.  False
        means that the grid cell is masked out.
    """

    if netcdf_file_name.endswith(GZIP_FILE_EXTENSION):
        with gzip.open(netcdf_file_name) as gzip_handle:
            with netCDF4.Dataset(
                    'dummy', mode='r', memory=gzip_handle.read()
            ) as dataset_object:
                return _read_targets(dataset_object)

    dataset_object = netCDF4.Dataset(netcdf_file_name)
    target_dict = _read_targets(dataset_object)
    dataset_object.close()
    return target_dict


def _read_predictors(dataset_object, read_unnormalized, read_normalized,
                     read_unif_normalized):
    """Reads predictors from NetCDF file.

    This method should be called only from `_read_predictor_file`.

    :param dataset_object: Instance of `netCDF4.Dataset`.
    :param read_unnormalized: See doc for `_read_predictor_file`.
    :param read_normalized: Same.
    :param read_unif_normalized: Same.
    :return: predictor_dict: Same.
    """

    predictor_dict = {
        PREDICTOR_MATRIX_UNNORM_KEY: None,
        PREDICTOR_MATRIX_NORM_KEY: None,
        PREDICTOR_MATRIX_UNIF_NORM_KEY: None,
        VALID_TIMES_KEY: dataset_object.variables[VALID_TIMES_KEY][:],
        LATITUDES_KEY: dataset_object.variables[LATITUDES_KEY][:],
        LONGITUDES_KEY: dataset_object.variables[LONGITUDES_KEY][:],
        BAND_NUMBERS_KEY: dataset_object.variables[BAND_NUMBERS_KEY][:],
        NORMALIZATION_FILE_KEY:
            str(getattr(dataset_object, NORMALIZATION_FILE_KEY))
    }

    if read_unnormalized:
        predictor_dict[PREDICTOR_MATRIX_UNNORM_KEY] = (
            dataset_object.variables[PREDICTOR_MATRIX_UNNORM_KEY][:]
        )

    if read_normalized:
        predictor_dict[PREDICTOR_MATRIX_NORM_KEY] = (
            dataset_object.variables[PREDICTOR_MATRIX_NORM_KEY][:]
        )

    if read_unif_normalized:
        predictor_dict[PREDICTOR_MATRIX_UNIF_NORM_KEY] = (
            dataset_object.variables[PREDICTOR_MATRIX_UNIF_NORM_KEY][:]
        )

    return predictor_dict


def _read_predictor_file(netcdf_file_name, read_unnormalized, read_normalized,
                         read_unif_normalized):
    """Reads predictors from NetCDF file.

    E = number of examples per batch
    M = number of rows in grid
    N = number of columns in grid
    C = number of channels (spectral bands)

    :param netcdf_file_name: Path to input file.
    :param read_unnormalized: Boolean flag.  If True, will read unnormalized
        predictors.  If False, key `predictor_matrix_unnorm` in the output
        dictionary will be None.
    :param read_normalized: Boolean flag.  If True, will read normalized
        predictors.  If False, key `predictor_matrix_norm` in the output
        dictionary will be None.
    :param read_unif_normalized: Boolean flag.  If True, will read
        uniformized/normalized predictors.  If False, key
        `predictor_matrix_unif_norm` in the output dictionary will be None.

    :return: predictor_dict: Dictionary with the following keys.
    predictor_dict['predictor_matrix_unnorm']: E-by-M-by-N-by-C numpy array of
        unnormalized predictor values.
    predictor_dict['predictor_matrix_norm']: E-by-M-by-N-by-C numpy array of
        normalized predictor values.
    predictor_dict['predictor_matrix_unif_norm']: E-by-M-by-N-by-C numpy array
        of uniformized, then normalized, predictor values.
    predictor_dict['valid_times_unix_sec']: length-E numpy array of valid times.
    predictor_dict['latitudes_deg_n']: length-M numpy array of latitudes
        (deg north).
    predictor_dict['longitudes_deg_e']: length-N numpy array of longitudes
        (deg east).
    predictor_dict['band_numbers']: length-C numpy array of spectral bands
        (integers).
    """

    if netcdf_file_name.endswith(GZIP_FILE_EXTENSION):
        with gzip.open(netcdf_file_name) as gzip_handle:
            with netCDF4.Dataset(
                    'dummy', mode='r', memory=gzip_handle.read()
            ) as dataset_object:
                return _read_predictors(
                    dataset_object=dataset_object,
                    read_unnormalized=read_unnormalized,
                    read_normalized=read_normalized,
                    read_unif_normalized=read_unif_normalized
                )

    dataset_object = netCDF4.Dataset(netcdf_file_name)

    predictor_dict = _read_predictors(
        dataset_object=dataset_object,
        read_unnormalized=read_unnormalized,
        read_normalized=read_normalized,
        read_unif_normalized=read_unif_normalized
    )

    dataset_object.close()
    return predictor_dict


def _concat_predictor_data(predictor_dicts):
    """Concatenates many dictionaries with predictor data into one.

    :param predictor_dicts: List of dictionaries, each in the format returned by
        `_read_predictor_file`.
    :return: predictor_dict: Single dictionary, also in the format returned by
        `_read_predictor_file`.
    :raises: ValueError: if any two dictionaries have different band numbers,
        latitudes, or longitudes.
    """

    predictor_dict = copy.deepcopy(predictor_dicts[0])
    keys_to_match = [
        BAND_NUMBERS_KEY, LATITUDES_KEY, LONGITUDES_KEY, NORMALIZATION_FILE_KEY
    ]

    for i in range(1, len(predictor_dicts)):
        for this_key in keys_to_match:
            if this_key == BAND_NUMBERS_KEY:
                if numpy.array_equal(
                        predictor_dict[this_key], predictor_dicts[i][this_key]
                ):
                    continue
            elif this_key == NORMALIZATION_FILE_KEY:
                if predictor_dict[this_key] == predictor_dicts[i][this_key]:
                    continue
            else:
                if numpy.allclose(
                        predictor_dict[this_key], predictor_dicts[i][this_key],
                        atol=TOLERANCE
                ):
                    continue

            error_string = (
                '1st and {0:d}th dictionaries have different values for '
                '"{1:s}".  1st dictionary:\n{2:s}\n\n'
                '{0:d}th dictionary:\n{3:s}'
            ).format(
                i + 1, this_key,
                str(predictor_dict[this_key]),
                str(predictor_dicts[i][this_key])
            )

            raise ValueError(error_string)

    for i in range(1, len(predictor_dicts)):
        for this_key in ONE_PER_PREDICTOR_TIME_KEYS:
            if predictor_dict[this_key] is None:
                continue

            predictor_dict[this_key] = numpy.concatenate((
                predictor_dict[this_key], predictor_dicts[i][this_key]
            ), axis=0)

    return predictor_dict


def _subset_by_index(predictor_or_target_dict, desired_indices):
    """Subsets predictor or target data by index.

    :param predictor_or_target_dict: Dictionary with keys listed in either
        `_read_predictor_file` or `_read_target_file`.
    :param desired_indices: 1-D numpy array of desired indices.
    :return: predictor_or_target_dict: Same as input but maybe with fewer
        examples.
    """

    if PREDICTOR_MATRIX_UNNORM_KEY in predictor_or_target_dict:
        expected_keys = ONE_PER_PREDICTOR_TIME_KEYS
    else:
        expected_keys = ONE_PER_TARGET_TIME_KEYS

    for this_key in expected_keys:
        if predictor_or_target_dict[this_key] is None:
            continue

        predictor_or_target_dict[this_key] = (
            predictor_or_target_dict[this_key][desired_indices, ...]
        )

    return predictor_or_target_dict


def _subset_by_time(predictor_or_target_dict, desired_times_unix_sec):
    """Subsets predictor or target data by time.

    T = number of desired times

    :param predictor_or_target_dict: Dictionary with keys listed in either
        `_read_predictor_file` or `_read_target_file`.
    :param desired_times_unix_sec: length-T numpy array of desired times.
    :return: predictor_or_target_dict: Same as input but maybe with fewer
        examples.
    :return: desired_indices: length-T numpy array with indices corresponding to
        desired times.
    """

    desired_indices = numpy.array([
        numpy.where(predictor_or_target_dict[VALID_TIMES_KEY] == t)[0][0]
        for t in desired_times_unix_sec
    ], dtype=int)

    predictor_or_target_dict = _subset_by_index(
        predictor_or_target_dict=predictor_or_target_dict,
        desired_indices=desired_indices
    )

    return predictor_or_target_dict, desired_indices


def _predictor_matrix_to_keras(predictor_matrix, num_lag_times,
                               add_time_dimension):
    """Reshapes predictor matrix into format required by Keras.

    E = number of examples
    M = number of rows in grid
    N = number of columns in grid
    B = number of spectral bands
    L = number of lag times

    :param predictor_matrix: numpy array (EL x M x N x B) of predictors.
    :param num_lag_times: Number of lag times.
    :param add_time_dimension: Boolean flag.  If True, will add time dimension,
        so output will be E x M x N x L x B.  If True, will just reorder data,
        so output will be E x M x N x BL.
    :return: predictor_matrix: numpy array of predictors.
    :raises: ValueError: if length of first axis of predictor matrix is not an
        integer multiple of L.
    """

    # Check for errors.
    first_axis_length = predictor_matrix.shape[0]
    num_examples = float(first_axis_length) / num_lag_times
    this_diff = numpy.absolute(numpy.round(num_examples) - num_examples)

    if this_diff > TOLERANCE:
        error_string = (
            'Length of first axis of predictor matrix ({0:d}) must be an '
            'integer multiple of the number of lag times ({1:d}).'
        ).format(first_axis_length, num_lag_times)

        raise ValueError(error_string)

    # Do actual stuff.
    predictor_matrix_by_lag = [
        predictor_matrix[j::num_lag_times, ...] for j in range(num_lag_times)
    ]

    if add_time_dimension:
        return numpy.stack(predictor_matrix_by_lag, axis=-2)

    num_bands = predictor_matrix.shape[-1]
    predictor_matrix = numpy.stack(predictor_matrix_by_lag, axis=-1)

    num_channels = num_bands * num_lag_times
    these_dim = predictor_matrix.shape[:-2] + (num_channels,)
    return numpy.reshape(predictor_matrix, these_dim)


def _read_inputs_one_day(
        valid_date_string, predictor_file_names,
        normalize, uniformize, target_file_names, lead_time_seconds,
        lag_times_seconds, include_time_dimension, num_examples_to_read,
        fourier_transform_targets, wavelet_transform_targets,
        min_target_resolution_deg, max_target_resolution_deg):
    """Reads inputs (predictor and target files) for one day.

    E = number of examples
    M = number of rows in grid
    N = number of columns in grid

    :param valid_date_string: Valid date (format "yyyymmdd").
    :param predictor_file_names: 1-D list of paths to predictor files (readable
        by `_read_predictor_file`).
    :param normalize: See doc for `_generator_partial_grids`.
    :param uniformize: Same.
    :param target_file_names: 1-D list of paths to target files (readable by
        `_read_target_file`).
    :param lead_time_seconds: See doc for `_generator_partial_grids`.
    :param lag_times_seconds: Same.
    :param include_time_dimension: Same.
    :param num_examples_to_read: Number of examples to read.
    :param fourier_transform_targets: See doc for `_generator_partial_grids`.
    :param wavelet_transform_targets: Same.
    :param min_target_resolution_deg: Same.
    :param max_target_resolution_deg: Same.
    :return: data_dict: Dictionary with the following keys.
    data_dict['predictor_matrix']: See doc for `_generator_partial_grids`.
    data_dict['target_matrix']: Same.
    data_dict['valid_times_unix_sec']: length-E numpy array of valid times.
    """

    uniformize = uniformize and normalize

    target_date_strings = [
        file_name_to_date(f) for f in target_file_names
    ]
    index = target_date_strings.index(valid_date_string)
    desired_target_file_name = target_file_names[index]

    predictor_date_strings = [
        file_name_to_date(f) for f in predictor_file_names
    ]
    index = predictor_date_strings.index(valid_date_string)
    desired_predictor_file_names = [predictor_file_names[index]]

    if lead_time_seconds > 0 or numpy.any(lag_times_seconds > 0):
        desired_predictor_file_names.insert(0, predictor_file_names[index - 1])

    print('Reading data from: "{0:s}"...'.format(desired_target_file_name))
    target_dict = _read_target_file(
        netcdf_file_name=desired_target_file_name
    )

    predictor_dicts = []

    for this_file_name in desired_predictor_file_names:
        print('Reading data from: "{0:s}"...'.format(this_file_name))
        this_predictor_dict = _read_predictor_file(
            netcdf_file_name=this_file_name,
            read_unnormalized=not normalize,
            read_normalized=normalize and not uniformize,
            read_unif_normalized=normalize and uniformize
        )
        predictor_dicts.append(this_predictor_dict)

    predictor_dict = _concat_predictor_data(predictor_dicts)

    assert numpy.allclose(
        target_dict[LATITUDES_KEY],
        predictor_dict[LATITUDES_KEY],
        atol=TOLERANCE
    )

    assert numpy.allclose(
        target_dict[LONGITUDES_KEY],
        predictor_dict[LONGITUDES_KEY],
        atol=TOLERANCE
    )

    valid_times_unix_sec = target_dict[VALID_TIMES_KEY]

    num_valid_times = len(valid_times_unix_sec)
    num_lag_times = len(lag_times_seconds)
    init_time_matrix_unix_sec = numpy.full(
        (num_valid_times, num_lag_times), -1, dtype=int
    )

    for i in range(num_valid_times):
        these_init_times_unix_sec = (
            valid_times_unix_sec[i] - lead_time_seconds - lag_times_seconds
        )

        if not all([
                t in predictor_dict[VALID_TIMES_KEY]
                for t in these_init_times_unix_sec
        ]):
            continue

        init_time_matrix_unix_sec[i, :] = these_init_times_unix_sec

    good_indices = numpy.where(
        numpy.all(init_time_matrix_unix_sec >= 0, axis=1)
    )[0]

    if len(good_indices) == 0:
        return None

    valid_times_unix_sec = valid_times_unix_sec[good_indices]
    init_time_matrix_unix_sec = init_time_matrix_unix_sec[good_indices, :]
    num_examples = len(valid_times_unix_sec)

    if num_examples >= num_examples_to_read:
        desired_indices = numpy.linspace(
            0, num_examples - 1, num=num_examples, dtype=int
        )
        desired_indices = numpy.random.choice(
            desired_indices, size=num_examples_to_read, replace=False
        )

        valid_times_unix_sec = valid_times_unix_sec[desired_indices]
        init_time_matrix_unix_sec = (
            init_time_matrix_unix_sec[desired_indices, :]
        )

    predictor_dict = _subset_by_time(
        predictor_or_target_dict=predictor_dict,
        desired_times_unix_sec=numpy.ravel(init_time_matrix_unix_sec)
    )[0]
    target_dict = _subset_by_time(
        predictor_or_target_dict=target_dict,
        desired_times_unix_sec=valid_times_unix_sec
    )[0]

    if normalize:
        if uniformize:
            predictor_matrix = (
                predictor_dict[PREDICTOR_MATRIX_UNIF_NORM_KEY]
            )
        else:
            predictor_matrix = (
                predictor_dict[PREDICTOR_MATRIX_NORM_KEY]
            )
    else:
        predictor_matrix = (
            predictor_dict[PREDICTOR_MATRIX_UNNORM_KEY]
        )

    predictor_matrix = _predictor_matrix_to_keras(
        predictor_matrix=predictor_matrix, num_lag_times=num_lag_times,
        add_time_dimension=include_time_dimension
    )

    target_matrix = target_dict[TARGET_MATRIX_KEY]
    print('Number of target values in batch = {0:d} ... mean = {1:.3g}'.format(
        target_matrix.size, numpy.mean(target_matrix)
    ))

    if fourier_transform_targets:
        target_matrix = target_matrix.astype(float)
        num_examples = target_matrix.shape[0]

        target_matrix = numpy.stack([
            _fourier_taper_spatial_data(target_matrix[i, ...])
            for i in range(num_examples)
        ], axis=0)

        blackman_matrix = _fourier_apply_blackman_window(
            numpy.ones(target_matrix.shape[1:])
        )
        target_matrix = numpy.stack([
            target_matrix[i, ...] * blackman_matrix for i in range(num_examples)
        ], axis=0)

        target_tensor = tensorflow.constant(
            target_matrix, dtype=tensorflow.complex128
        )
        target_weight_tensor = tensorflow.signal.fft2d(target_tensor)
        target_weight_matrix = K.eval(target_weight_tensor)

        butterworth_matrix = _fourier_apply_butterworth_filter(
            coefficient_matrix=numpy.ones(target_matrix.shape[1:]),
            filter_order=2, grid_spacing_metres=GRID_SPACING_DEG,
            min_resolution_metres=min_target_resolution_deg,
            max_resolution_metres=max_target_resolution_deg
        )

        target_weight_matrix = numpy.stack([
            target_weight_matrix[i, ...] * butterworth_matrix
            for i in range(num_examples)
        ], axis=0)

        target_weight_tensor = tensorflow.constant(
            target_weight_matrix, dtype=tensorflow.complex128
        )
        target_tensor = tensorflow.signal.ifft2d(target_weight_tensor)
        target_tensor = tensorflow.math.real(target_tensor)
        target_matrix = K.eval(target_tensor)

        target_matrix = numpy.stack([
            _fourier_untaper_spatial_data(target_matrix[i, ...])
            for i in range(num_examples)
        ], axis=0)

        target_matrix = numpy.maximum(target_matrix, 0.)
        target_matrix = numpy.minimum(target_matrix, 1.)

        print((
            'Number of target values and mean after Fourier transform = '
            '{0:d}, {1:.3g}'
        ).format(
            target_matrix.size, numpy.mean(target_matrix)
        ))

    if wavelet_transform_targets:
        target_matrix = target_matrix.astype(float)
        target_matrix, padding_arg = _wavelet_taper_spatial_data(
            target_matrix
        )

        coeff_tensor_by_level = _wavelet_do_forward_transform(
            target_matrix
        )
        coeff_tensor_by_level = _wavelet_filter_coefficients(
            coeff_tensor_by_level=coeff_tensor_by_level,
            grid_spacing_metres=GRID_SPACING_DEG,
            min_resolution_metres=min_target_resolution_deg,
            max_resolution_metres=max_target_resolution_deg, verbose=True
        )

        inverse_dwt_object = WaveTFFactory().build('haar', dim=2, inverse=True)
        target_tensor = inverse_dwt_object.call(coeff_tensor_by_level[0])
        target_matrix = K.eval(target_tensor)[..., 0]

        target_matrix = _wavelet_untaper_spatial_data(
            spatial_data_matrix=target_matrix, numpy_pad_width=padding_arg
        )
        target_matrix = numpy.maximum(target_matrix, 0.)
        target_matrix = numpy.minimum(target_matrix, 1.)

        print((
            'Number of target values and mean after wavelet transform = '
            '{0:d}, {1:.3g}'
        ).format(
            target_matrix.size, numpy.mean(target_matrix)
        ))

    return {
        PREDICTOR_MATRIX_KEY: predictor_matrix.astype('float16'),
        TARGET_MATRIX_KEY: numpy.expand_dims(target_matrix, axis=-1),
        VALID_TIMES_KEY: valid_times_unix_sec
    }


def _check_generator_args(option_dict):
    """Error-checks input arguments for generator.

    :param option_dict: See doc for `_generator_partial_grids`.
    :return: option_dict: Same as input, except defaults may have been added.
    """

    orig_option_dict = option_dict.copy()
    option_dict = DEFAULT_GENERATOR_OPTION_DICT.copy()
    option_dict.update(orig_option_dict)

    if option_dict[FOURIER_TRANSFORM_KEY]:
        option_dict[WAVELET_TRANSFORM_KEY] = False

    if option_dict[FOURIER_TRANSFORM_KEY] or option_dict[WAVELET_TRANSFORM_KEY]:
        if option_dict[MAX_TARGET_RESOLUTION_KEY] > MAX_TARGET_RESOLUTION_DEG:
            option_dict[MAX_TARGET_RESOLUTION_KEY] = numpy.inf
    else:
        option_dict[MIN_TARGET_RESOLUTION_KEY] = numpy.nan
        option_dict[MAX_TARGET_RESOLUTION_KEY] = numpy.nan

    return option_dict


def _generator_partial_grids(option_dict):
    """Generates training data on partial, radar-centered grids.

    E = number of examples per batch
    M = number of rows in grid
    N = number of columns in grid
    L = number of lag times
    B = number of spectral bands

    :param option_dict: Dictionary with the following keys.
    option_dict['top_predictor_dir_name']: Name of top-level directory with
        predictors.  Files therein will be found by `_find_predictor_file` and
        read by `_read_predictor_file`.
    option_dict['top_target_dir_name']: Name of top-level directory with
        targets.  Files therein will be found by `_find_target_file` and read
        by `_read_target_file`.
    option_dict['num_examples_per_batch']: Batch size.
    option_dict['max_examples_per_day_in_batch']: Max number of examples from
        the same day in one batch.
    option_dict['lead_time_seconds']: Lead time (valid time minus forecast
        time).
    option_dict['lag_times_seconds']: 1-D numpy array of lag times.  Each lag
        time is forecast time minus predictor time, so must be >= 0.
    option_dict['include_time_dimension']: Boolean flag.  If True, predictor
        matrix will include a time dimension.  If False, times and spectral
        bands will be combined into the last axis.
    option_dict['first_valid_date_string']: First valid date (format
        "yyyymmdd").  Will not generate examples with earlier valid times.
    option_dict['last_valid_date_string']: Last valid date (format
        "yyyymmdd").  Will not generate examples with later valid times.
    option_dict['normalize']: Boolean flag.  If True (False), will use
        normalized (unnormalized) predictors.
    option_dict['uniformize']: [used only if `normalize == True`]
        Boolean flag.  If True, will use uniformized and normalized predictors.
        If False, will use only normalized predictors.
    option_dict['fourier_transform_targets']: Boolean flag.  If True, will use
        Fourier transform to apply band-pass filter to targets.
    option_dict['wavelet_transform_targets']: Boolean flag.  If True, will use
        wavelet transform to apply band-pass filter to targets.
    option_dict['min_target_resolution_deg']: Minimum resolution (degrees) to
        allow through band-pass filter.
    option_dict['max_target_resolution_deg']: Max resolution (degrees) to allow
        through band-pass filter.

    :return: predictor_matrix: numpy array (E x M x N x LB or E x M x N x L x B)
        of predictor values, based on satellite data.
    :return: target_matrix: E-by-M-by-N-by-1 numpy array of target values
        (floats in 0...1, indicating whether or not convection occurs at
        the given lead time).
    :raises: ValueError: if no valid date can be found for which predictors and
        targets are available.
    """

    option_dict = _check_generator_args(option_dict)

    top_predictor_dir_name = option_dict[PREDICTOR_DIRECTORY_KEY]
    top_target_dir_name = option_dict[TARGET_DIRECTORY_KEY]
    num_examples_per_batch = option_dict[BATCH_SIZE_KEY]
    max_examples_per_day_in_batch = option_dict[MAX_DAILY_EXAMPLES_KEY]
    lead_time_seconds = option_dict[LEAD_TIME_KEY]
    lag_times_seconds = option_dict[LAG_TIMES_KEY]
    include_time_dimension = option_dict[INCLUDE_TIME_DIM_KEY]
    first_valid_date_string = option_dict[FIRST_VALID_DATE_KEY]
    last_valid_date_string = option_dict[LAST_VALID_DATE_KEY]
    normalize = option_dict[NORMALIZE_FLAG_KEY]
    uniformize = option_dict[UNIFORMIZE_FLAG_KEY]
    fourier_transform_targets = option_dict[FOURIER_TRANSFORM_KEY]
    wavelet_transform_targets = option_dict[WAVELET_TRANSFORM_KEY]
    min_target_resolution_deg = option_dict[MIN_TARGET_RESOLUTION_KEY]
    max_target_resolution_deg = option_dict[MAX_TARGET_RESOLUTION_KEY]

    if lead_time_seconds > 0 or numpy.any(lag_times_seconds > 0):
        first_init_date_string = _get_previous_date(first_valid_date_string)
    else:
        first_init_date_string = copy.deepcopy(first_valid_date_string)

    these_predictor_file_names = _find_many_predictor_files(
        top_directory_name=top_predictor_dir_name,
        first_date_string=first_init_date_string,
        last_date_string=last_valid_date_string,
        radar_number=1, prefer_zipped=False, allow_other_format=True,
        raise_error_if_any_missing=False
    )
    these_target_file_names = _find_many_target_files(
        top_directory_name=top_target_dir_name,
        first_date_string=first_init_date_string,
        last_date_string=last_valid_date_string,
        radar_number=1, prefer_zipped=False, allow_other_format=True,
        raise_error_if_any_missing=False
    )

    predictor_date_strings = [
        file_name_to_date(f) for f in these_predictor_file_names
    ]
    target_date_strings = [
        file_name_to_date(f) for f in these_target_file_names
    ]

    predictor_file_name_matrix = numpy.full(
        (len(these_predictor_file_names), NUM_RADARS), '', dtype=object
    )
    target_file_name_matrix = numpy.full(
        (len(these_target_file_names), NUM_RADARS), '', dtype=object
    )
    predictor_file_name_matrix[:, 1] = numpy.array(these_predictor_file_names)
    target_file_name_matrix[:, 1] = numpy.array(these_target_file_names)

    valid_date_strings = _find_days_with_both_inputs(
        predictor_file_names=predictor_file_name_matrix[:, 1],
        target_file_names=target_file_name_matrix[:, 1],
        lead_time_seconds=lead_time_seconds,
        lag_times_seconds=lag_times_seconds
    )

    if len(valid_date_strings) == 0:
        raise ValueError(
            'Cannot find any valid date for which both predictors and targets '
            'are available.'
        )
    
    for k in range(1, NUM_RADARS, 2):
        these_predictor_file_names = [
            _find_predictor_file(
                top_directory_name=top_predictor_dir_name, date_string=d,
                radar_number=k, prefer_zipped=False, allow_other_format=True,
                raise_error_if_missing=True
            ) for d in predictor_date_strings
        ]

        these_target_file_names = [
            _find_target_file(
                top_directory_name=top_target_dir_name, date_string=d,
                radar_number=k, prefer_zipped=False, allow_other_format=True,
                raise_error_if_missing=True
            ) for d in target_date_strings
        ]

        predictor_file_name_matrix[:, k] = numpy.array(
            these_predictor_file_names
        )
        target_file_name_matrix[:, k] = numpy.array(these_target_file_names)

    predictor_file_name_matrix = predictor_file_name_matrix[:, [1, 3]]
    target_file_name_matrix = target_file_name_matrix[:, [1, 3]]

    num_radars = predictor_file_name_matrix.shape[1]
    radar_indices = numpy.linspace(0, num_radars - 1, num=num_radars, dtype=int)
    date_indices = numpy.linspace(
        0, len(valid_date_strings) - 1, num=len(valid_date_strings), dtype=int
    )

    date_index_matrix, radar_index_matrix = numpy.meshgrid(
        date_indices, radar_indices
    )
    date_indices_1d = numpy.ravel(date_index_matrix)
    radar_indices_1d = numpy.ravel(radar_index_matrix)

    random_indices = numpy.linspace(
        0, len(radar_indices_1d) - 1, num=len(radar_indices_1d), dtype=int
    )
    numpy.random.shuffle(random_indices)
    date_indices_1d = date_indices_1d[random_indices]
    radar_indices_1d = radar_indices_1d[random_indices]

    current_index = 0

    while True:
        predictor_matrix = None
        target_matrix = None
        num_examples_in_memory = 0

        while num_examples_in_memory < num_examples_per_batch:
            if current_index == len(radar_indices_1d):
                current_index = 0

            num_examples_to_read = min([
                max_examples_per_day_in_batch,
                num_examples_per_batch - num_examples_in_memory
            ])

            current_date_index = date_indices_1d[current_index]
            current_radar_index = radar_indices_1d[current_index]

            this_data_dict = _read_inputs_one_day(
                valid_date_string=valid_date_strings[current_date_index],
                predictor_file_names=
                predictor_file_name_matrix[:, current_radar_index],
                normalize=normalize, uniformize=uniformize,
                target_file_names=
                target_file_name_matrix[:, current_radar_index],
                lead_time_seconds=lead_time_seconds,
                lag_times_seconds=lag_times_seconds,
                include_time_dimension=include_time_dimension,
                num_examples_to_read=num_examples_to_read,
                fourier_transform_targets=fourier_transform_targets,
                wavelet_transform_targets=wavelet_transform_targets,
                min_target_resolution_deg=min_target_resolution_deg,
                max_target_resolution_deg=max_target_resolution_deg
            )

            current_index += 1
            if this_data_dict is None:
                continue

            this_predictor_matrix = this_data_dict[PREDICTOR_MATRIX_KEY]
            this_target_matrix = this_data_dict[TARGET_MATRIX_KEY]

            if predictor_matrix is None:
                predictor_matrix = this_predictor_matrix + 0.
                target_matrix = this_target_matrix + 0
            else:
                predictor_matrix = numpy.concatenate(
                    (predictor_matrix, this_predictor_matrix), axis=0
                )
                target_matrix = numpy.concatenate(
                    (target_matrix, this_target_matrix), axis=0
                )

            num_examples_in_memory = predictor_matrix.shape[0]

        predictor_matrix = predictor_matrix.astype('float16')
        target_matrix = target_matrix.astype('float16')
        yield predictor_matrix, target_matrix


def _find_nn_metafile(model_file_name, raise_error_if_missing=True):
    """Finds metafile for neural net.

    :param model_file_name: Path to trained model.
    :param raise_error_if_missing: Boolean flag.  If file is missing and
        `raise_error_if_missing == True`, will throw error.  If file is missing
        and `raise_error_if_missing == False`, will return *expected* file path.
    :return: metafile_name: Path to metafile.
    """

    metafile_name = '{0:s}/model_metadata.dill'.format(
        os.path.split(model_file_name)[0]
    )

    if raise_error_if_missing and not os.path.isfile(metafile_name):
        error_string = 'Cannot find file.  Expected at: "{0:s}"'.format(
            metafile_name
        )
        raise ValueError(error_string)

    return metafile_name


def _read_nn_metafile(dill_file_name):
    """Reads metadata for neural net from Dill file.

    :param dill_file_name: Path to input file.
    :return: metadata_dict: Dictionary with the following keys.
    metadata_dict['num_epochs']: See doc for `train_model`.
    metadata_dict['num_training_batches_per_epoch']: Same.
    metadata_dict['training_option_dict']: Same.
    metadata_dict['num_validation_batches_per_epoch']: Same.
    metadata_dict['validation_option_dict']: Same.
    metadata_dict['do_early_stopping']: Same.
    metadata_dict['plateau_lr_multiplier']: Same.
    metadata_dict['loss_function_name']: Same.
    metadata_dict['metric_names']: Same.
    metadata_dict['mask_matrix']: Same.

    :raises: ValueError: if any expected key is not found in dictionary.
    """

    dill_file_handle = open(dill_file_name, 'rb')
    metadata_dict = dill.load(dill_file_handle)
    dill_file_handle.close()

    training_option_dict = metadata_dict[TRAINING_OPTIONS_KEY]
    validation_option_dict = metadata_dict[VALIDATION_OPTIONS_KEY]

    missing_keys = list(set(METADATA_KEYS) - set(metadata_dict.keys()))
    if len(missing_keys) == 0:
        return metadata_dict

    error_string = (
        '\n{0:s}\nKeys listed above were expected, but not found, in file '
        '"{1:s}".'
    ).format(str(missing_keys), dill_file_name)

    raise ValueError(error_string)


def _metric_params_to_name(score_name, half_window_size_px):
    """Converts parameters for evaluation metric to name.

    :param score_name: Name of score.
    :param half_window_size_px: Half-window size (pixels) for neighbourhood.
    :return: metric_name: Metric name (string).
    """

    half_window_size_px = int(numpy.round(half_window_size_px))
    return '{0:s}_neigh{1:d}'.format(score_name, half_window_size_px)


def _metric_name_to_params(metric_name):
    """Converts name of evaluation metric to parameters.

    This method is the inverse of `_metric_params_to_name`.

    :param metric_name: Metric name (string).
    :return: param_dict: Dictionary with the following keys.
    param_dict['score_name']: See doc for `_metric_params_to_name`.
    param_dict['half_window_size_px']: Same.
    """

    metric_name_parts = metric_name.split('_')
    score_name = metric_name_parts[0]

    assert len(metric_name_parts) == 2
    assert metric_name_parts[1].startswith('neigh')
    half_window_size_px = int(
        metric_name_parts[1].replace('neigh', '')
    )

    return {
        SCORE_NAME_KEY: score_name,
        HALF_WINDOW_SIZE_KEY: half_window_size_px
    }

# <font color='red'>Define Fourier/wavelet methods (required)</font>

The next cell defines methods used to achieve Fourier and wavelet decomposition.  These are also private methods, but I have put them in a different code cell because they have spatial importance -- *i.e.*, they are key to implementing our spatially enhanced loss functions.

In [None]:
def _get_spatial_resolutions(num_grid_rows, num_grid_columns,
                             grid_spacing_metres):
    """Computes spatial resolution for each Fourier coefficient.

    M = number of rows in spatial grid
    N = number of columns in spatial grid

    Matrices returned by this method correspond to matrices of Fourier
    coefficients returned by `numpy.fft.fft2`.  The x-coordinate increases with
    column index, and the y-coordinate increases with row index.

    :param num_grid_rows: M in the above discussion.
    :param num_grid_columns: N in the above discussion.
    :param grid_spacing_metres: Grid spacing (for which I use "resolution" as a
        synonym).
    :return: x_resolution_matrix_metres: M-by-N numpy array of resolutions in
        x-direction.
    :return: y_resolution_matrix_metres: Same but for y-direction.
    """

    num_half_rows_float = float(num_grid_rows - 1) / 2
    num_half_rows = int(numpy.round(num_half_rows_float))
    assert numpy.isclose(num_half_rows, num_half_rows_float, atol=TOLERANCE)

    num_half_columns_float = float(num_grid_columns - 1) / 2
    num_half_columns = int(numpy.round(num_half_columns_float))
    assert numpy.isclose(
        num_half_columns, num_half_columns_float, atol=TOLERANCE
    )

    # Find resolutions in x-direction.
    unique_x_wavenumbers = numpy.linspace(
        0, num_half_columns, num=num_half_columns + 1, dtype=int
    )
    x_wavenumbers = numpy.concatenate((
        unique_x_wavenumbers, unique_x_wavenumbers[1:][::-1]
    ))
    x_wavenumber_matrix = numpy.expand_dims(x_wavenumbers, axis=0)
    x_wavenumber_matrix = numpy.repeat(
        x_wavenumber_matrix, axis=0, repeats=num_grid_rows
    )

    x_grid_length_metres = grid_spacing_metres * (num_grid_columns - 1)
    x_resolution_matrix_metres = (
        0.5 * x_grid_length_metres / x_wavenumber_matrix
    )

    # Find resolutions in y-direction.
    unique_y_wavenumbers = numpy.linspace(
        0, num_half_rows, num=num_half_rows + 1, dtype=int
    )
    y_wavenumbers = numpy.concatenate((
        unique_y_wavenumbers, unique_y_wavenumbers[1:][::-1]
    ))
    y_wavenumber_matrix = numpy.expand_dims(y_wavenumbers, axis=1)
    y_wavenumber_matrix = numpy.repeat(
        y_wavenumber_matrix, axis=1, repeats=num_grid_columns
    )

    y_grid_length_metres = grid_spacing_metres * (num_grid_rows - 1)
    y_resolution_matrix_metres = (
        0.5 * y_grid_length_metres / y_wavenumber_matrix
    )

    return x_resolution_matrix_metres, y_resolution_matrix_metres


def _fourier_apply_butterworth_filter(
        coefficient_matrix, filter_order, grid_spacing_metres,
        min_resolution_metres, max_resolution_metres):
    """Applies Butterworth band-pass filter to Fourier coefficients.

    :param coefficient_matrix: M-by-N numpy array of coefficients in format
        returned by `numpy.fft.fft2`.
    :param filter_order: Order of Butterworth filter (same as input arg `N` for
        `scipy.signal.butter`).
    :param grid_spacing_metres: Grid spacing (resolution).
    :param min_resolution_metres: Minimum resolution to preserve.
    :param max_resolution_metres: Max resolution to preserve.
    :return: coefficient_matrix: Same as input but after filtering.
    """

    # Determine horizontal, vertical, and total wavenumber for each Fourier
    # coefficient.
    x_resolution_matrix_metres, y_resolution_matrix_metres = (
        _get_spatial_resolutions(
            num_grid_rows=coefficient_matrix.shape[0],
            num_grid_columns=coefficient_matrix.shape[1],
            grid_spacing_metres=grid_spacing_metres
        )
    )

    x_wavenumber_matrix_metres01 = (2 * x_resolution_matrix_metres) ** -1
    y_wavenumber_matrix_metres01 = (2 * y_resolution_matrix_metres) ** -1
    wavenumber_matrix_metres01 = numpy.sqrt(
        x_wavenumber_matrix_metres01 ** 2 + y_wavenumber_matrix_metres01 ** 2
    )

    # High-pass part.
    if not numpy.isinf(max_resolution_metres):
        min_wavenumber_metres01 = (2 * max_resolution_metres) ** -1
        ratio_matrix = wavenumber_matrix_metres01 / min_wavenumber_metres01
        gain_matrix = 1 - (1 + ratio_matrix ** (2 * filter_order)) ** -1
        coefficient_matrix = coefficient_matrix * gain_matrix

    # Low-pass part.
    if min_resolution_metres > grid_spacing_metres:
        max_wavenumber_metres01 = (2 * min_resolution_metres) ** -1
        ratio_matrix = wavenumber_matrix_metres01 / max_wavenumber_metres01
        gain_matrix = (1 + ratio_matrix ** (2 * filter_order)) ** -1
        coefficient_matrix = coefficient_matrix * gain_matrix

    return coefficient_matrix


def _fourier_taper_spatial_data(spatial_data_matrix):
    """Tapers spatial data by putting zeros along the edge.

    M = number of rows in grid
    N = number of columns in grid

    :param spatial_data_matrix: M-by-N numpy array of real numbers.
    :return: spatial_data_matrix: Same but after tapering.
    """

    num_rows = spatial_data_matrix.shape[0]
    num_columns = spatial_data_matrix.shape[1]

    padding_arg = (
        (num_rows, num_rows),
        (num_columns, num_columns)
    )

    spatial_data_matrix = numpy.pad(
        spatial_data_matrix, pad_width=padding_arg, mode='constant',
        constant_values=0.
    )

    return spatial_data_matrix


def _fourier_untaper_spatial_data(spatial_data_matrix):
    """Removes zeros along the edge of spatial data.

    This method is the inverse of `_fourier_taper_spatial_data`.

    :param spatial_data_matrix: See output doc for
        `_fourier_taper_spatial_data`.
    :return: spatial_data_matrix: See input doc for
        `_fourier_taper_spatial_data`.
    """

    num_rows = spatial_data_matrix.shape[0]
    num_columns = spatial_data_matrix.shape[1]

    num_third_rows_float = float(num_rows) / 3
    num_third_rows = int(numpy.round(num_third_rows_float))
    assert numpy.isclose(num_third_rows, num_third_rows_float, atol=TOLERANCE)

    num_third_columns_float = float(num_columns) / 3
    num_third_columns = int(numpy.round(num_third_columns_float))
    assert numpy.isclose(
        num_third_columns, num_third_columns_float, atol=TOLERANCE
    )

    return spatial_data_matrix[
        num_third_rows:-num_third_rows,
        num_third_columns:-num_third_columns
    ]


def _fourier_apply_blackman_window(spatial_data_matrix):
    """Applies Blackman window to 2-D spatial data.

    M = number of rows in grid
    N = number of columns in grid

    :param spatial_data_matrix: M-by-N numpy array of real numbers.
    :return: spatial_data_matrix: Same but after smoothing via Blackman window.
    """

    num_rows = spatial_data_matrix.shape[0]
    num_columns = spatial_data_matrix.shape[1]
    num_half_rows = float(num_rows - 1) / 2
    num_half_columns = float(num_columns - 1) / 2

    row_indices = numpy.linspace(0, num_rows - 1, num=num_rows, dtype=float)
    column_indices = numpy.linspace(
        0, num_columns - 1, num=num_columns, dtype=float
    )

    y_distances = numpy.absolute(row_indices - num_half_rows)
    x_distances = numpy.absolute(column_indices - num_half_columns)
    x_distance_matrix, y_distance_matrix = numpy.meshgrid(
        x_distances, y_distances
    )

    distance_matrix = numpy.sqrt(
        x_distance_matrix ** 2 + y_distance_matrix ** 2
    )
    max_distance = numpy.maximum(
        numpy.max(x_distance_matrix),
        numpy.max(y_distance_matrix)
    )
    fractional_distance_matrix = distance_matrix / max_distance
    fractional_distance_matrix = numpy.minimum(fractional_distance_matrix, 1.)

    weight_matrix = (
        0.42 -
        0.5 * numpy.cos(numpy.pi * (1 + fractional_distance_matrix)) +
        0.08 * numpy.cos(2 * numpy.pi * (1 + fractional_distance_matrix))
    )

    return spatial_data_matrix * weight_matrix


def _wavelet_taper_spatial_data(spatial_data_matrix):
    """Tapers spatial data by putting zeros along the edge.

    E = number of examples
    M = number of rows in grid
    N = number of columns in grid

    :param spatial_data_matrix: E-by-M-by-N numpy array of real numbers.
    :return: spatial_data_matrix: Same but after tapering.
    :return: numpy_pad_width: Argument `pad_width` used for `numpy.pad`.
    """

    num_rowcols = max(spatial_data_matrix.shape[1:])
    num_transform_levels = int(numpy.ceil(
        numpy.log2(num_rowcols)
    ))
    num_rowcols_needed = int(numpy.round(
        2 ** num_transform_levels
    ))

    num_rows = spatial_data_matrix.shape[1]
    num_columns = spatial_data_matrix.shape[2]
    num_padding_rows = num_rowcols_needed - num_rows
    num_padding_columns = num_rowcols_needed - num_columns

    if numpy.mod(num_padding_rows, 2) == 0:
        num_start_rows = int(numpy.round(
            float(num_padding_rows) / 2
        ))
        num_end_rows = num_start_rows + 0
    else:
        num_start_rows = int(numpy.floor(
            float(num_padding_rows) / 2
        ))
        num_end_rows = num_start_rows + 1

    if numpy.mod(num_padding_columns, 2) == 0:
        num_start_columns = int(numpy.round(
            float(num_padding_columns) / 2
        ))
        num_end_columns = num_start_columns + 0
    else:
        num_start_columns = int(numpy.floor(
            float(num_padding_columns) / 2
        ))
        num_end_columns = num_start_columns + 1

    padding_arg = (
        (0, 0),
        (num_start_rows, num_end_rows),
        (num_start_columns, num_end_columns)
    )

    spatial_data_matrix = numpy.pad(
        spatial_data_matrix, pad_width=padding_arg, mode='constant',
        constant_values=0.
    )

    return spatial_data_matrix, padding_arg


def _wavelet_untaper_spatial_data(spatial_data_matrix, numpy_pad_width):
    """Removes zeros along the edge of spatial data.

    This method is the inverse of `_wavelet_taper_spatial_data`.

    :param spatial_data_matrix: See output doc for
        `_wavelet_taper_spatial_data`.
    :param numpy_pad_width: Same.
    :return: spatial_data_matrix: See input doc for
        `_wavelet_taper_spatial_data`.
    """

    return spatial_data_matrix[
        :,
        numpy_pad_width[1][0]:-numpy_pad_width[1][-1],
        numpy_pad_width[2][0]:-numpy_pad_width[2][-1]
    ]


def _wavelet_do_forward_transform(spatial_data_matrix):
    """Does forward multi-level wavelet transform.

    E = number of examples
    N = number of rows in grid = number of columns in grid
    K = number of levels in wavelet transform = log_2(N)

    :param spatial_data_matrix: E-by-N-by-N numpy array of real numbers.
    :return: coeff_tensor_by_level: length-K list of tensors, each containing
        coefficients in format returned by WaveTF library.
    """

    num_levels = int(numpy.round(
        numpy.log2(spatial_data_matrix.shape[1])
    ))

    spatial_data_tensor = tensorflow.constant(
        spatial_data_matrix, dtype=tensorflow.float64
    )
    spatial_data_tensor = tensorflow.expand_dims(spatial_data_tensor, axis=-1)

    dwt_object = WaveTFFactory().build('haar', dim=2)
    coeff_tensor_by_level = [None] * num_levels

    for k in range(num_levels):
        if k == 0:
            coeff_tensor_by_level[k] = dwt_object.call(spatial_data_tensor)
        else:
            coeff_tensor_by_level[k] = dwt_object.call(
                coeff_tensor_by_level[k - 1][..., :1]
            )

    return coeff_tensor_by_level


def _wavelet_filter_coefficients(
        coeff_tensor_by_level, grid_spacing_metres, min_resolution_metres,
        max_resolution_metres, verbose=True):
    """Filters wavelet coeffs (zeroes out coeffs at undesired wavelengths).

    :param coeff_tensor_by_level: See documentation for
        `_wavelet_do_forward_transform`.
    :param grid_spacing_metres: Grid spacing (resolution).
    :param min_resolution_metres: Minimum resolution to preserve.
    :param max_resolution_metres: Max resolution to preserve.
    :param verbose: Boolean flag.
    :return: coeff_tensor_by_level: Same as input but maybe with more zeros.
    """

    inverse_dwt_object = WaveTFFactory().build('haar', dim=2, inverse=True)
    num_levels = len(coeff_tensor_by_level)

    level_indices = numpy.linspace(0, num_levels - 1, num=num_levels, dtype=int)
    detail_res_by_level_metres = grid_spacing_metres * (2 ** level_indices)
    mean_res_by_level_metres = grid_spacing_metres * (2 ** (level_indices + 1))

    max_index = numpy.searchsorted(
        a=mean_res_by_level_metres, v=max_resolution_metres, side='right'
    )
    min_index = -1 + numpy.searchsorted(
        a=detail_res_by_level_metres, v=min_resolution_metres, side='left'
    )

    if max_index < num_levels:
        k = num_levels - 1

        while k >= max_index:
            if verbose:
                print((
                    'Zeroing out low-frequency coefficients at level '
                    '{0:d} of {1:d} (resolutions = {2:.4f} and {3:.4f} deg)...'
                ).format(
                    k + 1, num_levels,
                    mean_res_by_level_metres[k],
                    detail_res_by_level_metres[k]
                ))

            coeff_tensor_by_level[k] = tensorflow.concat([
                tensorflow.zeros_like(coeff_tensor_by_level[k][..., :1]),
                coeff_tensor_by_level[k][..., 1:]
            ], axis=-1)

            k -= 1

        k = max_index + 0

        while k > 0 and k > min_index:
            if verbose:
                print((
                    'Reconstructing low-frequency coefficients at level '
                    '{0:d} of {1:d} (resolutions = {2:.4f} and {3:.4f} deg)...'
                ).format(
                    k, num_levels,
                    mean_res_by_level_metres[k - 1],
                    detail_res_by_level_metres[k - 1]
                ))

            coeff_tensor_by_level[k - 1] = tensorflow.concat([
                inverse_dwt_object.call(coeff_tensor_by_level[k]),
                coeff_tensor_by_level[k - 1][..., 1:]
            ], axis=-1)

            k -= 1

    if min_index > 0:
        if verbose:
            print((
                'Zeroing out high-frequency coefficients at level '
                '{0:d} of {1:d} (resolutions = {2:.4f} and {3:.4f} deg)...'
            ).format(
                min_index + 1, num_levels,
                mean_res_by_level_metres[min_index],
                detail_res_by_level_metres[min_index]
            ))

        coeff_tensor_by_level[min_index] = tensorflow.concat([
            coeff_tensor_by_level[min_index][..., :1],
            tensorflow.zeros_like(coeff_tensor_by_level[min_index][..., 1:])
        ], axis=-1)

    k = min_index + 0

    while k > 0:
        if verbose:
            print((
                'Reconstructing low-frequency coefficients at level '
                '{0:d} of {1:d} (resolutions = {2:.4f} and {3:.4f} deg)...'
            ).format(
                k, num_levels,
                mean_res_by_level_metres[k - 1],
                detail_res_by_level_metres[k - 1]
            ))

        coeff_tensor_by_level[k - 1] = tensorflow.concat([
            inverse_dwt_object.call(coeff_tensor_by_level[k]),
            coeff_tensor_by_level[k - 1][..., 1:]
        ], axis=-1)

        if verbose:
            print((
                'Zeroing out high-frequency coefficients at level '
                '{0:d} of {1:d} (resolutions = {2:.4f} and {3:.4f} deg)...'
            ).format(
                k, num_levels,
                mean_res_by_level_metres[k - 1],
                detail_res_by_level_metres[k - 1]
            ))

        coeff_tensor_by_level[k - 1] = tensorflow.concat([
            coeff_tensor_by_level[k - 1][..., :1],
            tensorflow.zeros_like(coeff_tensor_by_level[k - 1][..., 1:])
        ], axis=-1)

        k -= 1

    return coeff_tensor_by_level

# <font color='red'>Define loss functions (required)</font>

The next cell defines all loss functions.

In [None]:
def _log2(input_tensor):
    """Computes logarithm in base 2.

    :param input_tensor: Keras tensor.
    :return: logarithm_tensor: Keras tensor with the same shape as
        `input_tensor`.
    """

    return K.log(K.maximum(input_tensor, 1e-6)) / K.log(2.)


def _create_mean_filter(half_num_rows, half_num_columns, num_channels):
    """Creates convolutional filter that computes mean.

    M = number of rows in filter
    N = number of columns in filter
    C = number of channels

    :param half_num_rows: Number of rows on either side of center.  This is
        (M - 1) / 2.
    :param half_num_columns: Number of columns on either side of center.  This
        is (N - 1) / 2.
    :param num_channels: Number of channels.
    :return: weight_matrix: M-by-N-by-C-by-C numpy array of filter weights.
    """

    num_rows = 2 * half_num_rows + 1
    num_columns = 2 * half_num_columns + 1
    weight = 1. / (num_rows * num_columns)

    return numpy.full(
        (num_rows, num_columns, num_channels, num_channels), weight,
        dtype=numpy.float32
    )


def _erode_mask(mask_matrix, half_window_size_px):
    """Erodes binary mask.

    :param mask_matrix: See doc for `pod`.
    :param half_window_size_px: Same.
    :return: eroded_mask_matrix: Eroded version of input.
    """

    window_size_px = 2 * half_window_size_px + 1
    structure_matrix = numpy.full(
        (window_size_px, window_size_px), 1, dtype=bool
    )

    eroded_mask_matrix = binary_erosion(
        mask_matrix.astype(int), structure=structure_matrix, iterations=1,
        border_value=1
    )

    return numpy.expand_dims(
        eroded_mask_matrix.astype(float), axis=(0, -1)
    )


def _apply_max_filter(input_tensor, half_window_size_px):
    """Applies maximum-filter to tensor.

    :param input_tensor: Keras tensor.
    :param half_window_size_px: Number of pixels in half of filter window (on
        either side of center).  If this argument is K, the window size will be
        (1 + 2 * K) by (1 + 2 * K).
    :return: output_tensor: Filtered version of `input_tensor`.
    """

    window_size_px = 2 * half_window_size_px + 1

    return K.pool2d(
        x=input_tensor, pool_mode='max',
        pool_size=(window_size_px, window_size_px), strides=(1, 1),
        padding='same', data_format='channels_last'
    )


def cross_entropy(mask_matrix, function_name=None):
    """Cross-entropy.

    M = number of rows in grid
    N = number of columns in grid

    :param mask_matrix: M-by-N numpy array of Boolean flags.  Grid cells marked
        "False" are masked out and not used to compute the loss.
    :param function_name: Function name (string).
    :return: loss: Loss function (defined below).
    """

    mask_matrix_4d = copy.deepcopy(mask_matrix)
    mask_matrix_4d = numpy.expand_dims(
        mask_matrix_4d.astype(float), axis=(0, -1)
    )

    def loss(target_tensor, prediction_tensor):
        """Computes loss (cross-entropy).

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: loss: Fractions skill score.
        """

        filtered_target_tensor = target_tensor * mask_matrix_4d
        filtered_prediction_tensor = prediction_tensor * mask_matrix_4d

        xentropy_tensor = (
            filtered_target_tensor * _log2(filtered_prediction_tensor) +
            (1. - filtered_target_tensor) *
            _log2(1. - filtered_prediction_tensor)
        )

        return -K.mean(xentropy_tensor)

    if function_name is not None:
        loss.__name__ = function_name

    return loss


def fractions_skill_score(
        half_window_size_px, use_as_loss_function, mask_matrix,
        function_name=None, test_mode=False):
    """Fractions skill score (FSS).

    M = number of rows in grid
    N = number of columns in grid

    :param half_window_size_px: Number of pixels (grid cells) in half of
        smoothing window (on either side of center).  If this argument is K, the
        window size will be (1 + 2 * K) by (1 + 2 * K).
    :param use_as_loss_function: Boolean flag.  FSS is positively oriented
        (higher is better), but if using it as loss function, we want it to be
        negatively oriented.  Thus, if `use_as_loss_function == True`, will
        return 1 - FSS.  If `use_as_loss_function == False`, will return just
        FSS.
    :param mask_matrix: M-by-N numpy array of Boolean flags.  Grid cells marked
        "False" are masked out and not used to compute the loss.
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: loss: Loss function (defined below).
    """

    # TODO(thunderhoser): Allow multiple channels.

    weight_matrix = _create_mean_filter(
        half_num_rows=half_window_size_px,
        half_num_columns=half_window_size_px, num_channels=1
    )

    if test_mode:
        eroded_mask_matrix = copy.deepcopy(mask_matrix)
    else:
        eroded_mask_matrix = _erode_mask(
            mask_matrix=copy.deepcopy(mask_matrix),
            half_window_size_px=half_window_size_px
        )

    # eroded_mask_tensor = K.variable(eroded_mask_matrix)

    def loss(target_tensor, prediction_tensor):
        """Computes loss (fractions skill score).

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: loss: Fractions skill score.
        """

        smoothed_target_tensor = K.conv2d(
            x=target_tensor, kernel=weight_matrix,
            padding='same', strides=(1, 1), data_format='channels_last'
        )

        smoothed_prediction_tensor = K.conv2d(
            x=prediction_tensor, kernel=weight_matrix,
            padding='same', strides=(1, 1), data_format='channels_last'
        )

        smoothed_target_tensor = smoothed_target_tensor * eroded_mask_matrix
        smoothed_prediction_tensor = (
            smoothed_prediction_tensor * eroded_mask_matrix
        )

        actual_mse = K.mean(
            (smoothed_target_tensor - smoothed_prediction_tensor) ** 2
        )
        reference_mse = K.mean(
            smoothed_target_tensor ** 2 + smoothed_prediction_tensor ** 2
        )

        if use_as_loss_function:
            return actual_mse / reference_mse

        return 1. - actual_mse / reference_mse

    if function_name is not None:
        loss.__name__ = function_name

    return loss


def heidke_score(mask_matrix, use_as_loss_function, function_name=None):
    """Creates function to compute Heidke score at given scale.

    :param mask_matrix: See doc for `fractions_skill_score`.
    :param use_as_loss_function: Same.
    :param function_name: Same.
    :return: heidke_function: Function (defined below).
    """

    mask_matrix = numpy.expand_dims(mask_matrix, axis=0).astype(float)

    def heidke_function(target_tensor, prediction_tensor):
        """Computes Heidke score at a given scale.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: heidke_value: Heidke score (scalar).
        """

        num_true_positives = K.sum(
            mask_matrix * target_tensor * prediction_tensor
        )
        num_false_positives = K.sum(
            mask_matrix * (1 - target_tensor) * prediction_tensor
        )
        num_false_negatives = K.sum(
            mask_matrix * target_tensor * (1 - prediction_tensor)
        )
        num_true_negatives = K.sum(
            mask_matrix * (1 - target_tensor) * (1 - prediction_tensor)
        )

        random_num_correct = (
            (num_true_positives + num_false_positives) *
            (num_true_positives + num_false_negatives) +
            (num_false_negatives + num_true_negatives) *
            (num_false_positives + num_true_negatives)
        )
        num_examples = (
            num_true_positives + num_false_positives +
            num_false_negatives + num_true_negatives
        )
        random_num_correct = random_num_correct / num_examples

        numerator = num_true_positives + num_true_negatives - random_num_correct
        denominator = num_examples - random_num_correct + K.epsilon()
        heidke_value = numerator / denominator

        if use_as_loss_function:
            return 1. - heidke_value

        return heidke_value

    if function_name is not None:
        heidke_function.__name__ = function_name

    return heidke_function


def peirce_score(mask_matrix, use_as_loss_function, function_name=None):
    """Creates function to compute Peirce score at given scale.

    :param mask_matrix: See doc for `fractions_skill_score`.
    :param use_as_loss_function: Same.
    :param function_name: Same.
    :return: peirce_function: Function (defined below).
    """

    mask_matrix = numpy.expand_dims(mask_matrix, axis=0).astype(float)

    def peirce_function(target_tensor, prediction_tensor):
        """Computes Peirce score at a given scale.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: peirce_value: Peirce score (scalar).
        """

        num_true_positives = K.sum(
            mask_matrix * target_tensor * prediction_tensor
        )
        num_false_positives = K.sum(
            mask_matrix * (1 - target_tensor) * prediction_tensor
        )
        num_false_negatives = K.sum(
            mask_matrix * target_tensor * (1 - prediction_tensor)
        )
        num_true_negatives = K.sum(
            mask_matrix * (1 - target_tensor) * (1 - prediction_tensor)
        )

        pod_value = (
            num_true_positives /
            (num_true_positives + num_false_negatives + K.epsilon())
        )
        pofd_value = (
            num_false_positives /
            (num_false_positives + num_true_negatives + K.epsilon())
        )
        peirce_value = pod_value - pofd_value

        if use_as_loss_function:
            return 1. - peirce_value

        return peirce_value

    if function_name is not None:
        peirce_function.__name__ = function_name

    return peirce_function


def gerrity_score(mask_matrix, use_as_loss_function, function_name=None):
    """Creates function to compute Gerrity score at given scale.

    :param mask_matrix: See doc for `fractions_skill_score`.
    :param use_as_loss_function: Same.
    :param function_name: Same.
    :return: gerrity_function: Function (defined below).
    """

    mask_matrix = numpy.expand_dims(mask_matrix, axis=0).astype(float)

    def gerrity_function(target_tensor, prediction_tensor):
        """Computes Gerrity score at a given scale.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: gerrity_value: Gerrity score (scalar).
        """

        num_true_positives = K.sum(
            mask_matrix * target_tensor * prediction_tensor
        )
        num_false_positives = K.sum(
            mask_matrix * (1 - target_tensor) * prediction_tensor
        )
        num_false_negatives = K.sum(
            mask_matrix * target_tensor * (1 - prediction_tensor)
        )
        num_true_negatives = K.sum(
            mask_matrix * (1 - target_tensor) * (1 - prediction_tensor)
        )

        event_ratio = (
            (num_false_positives + num_true_negatives) /
            (num_true_positives + num_false_negatives + K.epsilon())
        )
        num_examples = (
            num_true_positives + num_false_positives +
            num_false_negatives + num_true_negatives
        )

        numerator = (
            num_true_positives * event_ratio
            + num_true_negatives * (1. / event_ratio)
            - num_false_positives - num_false_negatives
        )

        gerrity_value = numerator / num_examples

        if use_as_loss_function:
            return 1. - gerrity_value

        return gerrity_value

    if function_name is not None:
        gerrity_function.__name__ = function_name

    return gerrity_function


def pod(half_window_size_px, mask_matrix, function_name=None, test_mode=False):
    """Creates function to compute probability of detection.

    M = number of rows in grid
    N = number of columns in grid

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: M-by-N numpy array of Boolean flags.  Grid cells marked
        "False" are masked out and not used to compute the loss.
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: pod_function: Function (defined below).
    """

    eroded_mask_matrix = _erode_mask(
        mask_matrix=copy.deepcopy(mask_matrix),
        half_window_size_px=half_window_size_px
    )
    # eroded_mask_tensor = K.variable(eroded_mask_matrix)

    def pod_function(target_tensor, prediction_tensor):
        """Computes probability of detection.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: pod: Probability of detection.
        """

        filtered_prediction_tensor = _apply_max_filter(
            input_tensor=prediction_tensor,
            half_window_size_px=half_window_size_px
        )

        masked_prediction_tensor = (
            eroded_mask_matrix * filtered_prediction_tensor
        )
        masked_target_tensor = eroded_mask_matrix * target_tensor

        num_actual_oriented_true_positives = K.sum(
            masked_target_tensor * masked_prediction_tensor
        )
        num_false_negatives = K.sum(
            masked_target_tensor * (1 - masked_prediction_tensor)
        )

        denominator = (
            num_actual_oriented_true_positives + num_false_negatives +
            K.epsilon()
        )
        return num_actual_oriented_true_positives / denominator

    if function_name is not None:
        pod_function.__name__ = function_name

    return pod_function


def success_ratio(half_window_size_px, mask_matrix, function_name=None,
                  test_mode=False):
    """Creates function to compute success ratio.

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: See doc for `pod`.
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: success_ratio_function: Function (defined below).
    """

    eroded_mask_matrix = _erode_mask(
        mask_matrix=copy.deepcopy(mask_matrix),
        half_window_size_px=half_window_size_px
    )
    # eroded_mask_tensor = K.variable(eroded_mask_matrix)

    def success_ratio_function(target_tensor, prediction_tensor):
        """Computes success ratio.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: success_ratio: Success ratio.
        """

        filtered_target_tensor = _apply_max_filter(
            input_tensor=target_tensor, half_window_size_px=half_window_size_px
        )

        masked_target_tensor = eroded_mask_matrix * filtered_target_tensor
        masked_prediction_tensor = eroded_mask_matrix * prediction_tensor

        num_prediction_oriented_true_positives = K.sum(
            masked_target_tensor * masked_prediction_tensor
        )
        num_false_positives = K.sum(
            (1 - masked_target_tensor) * masked_prediction_tensor
        )

        denominator = (
            num_prediction_oriented_true_positives + num_false_positives +
            K.epsilon()
        )
        return num_prediction_oriented_true_positives / denominator

    if function_name is not None:
        success_ratio_function.__name__ = function_name

    return success_ratio_function


def csi(half_window_size_px, mask_matrix, use_as_loss_function=False,
        function_name=None, test_mode=False):
    """Creates function to compute critical success index.

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: See doc for `pod`.
    :param use_as_loss_function: Boolean flag.  If True (False), will use CSI as
        loss function (metric).
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: csi_function: Function (defined below).
    """

    def csi_function(target_tensor, prediction_tensor):
        """Computes critical success index.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: csi: Critical success index.
        """

        pod_function = pod(
            half_window_size_px=half_window_size_px, mask_matrix=mask_matrix,
            test_mode=test_mode
        )
        pod_value = K.epsilon() + pod_function(
            target_tensor=target_tensor, prediction_tensor=prediction_tensor
        )

        success_ratio_function = success_ratio(
            half_window_size_px=half_window_size_px, mask_matrix=mask_matrix,
            test_mode=test_mode
        )
        success_ratio_value = K.epsilon() + success_ratio_function(
            target_tensor=target_tensor, prediction_tensor=prediction_tensor
        )

        csi_value = (pod_value ** -1 + success_ratio_value ** -1 - 1) ** -1

        if use_as_loss_function:
            return 1. - csi_value

        return csi_value

    if function_name is not None:
        csi_function.__name__ = function_name

    return csi_function


def iou(half_window_size_px, mask_matrix, use_as_loss_function=False,
        function_name=None, test_mode=False):
    """Creates function to compute intersection over union.

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: See doc for `pod`.
    :param use_as_loss_function: Boolean flag.  If True (False), will use CSI as
        loss function (metric).
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: iou_function: Function (defined below).
    """

    eroded_mask_matrix = _erode_mask(
        mask_matrix=copy.deepcopy(mask_matrix),
        half_window_size_px=half_window_size_px
    )
    # eroded_mask_tensor = K.variable(eroded_mask_matrix)

    def iou_function(target_tensor, prediction_tensor):
        """Computes intersection over union.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: iou: Intersection over union.
        """

        filtered_target_tensor = _apply_max_filter(
            input_tensor=target_tensor,
            half_window_size_px=half_window_size_px
        )

        masked_target_tensor = eroded_mask_matrix * filtered_target_tensor
        masked_prediction_tensor = eroded_mask_matrix * prediction_tensor

        masked_target_tensor = masked_target_tensor[..., 0]
        masked_prediction_tensor = masked_prediction_tensor[..., 0]

        intersection_tensor = K.sum(
            masked_target_tensor * masked_prediction_tensor, axis=(1, 2)
        )
        union_tensor = (
            K.sum(masked_target_tensor, axis=(1, 2)) +
            K.sum(masked_prediction_tensor, axis=(1, 2)) -
            intersection_tensor
        )

        iou_value = K.mean(
            intersection_tensor / (union_tensor + K.epsilon())
        )

        if use_as_loss_function:
            return 1. - iou_value

        return iou_value

    if function_name is not None:
        iou_function.__name__ = function_name

    return iou_function


def dice_coeff(half_window_size_px, mask_matrix, use_as_loss_function=False,
               function_name=None, test_mode=False):
    """Creates function to compute Dice coefficient.

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: See doc for `pod`.
    :param use_as_loss_function: Boolean flag.  If True (False), will use CSI as
        loss function (metric).
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: dice_function: Function (defined below).
    """

    eroded_mask_matrix = _erode_mask(
        mask_matrix=copy.deepcopy(mask_matrix),
        half_window_size_px=half_window_size_px
    )
    # eroded_mask_tensor = K.variable(eroded_mask_matrix)

    def dice_function(target_tensor, prediction_tensor):
        """Computes Dice coefficient.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: dice_coeff: Dice coefficient.
        """

        filtered_target_tensor = _apply_max_filter(
            input_tensor=target_tensor,
            half_window_size_px=half_window_size_px
        )

        masked_target_tensor = eroded_mask_matrix * filtered_target_tensor
        masked_prediction_tensor = eroded_mask_matrix * prediction_tensor
        positive_intersection_tensor = K.sum(
            masked_target_tensor[..., 0] * masked_prediction_tensor[..., 0],
            axis=(1, 2)
        )

        masked_target_tensor = (
            eroded_mask_matrix * (1. - filtered_target_tensor)
        )
        masked_prediction_tensor = eroded_mask_matrix * (1. - prediction_tensor)
        negative_intersection_tensor = K.sum(
            masked_target_tensor[..., 0] * masked_prediction_tensor[..., 0],
            axis=(1, 2)
        )

        eroded_mask_tensor = eroded_mask_matrix * K.ones_like(prediction_tensor)
        num_pixels_tensor = K.sum(eroded_mask_tensor, axis=(1, 2, 3))

        dice_value = K.mean(
            (positive_intersection_tensor + negative_intersection_tensor) /
            num_pixels_tensor
        )

        if use_as_loss_function:
            return 1. - dice_value

        return dice_value

    if function_name is not None:
        dice_function.__name__ = function_name

    return dice_function


def brier_score(half_window_size_px, mask_matrix, function_name=None,
                test_mode=False):
    """Creates function to compute Brier score.

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: See doc for `pod`.
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: brier_function: Function (defined below).
    """

    eroded_mask_matrix = _erode_mask(
        mask_matrix=copy.deepcopy(mask_matrix),
        half_window_size_px=half_window_size_px
    )
    # eroded_mask_tensor = K.variable(eroded_mask_matrix)

    def brier_function(target_tensor, prediction_tensor):
        """Computes Brier score.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: brier_score: Brier score.
        """

        filtered_target_tensor = _apply_max_filter(
            input_tensor=target_tensor,
            half_window_size_px=half_window_size_px
        )

        masked_target_tensor = eroded_mask_matrix * filtered_target_tensor
        masked_prediction_tensor = eroded_mask_matrix * prediction_tensor

        squared_error_tensor = K.sum(
            (masked_target_tensor - masked_prediction_tensor) ** 2,
            axis=(1, 2, 3)
        )

        eroded_mask_tensor = eroded_mask_matrix * K.ones_like(prediction_tensor)
        num_pixels_tensor = K.sum(eroded_mask_tensor, axis=(1, 2, 3))

        return K.mean(squared_error_tensor / num_pixels_tensor)

    if function_name is not None:
        brier_function.__name__ = function_name

    return brier_function


def frequency_bias(half_window_size_px, mask_matrix, function_name=None,
                   test_mode=False):
    """Creates function to compute frequency bias.

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: See doc for `pod`.
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: frequency_bias_function: Function (defined below).
    """

    def frequency_bias_function(target_tensor, prediction_tensor):
        """Computes frequency bias.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: frequency_bias: Frequency bias.
        """

        pod_function = pod(
            half_window_size_px=half_window_size_px, mask_matrix=mask_matrix,
            test_mode=test_mode
        )
        pod_value = pod_function(
            target_tensor=target_tensor, prediction_tensor=prediction_tensor
        )

        success_ratio_function = success_ratio(
            half_window_size_px=half_window_size_px, mask_matrix=mask_matrix,
            test_mode=test_mode
        )
        success_ratio_value = success_ratio_function(
            target_tensor=target_tensor, prediction_tensor=prediction_tensor
        )

        return pod_value / (success_ratio_value + K.epsilon())

    if function_name is not None:
        frequency_bias_function.__name__ = function_name

    return frequency_bias_function


def all_class_iou(half_window_size_px, mask_matrix, use_as_loss_function=False,
                  function_name=None, test_mode=False):
    """Creates function to compute all-class intersection over union (IOU).

    :param half_window_size_px: See doc for `_apply_max_filter`.
    :param mask_matrix: See doc for `pod`.
    :param use_as_loss_function: Boolean flag.  If True (False), will use CSI as
        loss function (metric).
    :param function_name: Function name (string).
    :param test_mode: Leave this alone.
    :return: iou_function: Function (defined below).
    """

    eroded_mask_matrix = _erode_mask(
        mask_matrix=copy.deepcopy(mask_matrix),
        half_window_size_px=half_window_size_px
    )
    # eroded_mask_tensor = K.variable(eroded_mask_matrix)

    def all_class_iou_function(target_tensor, prediction_tensor):
        """Computes all-class IOU.

        :param target_tensor: Tensor of target (actual) values.
        :param prediction_tensor: Tensor of predicted values.
        :return: iou_value: All-class IOU.
        """

        filtered_target_tensor = _apply_max_filter(
            input_tensor=target_tensor,
            half_window_size_px=half_window_size_px
        )

        masked_target_tensor = eroded_mask_matrix * filtered_target_tensor
        masked_prediction_tensor = eroded_mask_matrix * prediction_tensor
        positive_intersection_tensor = K.sum(
            masked_target_tensor[..., 0] * masked_prediction_tensor[..., 0],
            axis=(1, 2)
        )
        positive_union_tensor = (
            K.sum(masked_target_tensor[..., 0], axis=(1, 2)) +
            K.sum(masked_prediction_tensor[..., 0], axis=(1, 2)) -
            positive_intersection_tensor
        )

        masked_target_tensor = (
            eroded_mask_matrix * (1. - filtered_target_tensor)
        )
        masked_prediction_tensor = eroded_mask_matrix * (1. - prediction_tensor)
        negative_intersection_tensor = K.sum(
            masked_target_tensor[..., 0] * masked_prediction_tensor[..., 0],
            axis=(1, 2)
        )
        negative_union_tensor = (
            K.sum(masked_target_tensor[..., 0], axis=(1, 2)) +
            K.sum(masked_prediction_tensor[..., 0], axis=(1, 2)) -
            negative_intersection_tensor
        )

        positive_iou = K.mean(
            positive_intersection_tensor / (positive_union_tensor + K.epsilon())
        )
        negative_iou = K.mean(
            negative_intersection_tensor / (negative_union_tensor + K.epsilon())
        )
        iou_value = (positive_iou + negative_iou) / 2

        if use_as_loss_function:
            return 1. - iou_value

        return iou_value

    if function_name is not None:
        all_class_iou_function.__name__ = function_name

    return all_class_iou_function

# <font color='red'>Define public methods (required)</font>

The next cell defines public methods used in this notebook.

In [None]:
def mkdir_recursive_if_necessary(directory_name=None, file_name=None):
    """Creates directory if necessary (i.e., doesn't already exist).

    This method checks for the argument `directory_name` first.  If
    `directory_name` is None, this method checks for `file_name` and extracts
    the directory.

    :param directory_name: Path to local directory.
    :param file_name: Path to local file.
    """

    if directory_name is None:
        assert isinstance(file_name, str)
        directory_name = os.path.dirname(file_name)
    else:
        assert isinstance(directory_name, str)

    if directory_name == '':
        return

    try:
        os.makedirs(directory_name)
    except OSError as this_error:
        if this_error.errno == errno.EEXIST and os.path.isdir(directory_name):
            pass
        else:
            raise


def file_name_to_date(predictor_file_name):
    """Parses date from name of predictor or target file.

    :param predictor_file_name: Path to predictor or target file (see
        `_find_predictor_file` or `_find_target_file` for naming convention).
    :return: valid_date_string: Valid date (format "yyyymmdd").
    """

    pathless_file_name = os.path.split(predictor_file_name)[-1]

    valid_date_string = pathless_file_name.split('.')[0].split('_')[1]
    _time_string_to_unix_sec(valid_date_string, DATE_STRING_FORMAT)

    return valid_date_string


def decompress_gzip_file(gzip_file_name):
    """Decompresses gzip file.

    :param gzip_file_name: Path to gzip file.
    """

    assert gzip_file_name.endswith(GZIP_FILE_EXTENSION)
    new_file_name = gzip_file_name[:-len(GZIP_FILE_EXTENSION)]

    with gzip.open(gzip_file_name, 'rb') as gzip_handle:
        with open(new_file_name, 'wb') as new_handle:
            shutil.copyfileobj(gzip_handle, new_handle)


def read_model(hdf5_file_name):
    """Reads neural-net model from HDF5 file.

    :param hdf5_file_name: Path to input file.
    :return: model_object: Instance of `keras.models.Model`.
    """

    metafile_name = _find_nn_metafile(
        model_file_name=hdf5_file_name, raise_error_if_missing=True
    )

    metadata_dict = _read_nn_metafile(metafile_name)
    mask_matrix = metadata_dict[MASK_MATRIX_KEY]
    loss_function_name = metadata_dict[LOSS_FUNCTION_KEY]
    metric_names = metadata_dict[METRIC_NAMES_KEY]

    metric_list, custom_object_dict = get_metrics(
        metric_names=metric_names, mask_matrix=mask_matrix,
        use_as_loss_function=False
    )

    if loss_function_name is not None:
        loss_function = get_metrics(
            metric_names=[loss_function_name], mask_matrix=mask_matrix,
            use_as_loss_function=True
        )[0][0]

        custom_object_dict['loss'] = loss_function

    model_object = tf_keras.models.load_model(
        hdf5_file_name, custom_objects=custom_object_dict, compile=False
    )
    model_object.compile(
        loss=custom_object_dict['loss'], optimizer=tf_keras.optimizers.Adam(),
        metrics=metric_list
    )

    return model_object


def train_model(
        model_object, output_dir_name, num_epochs,
        num_training_batches_per_epoch, training_option_dict,
        num_validation_batches_per_epoch, validation_option_dict,
        mask_matrix, loss_function_name, metric_names,
        do_early_stopping=True, plateau_lr_multiplier=0.6,
        save_every_epoch=True):
    """Trains neural net on either full grid or partial grids.

    m = number of rows in prediction grid
    n = number of columns in prediction grid

    :param model_object: Untrained neural net (instance of `keras.models.Model`
        or `keras.models.Sequential`).
    :param output_dir_name: Path to output directory (model and training history
        will be saved here).
    :param num_epochs: Number of training epochs.
    :param num_training_batches_per_epoch: Number of training batches per epoch.
    :param training_option_dict: See doc for
        `_generator_partial_grids`.  This dictionary will be used to
        generate training data.
    :param num_validation_batches_per_epoch: Number of validation batches per
        epoch.
    :param validation_option_dict: See doc for
        `_generator_partial_grids`.  For validation only, the following
        values will replace corresponding values in `training_option_dict`:
    validation_option_dict['top_predictor_dir_name']
    validation_option_dict['top_target_dir_name']
    validation_option_dict['first_valid_date_string']
    validation_option_dict['last_valid_date_string']

    :param mask_matrix: m-by-n numpy array of Boolean flags.  Grid cells labeled
        True (False) are (not) used for model evaluation.
    :param loss_function_name: Name of loss function.  Must be accepted by
        `_metric_name_to_params`.
    :param metric_names: 1-D list of metric names.  Each name must be accepted
        by `_metric_name_to_params`.
    :param do_early_stopping: Boolean flag.  If True, will stop training early
        if validation loss has not improved over last several epochs (see
        constants at top of file for what exactly this means).
    :param plateau_lr_multiplier: Multiplier for learning rate.  Learning
        rate will be multiplied by this factor upon plateau in validation
        performance.
    :param save_every_epoch: Boolean flag.  If True, will save new model after
        every epoch.
    """

    mkdir_recursive_if_necessary(directory_name=output_dir_name)
    training_option_dict = _check_generator_args(training_option_dict)

    validation_keys_to_keep = [
        PREDICTOR_DIRECTORY_KEY, TARGET_DIRECTORY_KEY,
        FIRST_VALID_DATE_KEY, LAST_VALID_DATE_KEY
    ]

    for this_key in list(training_option_dict.keys()):
        if this_key in validation_keys_to_keep:
            continue

        validation_option_dict[this_key] = training_option_dict[this_key]

    validation_option_dict = _check_generator_args(validation_option_dict)

    if save_every_epoch:
        model_file_name = (
            output_dir_name +
            '/model_epoch={epoch:03d}_val-loss={val_loss:.6f}.h5'
        )
    else:
        model_file_name = '{0:s}/model.h5'.format(output_dir_name)

    history_object = keras.callbacks.CSVLogger(
        filename='{0:s}/history.csv'.format(output_dir_name),
        separator=',', append=False
    )
    checkpoint_object = keras.callbacks.ModelCheckpoint(
        filepath=model_file_name, monitor='val_loss', verbose=1,
        save_best_only=not save_every_epoch, save_weights_only=False,
        mode='min', period=1
    )
    list_of_callback_objects = [history_object, checkpoint_object]

    if do_early_stopping:
        early_stopping_object = keras.callbacks.EarlyStopping(
            monitor='val_loss', min_delta=LOSS_PATIENCE,
            patience=EARLY_STOPPING_PATIENCE_EPOCHS, verbose=1, mode='min'
        )
        list_of_callback_objects.append(early_stopping_object)

        plateau_object = keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss', factor=plateau_lr_multiplier,
            patience=PLATEAU_PATIENCE_EPOCHS, verbose=1, mode='min',
            min_delta=LOSS_PATIENCE, cooldown=PLATEAU_COOLDOWN_EPOCHS
        )
        list_of_callback_objects.append(plateau_object)

    # metafile_name = _find_nn_metafile(
    #     model_file_name=model_file_name, raise_error_if_missing=False
    # )
    # print('Writing metadata to: "{0:s}"...'.format(metafile_name))
    #
    # _write_metafile(
    #     dill_file_name=metafile_name, num_epochs=num_epochs,
    #     num_training_batches_per_epoch=num_training_batches_per_epoch,
    #     training_option_dict=training_option_dict,
    #     num_validation_batches_per_epoch=num_validation_batches_per_epoch,
    #     validation_option_dict=validation_option_dict,
    #     do_early_stopping=do_early_stopping,
    #     plateau_lr_multiplier=plateau_lr_multiplier,
    #     loss_function_name=loss_function_name, metric_names=metric_names,
    #     mask_matrix=mask_matrix
    # )

    training_generator = _generator_partial_grids(training_option_dict)
    validation_generator = _generator_partial_grids(validation_option_dict)

    model_object.fit_generator(
        generator=training_generator,
        steps_per_epoch=num_training_batches_per_epoch,
        epochs=num_epochs, verbose=1, callbacks=list_of_callback_objects,
        validation_data=validation_generator,
        validation_steps=num_validation_batches_per_epoch
    )


def get_metrics(metric_names, mask_matrix, use_as_loss_function):
    """Returns metrics used for on-the-fly monitoring.

    K = number of metrics
    M = number of rows in grid
    N = number of columns in grid

    :param metric_names: length-K list of metric names (each must be accepted by
        `_metric_name_to_params`).
    :param mask_matrix: M-by-N numpy array of Boolean flags.  Only pixels marked
        "True" are considered in each metric.
    :param use_as_loss_function: Boolean flag.  If True (False), will return
        each metric to be used as a loss function (metric).
    :return: metric_function_list: length-K list of functions.
    :return: metric_function_dict: Dictionary, where each key is a function name
        and each value is a function itself.
    """

    if len(metric_names) > 1:
        use_as_loss_function = False

    metric_function_list = []
    metric_function_dict = dict()

    for this_metric_name in metric_names:
        this_param_dict = _metric_name_to_params(this_metric_name)

        if this_param_dict[SCORE_NAME_KEY] == FSS_NAME:
            this_function = fractions_skill_score(
                half_window_size_px=this_param_dict[HALF_WINDOW_SIZE_KEY],
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == HEIDKE_SCORE_NAME:
            this_function = heidke_score(
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == PEIRCE_SCORE_NAME:
            this_function = peirce_score(
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == GERRITY_SCORE_NAME:
            this_function = gerrity_score(
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == CROSS_ENTROPY_NAME:
            this_function = cross_entropy(
                mask_matrix=mask_matrix, function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == BRIER_SCORE_NAME:
            this_function = brier_score(
                half_window_size_px=this_param_dict[HALF_WINDOW_SIZE_KEY],
                mask_matrix=mask_matrix, function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == CSI_NAME:
            this_function = csi(
                half_window_size_px=this_param_dict[HALF_WINDOW_SIZE_KEY],
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == FREQUENCY_BIAS_NAME:
            this_function = frequency_bias(
                half_window_size_px=this_param_dict[HALF_WINDOW_SIZE_KEY],
                mask_matrix=mask_matrix, function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == IOU_NAME:
            this_function = iou(
                half_window_size_px=this_param_dict[HALF_WINDOW_SIZE_KEY],
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )
        elif this_param_dict[SCORE_NAME_KEY] == ALL_CLASS_IOU_NAME:
            this_function = all_class_iou(
                half_window_size_px=this_param_dict[HALF_WINDOW_SIZE_KEY],
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )
        else:
            this_function = dice_coeff(
                half_window_size_px=this_param_dict[HALF_WINDOW_SIZE_KEY],
                mask_matrix=mask_matrix,
                use_as_loss_function=use_as_loss_function,
                function_name=this_metric_name
            )

        metric_function_list.append(this_function)
        metric_function_dict[this_metric_name] = this_function

    return metric_function_list, metric_function_dict

# <font color='red'>Download input data (required)</font>

The next two cells download all input data for this notebook.

## <font color='red'>Download targets and model (required)</font>

The next cell downloads targets (actual convection masks) and the model.  Later in this notebook, we will change the loss function for said model, but the architecture will remain the same.

In [None]:
mkdir_recursive_if_necessary(file_name=TARGET_AND_MODEL_FILE_NAME_ZIPPED)
mkdir_recursive_if_necessary(directory_name=TARGET_AND_MODEL_DIR_NAME)

print('Downloading file: "{0:s}"...'.format(
    ONLINE_TARGET_AND_MODEL_FILE_NAME_ZIPPED
))
urlretrieve(
    ONLINE_TARGET_AND_MODEL_FILE_NAME_ZIPPED, TARGET_AND_MODEL_FILE_NAME_ZIPPED
)

print('Unzipping file: "{0:s}" to "{1:s}"...'.format(
    TARGET_AND_MODEL_FILE_NAME_ZIPPED, TARGET_AND_MODEL_DIR_NAME
))
tar_file_handle = tarfile.open(TARGET_AND_MODEL_FILE_NAME_ZIPPED)
tar_file_handle.extractall(TARGET_AND_MODEL_DIR_NAME)
tar_file_handle.close()

os.remove(TARGET_AND_MODEL_FILE_NAME_ZIPPED)

gzip_file_names = glob.glob('{0:s}/201*/*.nc.gz'.format(
    TARGET_AND_MODEL_DIR_NAME
))
gzip_file_names.sort()

for this_file_name in gzip_file_names:
    this_date_string = file_name_to_date(this_file_name)
    if this_date_string.startswith('2018'):
        continue

    print('Decompressing gzip file: "{0:s}"...'.format(this_file_name))
    decompress_gzip_file(this_file_name)
    os.remove(this_file_name)

Downloading file: "https://storage.googleapis.com/loss-functions-paper-2022-colab/targets_and_model.tar"...
Unzipping file: "/content/targets_and_model.tar" to "/content/targets_and_model"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160101_radar0.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160101_radar1.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160101_radar2.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160101_radar3.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160202_radar0.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160202_radar1.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160202_radar2.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_20160202_radar3.nc.gz"...
Decompressing gzip file: "/content/targets_and_model/2016/targets_201603

## <font color='red'>Download predictors (required)</font>

The next cell downloads predictors (gridded brightness-temperature maps from the Himawari-8 satellite).

**Note: this will take about 5 minutes.**  You might want to go grab a beer/coffee.

In [None]:
mkdir_recursive_if_necessary(directory_name=PREDICTOR_DIR_NAME)

for i in range(len(ONLINE_PREDICTOR_FILE_NAMES_ZIPPED)):
    print('\nDownloading file: "{0:s}"...'.format(
        ONLINE_PREDICTOR_FILE_NAMES_ZIPPED[i]
    ))
    urlretrieve(
        ONLINE_PREDICTOR_FILE_NAMES_ZIPPED[i], PREDICTOR_FILE_NAMES_ZIPPED[i]
    )

    this_pathless_file_name = os.path.split(PREDICTOR_FILE_NAMES_ZIPPED[i])[1]
    this_year = int(
        this_pathless_file_name.split('_')[-1].split('.')[0]
    )
    this_directory_name = '{0:s}/{1:04d}'.format(PREDICTOR_DIR_NAME, this_year)

    print('Unzipping file: "{0:s}" to "{1:s}"...'.format(
        PREDICTOR_FILE_NAMES_ZIPPED[i], this_directory_name
    ))
    tar_file_handle = tarfile.open(PREDICTOR_FILE_NAMES_ZIPPED[i])
    tar_file_handle.extractall(this_directory_name)
    tar_file_handle.close()

    os.remove(PREDICTOR_FILE_NAMES_ZIPPED[i])

    if this_year == 2018:
        continue
    
    gzip_file_names = glob.glob('{0:s}/*.nc.gz'.format(this_directory_name))
    gzip_file_names.sort()

    for this_file_name in gzip_file_names:
        print('Decompressing gzip file: "{0:s}"...'.format(this_file_name))
        decompress_gzip_file(this_file_name)
        os.remove(this_file_name)


Downloading file: "https://storage.googleapis.com/loss-functions-paper-2022-colab/predictors_2016_radars1and3.tar"...
Unzipping file: "/content/predictors/predictors_2016.tar" to "/content/predictors/2016"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160101_radar1.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160101_radar3.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160202_radar1.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160202_radar3.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160303_radar1.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160303_radar3.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160404_radar1.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160404_radar3.nc.gz"...
Decompressing gzip file: "/content/predictors/2016/predictors_20160505_radar1.nc.gz"..

# Train U-net with different loss functions

Finally, we have reached the fun part of the notebook!  The remaining code cells will teach you how to train a U-net with different loss functions.  Keep in mind the following:

 - U-net predictors are brightness-temperature maps at lag times of 0, 20, and 40 minutes (*i.e.*, 0, 20, and 40 minutes before the forecast-issue time).  You can change the lag times if you want, but that's not really the purpose of this notebook.  The purpose is to experiment with loss functions, not other hyperparameters.
 - The U-net target is the convection mask at a lead time of 60 minutes (*i.e.*, 60 minutes after the forecast-issue time).
 - The training period is Jan 1 - Dec 24 2016.
 - The validation period is Jan 1 - Dec 24 2017.
 - I have not included testing data in this notebook, because Colab virtual machines do not have enough disk space to store training, validation, and testing data.  As it is, I have included only a small subset (12 days per year and 2 of 3 radars) of the training and validation data.

## Train with neighbourhood loss function

The next cell trains the U-net with the 9-by-9 pixel fractions skill score (half-width of 4 pixels) as a loss function.  Code lines that help determine the loss function are marked with "`# IMPORTANT`".

Recall that when a positively oriented evaluation score $s$ (where higher is better, such as the FSS) is used in the loss function, the loss function is actually $1 - s$, because loss functions are negatively oriented.

**Other options for neighbourhood loss functions are commented out.  Uncomment them to see what happens!!**

In [None]:
# Define training options.
training_option_dict = {
    PREDICTOR_DIRECTORY_KEY: PREDICTOR_DIR_NAME,
    TARGET_DIRECTORY_KEY: TARGET_AND_MODEL_DIR_NAME,
    BATCH_SIZE_KEY: 60,
    MAX_DAILY_EXAMPLES_KEY: 8,
    LEAD_TIME_KEY: 3600,
    LAG_TIMES_KEY: numpy.array([0, 1200, 2400], dtype=int),
    INCLUDE_TIME_DIM_KEY: False,
    FIRST_VALID_DATE_KEY: '20160101',
    LAST_VALID_DATE_KEY: '20161224',
    NORMALIZE_FLAG_KEY: True,
    UNIFORMIZE_FLAG_KEY: True,
    FOURIER_TRANSFORM_KEY: False,  # IMPORTANT: do not use Fourier decomp, because we have a neigh loss function, not a scale-separation loss function
    WAVELET_TRANSFORM_KEY: False   # IMPORTANT: do not use wavelet decomp, for the same reason
}

# Validation options are the same, except the time period.
validation_option_dict = {
    PREDICTOR_DIRECTORY_KEY: PREDICTOR_DIR_NAME,
    TARGET_DIRECTORY_KEY: TARGET_AND_MODEL_DIR_NAME,
    FIRST_VALID_DATE_KEY: '20170101',
    LAST_VALID_DATE_KEY: '20171224'
}

# Read U-net model and print architecture.
model_file_name = '{0:s}/model.h5'.format(TARGET_AND_MODEL_DIR_NAME)
model_object = read_model(model_file_name)
model_object.summary()

print('\n\n**************************************************\n\n')

# Read U-net metafile and create list of metrics to use on validation data.
# KEEP IN MIND: There is only one difference between a "loss function" and a
# "metric".  A loss function is used as an evaluation tool during training,
# while a metric is used outside of training (here, on validation data after
# every epoch, as a monitoring tool).

metafile_name = _find_nn_metafile(
    model_file_name=model_file_name, raise_error_if_missing=True
)
metadata_dict = _read_nn_metafile(metafile_name)

mask_matrix = metadata_dict[MASK_MATRIX_KEY]
metric_names = metadata_dict[METRIC_NAMES_KEY]
metric_list, _ = get_metrics(
    metric_names=metric_names, mask_matrix=mask_matrix,
    use_as_loss_function=False
)

# Compile the model with the desired loss function.
loss_function_name = 'fss_neigh4'  # IMPORTANT.  Other options for neighbourhood loss functions are commented out.  Uncomment to see what happens!!
# loss_function_name = 'brier_neigh1'
# loss_function_name = 'csi_neigh2'
# loss_function_name = 'iou_neigh3'
# loss_function_name = 'dice_neigh12'

loss_function = get_metrics(
    metric_names=[loss_function_name], mask_matrix=mask_matrix,
    use_as_loss_function=True
)[0][0]

model_object.compile(
    loss=loss_function, optimizer=tf_keras.optimizers.Adam(),
    metrics=metric_list
)

# Train the model with the desired loss function.  NOTE: Even with a GPU,
# training is a bit slow.  If you want training to finish faster, decrease
# num_epochs, num_training_batches_per_epoch, or
# num_validation_batches_per_epoch.
new_model_dir_name = '{0:s}/{1:s}'.format(
    TARGET_AND_MODEL_DIR_NAME, loss_function_name
)
train_model(
    model_object=model_object, output_dir_name=new_model_dir_name,
    num_epochs=10, num_training_batches_per_epoch=32,
    training_option_dict=training_option_dict,
    num_validation_batches_per_epoch=16,
    validation_option_dict=validation_option_dict,
    mask_matrix=mask_matrix, loss_function_name=loss_function_name,
    metric_names=metric_names, do_early_stopping=True,
    plateau_lr_multiplier=0.6, save_every_epoch=False
)

Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_7 (InputLayer)           [(None, 205, 205, 2  0           []                               
                                1)]                                                               
                                                                                                  
 conv2d_174 (Conv2D)            (None, 205, 205, 16  3040        ['input_7[0][0]']                
                                )                                                                 
                                                                                                  
 leaky_re_lu_138 (LeakyReLU)    (None, 205, 205, 16  0           ['conv2d_174[0][0]']             
                                )                                                           



Reading data from: "/content/targets_and_model/2016/targets_20160606_radar3.nc"...
Reading data from: "/content/predictors/2016/predictors_20160505_radar3.nc"...




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Number of target values in batch = 336200 ... mean = 0
Reading data from: "/content/targets_and_model/2017/targets_20170808_radar3.nc"...
Reading data from: "/content/predictors/2017/predictors_20170707_radar3.nc"...
Reading data from: "/content/predictors/2017/predictors_20170808_radar3.nc"...
Number of target values in batch = 336200 ... mean = 0.000241
Reading data from: "/content/targets_and_model/2017/targets_20170909_radar3.nc"...
Reading data from: "/content/predictors/2017/predictors_20170808_radar3.nc"...
Reading data from: "/content/predictors/2017/predictors_20170909_radar3.nc"...
Number of target values in batch = 336200 ... mean = 0.000155
Reading data from: "/content/targets_and_model/2017/targets_20170606_radar1.nc"...
Reading data from: "/content/predictors/2017/predictors_20170505_radar1.nc"...
Reading data from: "/content/predictors/2017/predictors_20170606_radar1.nc"...
Number of target values in batch 

## Train with scale-separation loss function, using Fourier decomposition

The next cell trains the U-net with the pixelwise (1-by-1) fractions skill score (FSS) as the loss function.  However, in the data-generator, Fourier decomposition is used to emphasize the desired scales -- wavelengths of 0.1-0.4$^{\circ}$ lat/long, or resolutions of 0.05-0.2$^{\circ}$.

Keep in mind that for "scale-separation loss functions," as discussed in the paper, scale separation is actually done *outside* the loss function.  This is why I phrased the first sentence above so weirdly.

**If you want to train the U-net on different scales, just change `MIN_TARGET_RESOLUTION_KEY` and `MAX_TARGET_RESOLUTION_KEY` below.**  Make sure that `MIN_TARGET_RESOLUTION_KEY >= 0` and `MAX_TARGET_RESOLUTION_KEY > MIN_TARGET_RESOLUTION_KEY`.

**If you want to train the U-net with a different score in the loss function, uncomment the relevant code below.**

In [None]:
# Define training options.
training_option_dict = {
    PREDICTOR_DIRECTORY_KEY: PREDICTOR_DIR_NAME,
    TARGET_DIRECTORY_KEY: TARGET_AND_MODEL_DIR_NAME,
    BATCH_SIZE_KEY: 60,
    MAX_DAILY_EXAMPLES_KEY: 8,
    LEAD_TIME_KEY: 3600,
    LAG_TIMES_KEY: numpy.array([0, 1200, 2400], dtype=int),
    INCLUDE_TIME_DIM_KEY: False,
    FIRST_VALID_DATE_KEY: '20160101',
    LAST_VALID_DATE_KEY: '20161224',
    NORMALIZE_FLAG_KEY: True,
    UNIFORMIZE_FLAG_KEY: True,
    FOURIER_TRANSFORM_KEY: True,      # IMPORTANT
    WAVELET_TRANSFORM_KEY: False,     # IMPORTANT
    MIN_TARGET_RESOLUTION_KEY: 0.05,  # IMPORTANT
    MAX_TARGET_RESOLUTION_KEY: 0.2    # IMPORTANT
}

# Validation options are the same, except the time period.
validation_option_dict = {
    PREDICTOR_DIRECTORY_KEY: PREDICTOR_DIR_NAME,
    TARGET_DIRECTORY_KEY: TARGET_AND_MODEL_DIR_NAME,
    FIRST_VALID_DATE_KEY: '20170101',
    LAST_VALID_DATE_KEY: '20171224'
}

# Read U-net model.
model_file_name = '{0:s}/model.h5'.format(TARGET_AND_MODEL_DIR_NAME)
model_object = read_model(model_file_name)

# Read U-net metafile and create list of metrics to use on validation data.
metafile_name = _find_nn_metafile(
    model_file_name=model_file_name, raise_error_if_missing=True
)
metadata_dict = _read_nn_metafile(metafile_name)

mask_matrix = metadata_dict[MASK_MATRIX_KEY]
metric_names = metadata_dict[METRIC_NAMES_KEY]
metric_list, _ = get_metrics(
    metric_names=metric_names, mask_matrix=mask_matrix,
    use_as_loss_function=False
)

# Compile the model with the desired loss function.
loss_function_name = 'fss_neigh0'  # IMPORTANT.  Other options for scale-separation loss functions are commented out.  Uncomment to see what happens!!
# loss_function_name = 'brier_neigh0'
# loss_function_name = 'csi_neigh0'
# loss_function_name = 'iou_neigh0'
# loss_function_name = 'dice_neigh0'
# loss_function_name = 'heidke_neigh0'
# loss_function_name = 'peirce_neigh0'
# loss_function_name = 'gerrity_neigh0'

loss_function = get_metrics(
    metric_names=[loss_function_name], mask_matrix=mask_matrix,
    use_as_loss_function=True
)[0][0]

model_object.compile(
    loss=loss_function, optimizer=tf_keras.optimizers.Adam(),
    metrics=metric_list
)

# Train the model with the desired loss function.  NOTE: Even with a GPU,
# training is a bit slow.  If you want training to finish faster, decrease
# num_epochs, num_training_batches_per_epoch, or
# num_validation_batches_per_epoch.
new_model_dir_name = '{0:s}/fourier_wavelengths=0.1-0.4deg'.format(
    TARGET_AND_MODEL_DIR_NAME
)
train_model(
    model_object=model_object, output_dir_name=new_model_dir_name,
    num_epochs=10, num_training_batches_per_epoch=32,
    training_option_dict=training_option_dict,
    num_validation_batches_per_epoch=16,
    validation_option_dict=validation_option_dict,
    mask_matrix=mask_matrix, loss_function_name=loss_function_name,
    metric_names=metric_names, do_early_stopping=True,
    plateau_lr_multiplier=0.6, save_every_epoch=False
)





Reading data from: "/content/targets_and_model/2016/targets_20160909_radar1.nc"...
Reading data from: "/content/predictors/2016/predictors_20160808_radar1.nc"...




Reading data from: "/content/predictors/2016/predictors_20160909_radar1.nc"...
Number of target values in batch = 336200 ... mean = 0.00716




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Reading data from: "/content/predictors/2016/predictors_20161010_radar3.nc"...
Reading data from: "/content/predictors/2016/predictors_20161111_radar3.nc"...
Number of target values in batch = 336200 ... mean = 0
Number of target values and mean after Fourier transform = 336200, 0
Reading data from: "/content/targets_and_model/2016/targets_20160808_radar1.nc"...
Reading data from: "/content/predictors/2016/predictors_20160707_radar1.nc"...
Reading data from: "/content/predictors/2016/predictors_20160808_radar1.nc"...
Number of target values in batch = 336200 ... mean = 0.00341
Number of target values and mean after Fourier transform = 336200, 0.0018
Reading data from: "/content/targets_and_model/2016/targets_20161212_radar1.nc"...
Reading data from: "/content/predictors/2016/predictors_20161111_radar1.nc"...
Reading data from: "/content/predictors/2016/predictors_20161212_radar1.nc"...
Number of target values in batch = 3

## Train with scale-separation loss function, using wavelet decomposition

The next cell trains the U-net with pixelwise FSS as the loss function.  However, in the data-generator, wavelet decomposition is used to emphasize the desired scales -- wavelengths of 0.1-0.4$^{\circ}$ lat/long, or resolutions of 0.05-0.2$^{\circ}$.

**If you want to train the U-net on different scales, just change `MIN_TARGET_RESOLUTION_KEY` and `MAX_TARGET_RESOLUTION_KEY` below.**  Make sure that `MIN_TARGET_RESOLUTION_KEY >= 0` and `MAX_TARGET_RESOLUTION_KEY > MIN_TARGET_RESOLUTION_KEY`.

**If you want to train the U-net with a different score in the loss function, uncomment the relevant code below.**

In [None]:
# Define training options.
training_option_dict = {
    PREDICTOR_DIRECTORY_KEY: PREDICTOR_DIR_NAME,
    TARGET_DIRECTORY_KEY: TARGET_AND_MODEL_DIR_NAME,
    BATCH_SIZE_KEY: 60,
    MAX_DAILY_EXAMPLES_KEY: 8,
    LEAD_TIME_KEY: 3600,
    LAG_TIMES_KEY: numpy.array([0, 1200, 2400], dtype=int),
    INCLUDE_TIME_DIM_KEY: False,
    FIRST_VALID_DATE_KEY: '20160101',
    LAST_VALID_DATE_KEY: '20161224',
    NORMALIZE_FLAG_KEY: True,
    UNIFORMIZE_FLAG_KEY: True,
    FOURIER_TRANSFORM_KEY: False,     # IMPORTANT
    WAVELET_TRANSFORM_KEY: True,      # IMPORTANT
    MIN_TARGET_RESOLUTION_KEY: 0.05,  # IMPORTANT
    MAX_TARGET_RESOLUTION_KEY: 0.2    # IMPORTANT
}

# Validation options are the same, except the time period.
validation_option_dict = {
    PREDICTOR_DIRECTORY_KEY: PREDICTOR_DIR_NAME,
    TARGET_DIRECTORY_KEY: TARGET_AND_MODEL_DIR_NAME,
    FIRST_VALID_DATE_KEY: '20170101',
    LAST_VALID_DATE_KEY: '20171224'
}

# Read U-net model.
model_file_name = '{0:s}/model.h5'.format(TARGET_AND_MODEL_DIR_NAME)
model_object = read_model(model_file_name)

# Read U-net metafile and create list of metrics to use on validation data.
metafile_name = _find_nn_metafile(
    model_file_name=model_file_name, raise_error_if_missing=True
)
metadata_dict = _read_nn_metafile(metafile_name)

mask_matrix = metadata_dict[MASK_MATRIX_KEY]
metric_names = metadata_dict[METRIC_NAMES_KEY]
metric_list, _ = get_metrics(
    metric_names=metric_names, mask_matrix=mask_matrix,
    use_as_loss_function=False
)

# Compile the model with the desired loss function.
loss_function_name = 'fss_neigh0'  # IMPORTANT.  Other options for scale-separation loss functions are commented out.  Uncomment to see what happens!!
# loss_function_name = 'brier_neigh0'
# loss_function_name = 'csi_neigh0'
# loss_function_name = 'iou_neigh0'
# loss_function_name = 'dice_neigh0'
# loss_function_name = 'heidke_neigh0'
# loss_function_name = 'peirce_neigh0'
# loss_function_name = 'gerrity_neigh0'

loss_function = get_metrics(
    metric_names=[loss_function_name], mask_matrix=mask_matrix,
    use_as_loss_function=True
)[0][0]

model_object.compile(
    loss=loss_function, optimizer=tf_keras.optimizers.Adam(),
    metrics=metric_list
)

# Train the model with the desired loss function.  NOTE: Even with a GPU,
# training is a bit slow.  If you want training to finish faster, decrease
# num_epochs, num_training_batches_per_epoch, or
# num_validation_batches_per_epoch.
new_model_dir_name = '{0:s}/wavelet_wavelengths=0.1-0.4deg'.format(
    TARGET_AND_MODEL_DIR_NAME
)
train_model(
    model_object=model_object, output_dir_name=new_model_dir_name,
    num_epochs=10, num_training_batches_per_epoch=32,
    training_option_dict=training_option_dict,
    num_validation_batches_per_epoch=16,
    validation_option_dict=validation_option_dict,
    mask_matrix=mask_matrix, loss_function_name=loss_function_name,
    metric_names=metric_names, do_early_stopping=True,
    plateau_lr_multiplier=0.6, save_every_epoch=False
)

Reading data from: "/content/targets_and_model/2016/targets_20160909_radar1.nc"...




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Reading data from: "/content/predictors/2016/predictors_20160909_radar3.nc"...
Reading data from: "/content/predictors/2016/predictors_20161010_radar3.nc"...
Number of target values in batch = 336200 ... mean = 0.0249
Zeroing out low-frequency coefficients at level 8 of 8 (resolutions = 3.2000 and 1.6000 deg)...
Zeroing out low-frequency coefficients at level 7 of 8 (resolutions = 1.6000 and 0.8000 deg)...
Zeroing out low-frequency coefficients at level 6 of 8 (resolutions = 0.8000 and 0.4000 deg)...
Zeroing out low-frequency coefficients at level 5 of 8 (resolutions = 0.4000 and 0.2000 deg)...
Reconstructing low-frequency coefficients at level 4 of 8 (resolutions = 0.2000 and 0.1000 deg)...
Reconstructing low-frequency coefficients at level 3 of 8 (resolutions = 0.1000 and 0.0500 deg)...
Reconstructing low-frequency coefficients at level 2 of 8 (resolutions = 0.0500 and 0.0250 deg)...
Zeroing out high-frequency coefficie