# Training Script
### [N-DEPTH: Neural Depth Encoding for Compression-Resilient 3D Video Streaming](https://www.mdpi.com/2079-9292/13/13/2557)
### by [Stephen Siemonsma](https://github.com/ssiemonsma)

## Imports

In [None]:
# Standard library imports
import os
import shutil
import math
import random
import time
import re
import itertools
import copy
from datetime import datetime
from abc import ABC
from enum import Enum
from io import BytesIO
from pathlib import Path
from glob import glob
from typing import Callable, List, Optional, Tuple, Union
from functools import partial

# Third-party library imports
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pickle
from tqdm import tqdm  # Compatible with Papermill (for running notebook in command line), but may result in very large .ipynb files if you print updates too often
# from tqdm.notebook import tqdm  # Note that using tqdm.notebook (as opposed to the standard tqdm library) will make the print logs less persistent if you ever move the notebook file.  The regular tqdm library will make the printed logs be more persistent, but this can result in ridiculously large notebook files that will load very slowly if you aren't careful.

# PyTorch and related imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# Torchvision imports
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import _read_pfm
import torchvision.transforms as T

# IPython-specific imports
from IPython.display import display, HTML

## Setup

### Jupyter View Tweaks

In [None]:
# Widens cells to fill window
display(HTML("<style>.container { width:100% !important; }</style>"))

### Performance Fixes

In [None]:
# Limit PyTorch to 1 thread.
# This is a guard against a PyTorch bug where the CPU can sometimes be bogged down with nearly 100% utilization, despite doing little work (the data loader is not nearly that intensive).
torch.set_num_threads(1)

### Configuration Enums

In [None]:
class PrecisionMode(Enum):
    MIXED_PRECISION = 1
    PYTORCH_DEFAULT = 2
    FLOAT32_PRECISION = 3

class QuantizerProxyType(Enum):
    STRAIGHT_THROUGH = "straight-through"  # Generally trains faster
    SOFT = "soft"

class ImageCompressionType(Enum):
    DIFF_JPEG = "DiffJPEG"
    JPEG = "JPEG"
    LOSSLESS = "lossless"

class MaskType(Enum):
    BINARY = "binary"
    ERROR_THRESHOLDED_BINARY = "error-thresholded binary"  # This is what I had the best luck with.  The binary mask logits can be be treated as a "confidence map" in practice.
    ERROR_MAP = "error map"
    
class FinalEncoderActivationFunction(Enum):
    SINE = "sine"
    SIGMOID = "sigmoid"
    NONE = "none"

### Function to change parameters relevant to PyTorch's numerical precision

In [None]:
def set_performance_and_precision_settings(precision_mode):
    match precision_mode:
        case PrecisionMode.MIXED_PRECISION:
            # See: https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere
            torch.backends.cuda.matmul.allow_tf32 = False  # False by default.  This flag controls whether to allow TF32 on matmul.
            torch.backends.cudnn.allow_tf32 = True  # True by default.  This flag controls whether to allow TF32 on cuDNN (for convolutions).
            scaler = torch.amp.GradScaler('cuda')
        case PrecisionMode.PYTORCH_DEFAULT:
            torch.backends.cuda.matmul.allow_tf32 = False  # False by default.  This flag controls whether to allow TF32 on matmul.
            torch.backends.cudnn.allow_tf32 = True  # True by default.  This flag controls whether to allow TF32 on cuDNN (for convolutions).
            scaler = None
        case PrecisionMode.FLOAT32_PRECISION:
            torch.backends.cuda.matmul.allow_tf32 = False  # False by default.  This flag controls whether to allow TF32 on matmul.
            torch.backends.cudnn.allow_tf32 = False  # True by default.  This flag controls whether to allow TF32 on cuDNN (for convolutions).\
            scaler = None
        case PrecisionMode.FLOAT64_PRECISION:
            torch.backends.cuda.matmul.allow_tf32 = False  # False by default.  This flag controls whether to allow TF32 on matmul.
            torch.backends.cudnn.allow_tf32 = False  # True by default.  This flag controls whether to allow TF32 on cuDNN (for convolutions).
            scaler = None
    
    torch.backends.cudnn.benchmark = True  # Allows cuDNN to dynamically select the most efficient convolution algorithms based on runtime benchmarks
    torch.backends.cudnn.deterministic = False  # Set to True for deterministic behavior (helpful when reproducibility is key)

    return scaler

### Papermill Parameters

In [None]:
# If running without Jupyter Notebook or Jupyter Lab, Papermill can be configured to provide the file path of the notebook it is running (so that you can use this file name for naming checkpoint files, etc.).
# Papermill is great for running the notebook entirely from the command line while retaining the ability to save all Jupyter cell outputs.

# Defaults
PAPERMILL_INPUT_PATH = None

In [None]:
# Parameters
PAPERMILL_INPUT_PATH = None

### Parsing Notebook Name and Logging Run Time

In [None]:
# Parsing the notebook name
if PAPERMILL_INPUT_PATH is not None:
    notebook_path = PAPERMILL_INPUT_PATH
else:
    # If ran interactively, we can get the notbook path this way
    notebook_path = os.getenv('JPY_SESSION_NAME')
notebook_name = os.path.splitext(os.path.basename(notebook_path))[0]
print("Notebook name:", notebook_name)
print("Date Ran:", datetime.now().strftime("%d/%m/%Y %I:%M:%S %p"))

### Config Class

In [None]:
# Empty class to make adding new configuration variables very trivial.  Treated similarly to a dictionary, but with more convenient dot notation.
# The config object will be saved with Pickle alongside the model checkpoints.
class Config:
    pass

## Configuration

In [None]:
config = Config()

# Debugging
config.dry_run = False  # If enabled, model checkpointing and other logging will not be enabled (for debugging and evaluating performance)
config.debug_mode = False  # This will limit the size of each "run" so that the script can be tested end-to-end.  This will write out checkpoints, etc. to validate those features.
config.debug_mode_loop_length = 3

# Important
config.start_from_scratch = True  # If True, all model checkpoints matching derived from the notebook name will be deleted.  If False, training will continue from a checkpoint file.
config.weight_initialization_source = None  # For loading a pretrained model from a different folder.  Not necessary if simply continuing training from a checkpoint in the folder expected from the notebook name.
# config.weight_initialization_source = "weights/N-DEPTH_training_run3_best_L1_rangeGradient_val_loss.pt"
config.ignore_mismatched_weight_sizes = True  # Just in case you are loading some weights from a a model where other model components have been changed (e.g., loading pre-trained encoder weights to help train a new decoder structure)
config.dataset_path = "/path_to_parent_directory_of_FlyingThings3D"  # Set this to the path of the parent directory of the FlyingThings3D dataset

# Precision
config.precision_mode = PrecisionMode.MIXED_PRECISION
# config.precision_mode = PrecisionMode.PYTORCH_DEFAULT
# config.precision_mode = PrecisionMode.FLOAT32_PRECISION
scaler = set_performance_and_precision_settings(config.precision_mode)

# Quantization
config.differentiable_rounding_type = QuantizerProxyType.STRAIGHT_THROUGH
# config.differentiable_rounding_type = QuantizerProxyType.SOFT

# Encoder
config.encoder_activation_function = nn.Mish(inplace=True)
# config.encoder_activation_function = nn.ReLU(inplace=True)
# config.encoder_activation_function = nn.LeakyReLU(negative_slope=0.1, inplace=True)
config.final_encoder_activation_function = FinalEncoderActivationFunction.SINE
# config.final_encoder_activation_function = FinalEncoderActivationFunction.SIGMOID
# config.final_encoder_activation_function = FinalEncoderActivationFunction.NONE  # Probably should only use with MWD encoder, otherwise value range will be invalid for DiffJPEG
config.neural_encoder_width = 2048  # Number of neurons in the largest 1st layer of the encoder MLP

# Decoder
config.decoder_activation_function = nn.Mish(inplace=True)  # Generally converged best
# config.decoder_activation_function = nn.ReLU(inplace=True)
# config.decoder_activation_function = nn.LeakyReLU(negative_slope=0.1, inplace=True)
config.neural_decoder_width = 2048  # Number of neurons in the largest initial layer of the decoder MLP (other layer sizes scaled from this)

# Weight Freezing
config.freeze_encoder = False
config.freeze_decoder = False

# Misc
config.expanded_training_range = 1.02  # "Overprovisioning" the depth range ensures that you can avoid the high-error regions at the beginning and end of the range

# Training Parameters
config.num_training_epochs = 300  # Maximum total number of training epochs among all training runs
config.batch_size = 1  # Note: Batches greater than 1 will likely need to edit a bit of code.  1 image is effectively thousands of samples for this network, so larger batch sizes are less necessary.
config.height = 224  # DiffJPEG-compatible resolutions are, at a minimum multiples of 8 (due to 8x8 image blocks)
config.width = 224
config.random_crops = True
config.depth_normalization_lower_bound = 0  # I didn't notice any benefits to other normalization ranges (e.g., [-1, 1])
config.depth_normalization_upper_bound = 1
config.device = torch.device("cuda") # For NVIDIA cards
# config.device = torch.device("mps") # For M-series Macs
# config.device = torch.device("cpu") # For slow training on CPU

# Image Settings
config.image_compression_type = ImageCompressionType.DIFF_JPEG
# config.image_compression_type = ImageCompressionType.LOSSLESS  # Can train without JPEG compression
config.low_JPEG_quality = 85  # Note: This JPEG quality training range is somewhat arbitrary.
config.high_JPEG_quality = 99.999
config.subsampling = True  # 4:2:0 chroma subsampling, if enabled, otherwise 4:4:4 chroma sampling.  Generally 4:2:0 chroma subsampling will result in encoding functions that are more visually pleasing when plotte in 3D.
config.color_space_conversions = True  # True by default for JPEG, but you can disable this to save directly to YCbCr.  Note: Chroma subsampling is not yet an option when saving directly into YCbCr without color space conversions.
config.gradient_validation_image_compression_type = ImageCompressionType.LOSSLESS  # Note: Lossy compression not yet supported for 100x100 gradient image due to DiffJPEG restrictions on image resolutions.  The easiest workaround would be using a compatible gradient image resolution.
config.target_validation_JPEG_quality = 99.999
config.interpolation_mode = T.functional.InterpolationMode.NEAREST_EXACT  # Important!: Using anything except nearest neighbor interpolation will result in "fake" interpolated depth values, which will greatly affect JPEG artifacting at edges.

# Loss Parameters
config.depth_loss_enabled = True
config.mask_loss_enabled = True
config.depth_gradient_val_interval = 10  # This is then number of mini-batches between the lossless evaluation of a "tilted plane" with 100x100 resolution
config.JPEG_normalization_interval = 10  # This is then number of mini-batches between each calibration step that determines how to normalize losses given a particular JPEG quality setting
config.depth_loss_weighting = 6
config.mask_loss_weighting = 0.06  # Set to a low value to prioritize depth reconstruction quality
config.depth_loss_function = nn.L1Loss()  # Encourages better reconstructions of interior regions (not as distracted by JPEG artifacts along edges)
# config.depth_loss_function = nn.MSELoss()  # Will likely decrease JPEG artifacting, but at the cost of interior depth reconstruction quality.
# config.mask_type = MaskType.BINARY
config.mask_type = MaskType.ERROR_THRESHOLDED_BINARY
# config.mask_type = MaskType.ERROR_MAP
config.mask_loss_function = nn.BCEWithLogitsLoss()
# config.mask_loss_function = nn.L1Loss()  # Would be an appropriate option if attempting to generate an error map (i.e., not a binary mask)
config.error_threshold_for_masking = 0.03  # If decoded error is >3% of normalized range, the network will be trained to filter out this type of corrupted pixel
config.weight_decay = 0  # Regularization, if needed

# Optimizer/Scheduler
config.optimizer_type = optim.Adam
# config.optimizer_type = optim.SGD
config.scheduler_type = optim.lr_scheduler.CosineAnnealingWarmRestarts  # Cyclic and more aggressive (since overfitting is not a high risk)
# config.scheduler_type = optim.lr_scheduler.ReduceLROnPlateau
config.starting_lr = 1e-4
config.cosine_annealing_scheduler_period = 5000
config.cosine_annealing_scheduler_restart_factor = 1  # Set to less than 1 if you want the peak learning rate to decrease after every cycle
config.lr_gamma = 0.5 # (for metric-based scheduler ony) This is the factor the learning rate decreases by after the metric doesn't improve for some time
config.patience = 50000  # (for metric-based scheduler only) The number of iterations that must pass without metric improvement for the learning rate to decrease

# Training Loop Parameters
config.stalled_val_patience = 10  # Training will end after this many epochs is the validation loss has not improved by more than the threshold amount over that time
config.stalled_val_threshold = 1e-6
config.max_epochs_per_run = 100  # Training run will end after this number of epochs, even if progress is still somehow being made
config.num_training_epochs = 3 * config.max_epochs_per_run  # Just used to initialize some metric arrays

# Logging
config.encoded_image_log_interval = 500  # Measured in mini-batches
config.tensorboard_log_interval = 500

# If runninning in Papermill, the TQDM print frequency will be decreased
if PAPERMILL_INPUT_PATH is not None:
    config.print_interval = 500  # Lower print frequency when ran with Papermill
else:
    config.print_interval = 20  # Higher print frequency for interactive runs

if config.debug_mode:
    config.print_interval = 1  # Highest print frequency for debug runs (since they may be very short)

## Functions and Classes

In [None]:
# Since the network can improve very rapidly in the first few epochs (causing the generation of a metric-based checkpoint file), caching checkpoint files in memory until the end of the epoch is significantly faster.
def cache_checkpoint(config, epoch, net, optimizer, scheduler, scaler,
                    best_depth_range_gradient_L1_loss, best_max_depth_range_gradient_L1_loss, 
                    best_avg_L1_loss_on_JPEG_normalization_image, best_avg_RMSE_loss_on_JPEG_normalization_image, best_avg_mask_loss_on_JPEG_normalization_image,
                    best_depth_range_gradient_L1_losses_by_epoch, best_max_depth_range_gradient_L1_losses_by_epoch, best_avg_L1_losses_on_JPEG_normalization_by_epoch, 
                    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch, best_avg_mask_loss_on_JPEG_normalization_image_by_epoch):
    global global_step

    if config.precision_mode == PrecisionMode.MIXED_PRECISION:
        cached_checkpoint = {
                                'epoch': epoch,
                                'global_step': global_step,
                                'model_state_dict': net.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict(),
                                'scaler_state_dict': scaler.state_dict(),
                                'best_depth_range_gradient_L1_loss': best_depth_range_gradient_L1_loss,
                                'best_max_depth_range_gradient_L1_loss': best_max_depth_range_gradient_L1_loss,
                                'best_avg_L1_loss_on_JPEG_normalization_image': best_avg_L1_loss_on_JPEG_normalization_image,
                                'best_avg_RMSE_loss_on_JPEG_normalization_image': best_avg_RMSE_loss_on_JPEG_normalization_image,
                                'best_avg_mask_loss_on_JPEG_normalization_image': best_avg_mask_loss_on_JPEG_normalization_image,
                                'best_depth_range_gradient_L1_losses_by_epoch': best_depth_range_gradient_L1_losses_by_epoch,
                                'best_max_depth_range_gradient_L1_losses_by_epoch': best_max_depth_range_gradient_L1_losses_by_epoch,
                                'best_avg_L1_losses_on_JPEG_normalization_by_epoch': best_avg_L1_losses_on_JPEG_normalization_by_epoch,
                                'best_avg_RMSE_losses_on_JPEG_normalization_by_epoch': best_avg_RMSE_losses_on_JPEG_normalization_by_epoch,
                                'best_avg_mask_loss_on_JPEG_normalization_image_by_epoch': best_avg_mask_loss_on_JPEG_normalization_image_by_epoch,
                            }
    else:
        cached_checkpoint = {
                                'epoch': epoch,
                                'global_step': global_step,
                                'model_state_dict': net.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict(),
                                'best_depth_range_gradient_L1_loss': best_depth_range_gradient_L1_loss,
                                'best_max_depth_range_gradient_L1_loss': best_max_depth_range_gradient_L1_loss,
                                'best_avg_L1_loss_on_JPEG_normalization_image': best_avg_L1_loss_on_JPEG_normalization_image,
                                'best_avg_RMSE_loss_on_JPEG_normalization_image': best_avg_RMSE_loss_on_JPEG_normalization_image,
                                'best_avg_mask_loss_on_JPEG_normalization_image': best_avg_mask_loss_on_JPEG_normalization_image,
                                'best_depth_range_gradient_L1_losses_by_epoch': best_depth_range_gradient_L1_losses_by_epoch,
                                'best_max_depth_range_gradient_L1_losses_by_epoch': best_max_depth_range_gradient_L1_losses_by_epoch,
                                'best_avg_L1_losses_on_JPEG_normalization_by_epoch': best_avg_L1_losses_on_JPEG_normalization_by_epoch,
                                'best_avg_RMSE_losses_on_JPEG_normalization_by_epoch': best_avg_RMSE_losses_on_JPEG_normalization_by_epoch,
                                'best_avg_mask_loss_on_JPEG_normalization_image_by_epoch': best_avg_mask_loss_on_JPEG_normalization_image_by_epoch,
                            }

    cached_checkpoint = copy.deepcopy(cached_checkpoint)

    return cached_checkpoint

In [None]:
# Saves a cached checkpoint to disk (with some retry attempts if this fails for some reason)
def save_checkpoint(cached_checkpoint, checkpoint_save_name, config, weight_save_directory, notebook_name, training_run_number):
    
    global global_step
    
    if not config.dry_run:
        max_retries = 10
        retry_delay = 5  # seconds

        full_checkpoint_save_name = weight_save_directory + notebook_name + "_run%i_" % training_run_number + checkpoint_save_name + ".pt" 
            
        for attempt in range(max_retries):
            try:
                torch.save(cached_checkpoint, full_checkpoint_save_name)
            except RuntimeError as e:
                if attempt < max_retries - 1:
                    print(f"Error occurred while saving checkpoint. Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                else:
                    print("Max retries reached. Skipping checkpoint save.")

In [None]:
# Loads a model checkpoint file (model weights, optimizer/scheduler states, and some metrics that have been saved alongside it)
def load_checkpoint_file(checkpoint_path):
    global global_step
    global net
    global optimizer
    global scheduler
    global scaler
    global starting_epoch
    global best_max_depth_range_gradient_L1_loss
    global best_depth_range_gradient_L1_loss
    global best_avg_L1_loss_on_JPEG_normalization_image
    global best_avg_RMSE_loss_on_JPEG_normalization_image
    global best_depth_range_gradient_L1_losses_by_epoch
    global best_max_depth_range_gradient_L1_losses_by_epoch
    global best_avg_L1_losses_on_JPEG_normalization_by_epoch
    global best_avg_RMSE_losses_on_JPEG_normalization_by_epoch
    
    checkpoint = torch.load(checkpoint_path)
    
    # Restore the state dictionaries
    if config.ignore_mismatched_weight_sizes:
        # Use this instead if there are size mismatches in weights, but you want to load the parts of the model that do match
        current_model_dict = net.state_dict()
        loaded_state_dict = checkpoint['model_state_dict']
        new_state_dict = {k:v if v.size() == current_model_dict[k].size()  else  current_model_dict[k] for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())}
        net.load_state_dict(new_state_dict, strict=False)    
    else:
        net.load_state_dict(checkpoint['model_state_dict'], strict=False)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        if 'scaler_state_dict' in checkpoint and scaler is not None:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
    
    # Restore other variables
    starting_epoch = checkpoint['epoch'] + 1 
    global_step = checkpoint['global_step'] + 1
    
    best_max_depth_range_gradient_L1_loss = checkpoint['best_max_depth_range_gradient_L1_loss']
    best_depth_range_gradient_L1_loss = checkpoint['best_depth_range_gradient_L1_loss']
    best_avg_L1_loss_on_JPEG_normalization_image = checkpoint['best_avg_L1_loss_on_JPEG_normalization_image']
    best_avg_RMSE_loss_on_JPEG_normalization_image = checkpoint['best_avg_RMSE_loss_on_JPEG_normalization_image']
    
    best_depth_range_gradient_L1_losses_by_epoch = checkpoint['best_depth_range_gradient_L1_losses_by_epoch']
    best_max_depth_range_gradient_L1_losses_by_epoch = checkpoint['best_max_depth_range_gradient_L1_losses_by_epoch']
    best_avg_L1_losses_on_JPEG_normalization_by_epoch = checkpoint['best_avg_L1_losses_on_JPEG_normalization_by_epoch']
    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch = checkpoint['best_avg_RMSE_losses_on_JPEG_normalization_by_epoch']
    
    print("Checkpoint loaded successfully.")
    print("Epoch:", starting_epoch)
    print("Global step:", global_step)
    print("Best Max Depth Range Gradient Loss:", best_max_depth_range_gradient_L1_loss * 1000)
    print("Best Depth Range Gradient L1 Loss:", best_depth_range_gradient_L1_loss * 1000)
    print("Best Avg L1 Loss on JPEG Normalization Image:", best_avg_L1_loss_on_JPEG_normalization_image * 1000)
    print("Best Avg RMSE Loss on JPEG Normalization Image:", best_avg_RMSE_loss_on_JPEG_normalization_image * 1000)

In [None]:
# Helper function to find the latest epoch checkpoint
def find_highest_epoch_checkpoint(weight_save_directory, notebook_name, training_run_number):
    # Construct the checkpoint file pattern
    epoch_checkpoint_file_path_pattern = f"{weight_save_directory}{notebook_name}_run{training_run_number}_epoch*.pt"

    # Find all checkpoint files matching the pattern
    epoch_checkpoint_files = glob(epoch_checkpoint_file_path_pattern)
    
    if not epoch_checkpoint_files:
        return None
    else:
        # Extract the epoch numbers from the checkpoint file names
        epoch_numbers = []
        for epoch_checkpoint_file in epoch_checkpoint_files:
            match = re.search(r'epoch(\d+)', epoch_checkpoint_file)
            if match:
                epoch_number = int(match.group(1))
                epoch_numbers.append(epoch_number)
        
        # Find the highest epoch number
        highest_epoch = max(epoch_numbers)
        
        # Construct the highest epoch checkpoint file name
        highest_epoch_checkpoint = f"{weight_save_directory}{notebook_name}_run{training_run_number}_epoch{highest_epoch}.pt"
        
        return highest_epoch_checkpoint

In [None]:
def create_images_folder_structure(config, notebook_name):
    # Base directory for images
    base_dir = "./images/" + notebook_name

    if config.start_from_scratch and os.path.exists(base_dir) and not (config.dry_run or config.debug_mode):
        shutil.rmtree(base_dir)

    # List of subdirectories to create
    subdirectories = [
        "neural_encoding_for_ones_mask",
        "neural_encoding_for_zeros_mask",
    ]

    # Create each subdirectory if it does not exist
    for subdir in subdirectories:
        dir_path = os.path.join(base_dir, subdir)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

In [None]:
# Deletes the contents of a directory
def clear_directory(directory):
    for item in os.listdir(directory):
        item_path = os.path.join(directory, item)
        if os.path.isfile(item_path) or os.path.islink(item_path):
            os.unlink(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

In [None]:
# If training is interrupted, this function finds a recent checkpoint file.  It's also used on subsequent training runs to ascertain the current training run number.
def determine_training_run_number(config, save_directory_all_runs):
    if not os.path.exists(save_directory_all_runs):
        os.makedirs(save_directory_all_runs)

    training_run_number = 1
    checkpoint_file_path = None

    # If starting from scratch, delete all weight files and start over
    if config.start_from_scratch:
        print("Starting from scratch.  Deleting existing runs.")
        clear_directory(save_directory_all_runs)
        weight_save_directory = save_directory_all_runs + "/run_%i/" % training_run_number
        if not config.dry_run:
            os.makedirs(weight_save_directory, exist_ok=True)
        config_file_path = weight_save_directory + notebook_name + "_training_configuration_run%i.pt" % training_run_number
    else:
        while True:
            weight_save_directory = save_directory_all_runs + "/run_%i/" % training_run_number
            # If the weight save directory exists and is not empty
            if os.path.exists(weight_save_directory) and not os.listdir(weight_save_directory) == []:
                config_file_path = weight_save_directory + notebook_name + "_training_configuration_run%i.pt" % training_run_number
                if os.path.exists(config_file_path):
                    with open(config_file_path, 'rb') as handle:
                        loaded_config = pickle.load(handle)
                    if not loaded_config.ran_to_completion:
                        potential_checkpoint_file_path_1 = weight_save_directory + notebook_name + "_run%i_" % training_run_number + "best_L1_rangeGradient_val_loss.pt"
                        potential_checkpoint_file_path_2 = find_highest_epoch_checkpoint(weight_save_directory, notebook_name, training_run_number)
                        # Preference given to the checkpoint corresponding to the best L1 depth loss on the range gradient image (with no image compression).
                        if os.path.exists(potential_checkpoint_file_path_1):
                            checkpoint_file_path = potential_checkpoint_file_path_1
                        elif os.path.exists(potential_checkpoint_file_path_2):
                            checkpoint_file_path = potential_checkpoint_file_path_2
                        else:
                            raise FileNotFoundError("Last run not ran to completion, but pre-trained checkpoint file not found!")

                        # Make sure to start a new run if the precision mode is changing
                        if config.precision_mode != config.precision_mode:
                            training_run_number += 1
                        else:
                            break
                    else:
                        if not config.start_from_scratch:
                            potential_checkpoint_file_path_1 = weight_save_directory + notebook_name + "_run%i_" % training_run_number + "best_L1_rangeGradient_val_loss.pt"
                            potential_checkpoint_file_path_2 = find_highest_epoch_checkpoint(weight_save_directory, notebook_name, training_run_number)
                            if os.path.exists(potential_checkpoint_file_path_1):
                                checkpoint_file_path = potential_checkpoint_file_path_1
                            elif os.path.exists(potential_checkpoint_file_path_2):
                                checkpoint_file_path = potential_checkpoint_file_path_2
                            else:
                                raise FileNotFoundError("Pre-trained checkpoint file not found!")
                        training_run_number += 1
                else:
                    print(config_file_path)
                    print("Couldn't find config file " + config_file_path + ", so starting this run over from scratch")
                    break
            # If the weight save directory doesn't exist or is empty
            else:
                if not config.dry_run:
                    os.makedirs(weight_save_directory, exist_ok=True)
                break
    
            # Prevent a runaway loop in case of any issues
            if training_run_number >= 10:
                raise Exception("Training run number greater than 10 seems to indicate an issue in determine_training_run_number()")
                
    return training_run_number, weight_save_directory, config, checkpoint_file_path

In [None]:
# This simply allows the form of differentiable rounding to be easily configurable
def diff_round_verbose(x, differentiable_rounding_type = QuantizerProxyType.STRAIGHT_THROUGH):
    if differentiable_rounding_type == QuantizerProxyType.SOFT:
        return torch.round(x) + (x - torch.round(x))**3
    elif differentiable_rounding_type == QuantizerProxyType.STRAIGHT_THROUGH:
        return x + (torch.round(x) - x).detach()

# Make sure to use this instead of real rounding when training
diff_round = partial(diff_round_verbose, differentiable_rounding_type=config.differentiable_rounding_type)

### DiffJPEG Functions (lightly modified from https://github.com/mlomnitz/DiffJPEG)

#### DiffJPEG License:
MIT License

Copyright (c) 2021 Michael R Lomnitz

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

In [None]:
# From utils.py of DiffJPEG

y_table = np.array(
    [[16, 11, 10, 16, 24, 40, 51, 61],
     [12, 12, 14, 19, 26, 58, 60, 55], 
     [14, 13, 16, 24, 40, 57, 69, 56],
     [14, 17, 22, 29, 51, 87, 80, 62],
     [18, 22, 37, 56, 68, 109, 103, 77],
     [24, 35, 55, 64, 81, 104, 113, 92],
     [49, 64, 78, 87, 103, 121, 120, 101],
     [72, 92, 95, 98, 112, 100, 103, 99]],
    dtype=np.float32).T
y_table = nn.Parameter(torch.from_numpy(y_table))

c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array([[17, 18, 24, 47],
                            [18, 21, 26, 66],
                            [24, 26, 56, 99],
                            [47, 66, 99, 99]]).T
c_table = nn.Parameter(torch.from_numpy(c_table))

def quality_to_factor(quality):
    """ Calculate factor corresponding to quality
    Input:
        quality(float): Quality for jpeg compression
    Output:
        factor(float): Compression factor
    """
    if quality < 50:
        quality = 5000. / quality
    else:
        quality = 200. - quality*2
    return quality / 100.

In [None]:
# From compression.py of DiffJPEG

class rgb_to_ycbcr_jpeg(nn.Module):
    """ Converts RGB image to YCbCr
    Input:
        image(tensor): batch x 3 x height x width
    Outpput:
        result(tensor): batch x height x width x 3
    """
    def __init__(self, color_space_conversions):  # Modification to DiffJPEG
        super(rgb_to_ycbcr_jpeg, self).__init__()
        matrix = np.array(
            [[0.299, 0.587, 0.114], 
             [-0.168736, -0.331264, 0.5],
             [0.5, -0.418688, -0.081312]], dtype=np.float32).T
        self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))

        self.matrix = nn.Parameter(torch.from_numpy(matrix))
        self.color_space_conversions = color_space_conversions  # Modification to DiffJPEG

    def forward(self, image):
        result = image.permute(0, 2, 3, 1)
        
        if self.color_space_conversions:  # Modification to DiffJPEG
            result = torch.tensordot(result, self.matrix, dims=1) + self.shift
            result.view(image.shape)
        
        return result

class chroma_subsampling(nn.Module):
    """ Chroma subsampling on CbCv channels
    Input:
        image(tensor): batch x height x width x 3
    Output:
        y(tensor): batch x height x width
        cb(tensor): batch x height/2 x width/2
        cr(tensor): batch x height/2 x width/2
    """
    def __init__(self, subsampling):
        super(chroma_subsampling, self).__init__()  # Modification to DiffJPEG
        self.subsampling = subsampling  # Modification to DiffJPEG

    def forward(self, image):
        if self.subsampling:  # Modification to DiffJPEG
            image_2 = image.permute(0, 3, 1, 2).clone()
            avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2),
                                    count_include_pad=False)
            cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1))
            cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1))
            cb = cb.permute(0, 2, 3, 1)
            cr = cr.permute(0, 2, 3, 1)
            
            return image[:, :, :, 0], cb, cr
        else:
            return image[:, :, :, 0], image[:, :, :, 1], image[:, :, :, 2]  # Modification to DiffJPEG


class block_splitting(nn.Module):
    """ Splitting image into patches
    Input:
        image(tensor): batch x height x width
    Output: 
        patch(tensor):  batch x h*w/64 x h x w
    """
    def __init__(self):
        super(block_splitting, self).__init__()
        self.k = 8

    def forward(self, image):
        height, width = image.shape[1:3]
        batch_size = image.shape[0]
        image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
        image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
        return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
    

class dct_8x8(nn.Module):
    """ Discrete Cosine Transformation
    Input:
        image(tensor): batch x height x width
    Output:
        dcp(tensor): batch x height x width
    """
    def __init__(self):
        super(dct_8x8, self).__init__()
        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
        for x, y, u, v in itertools.product(range(8), repeat=4):
            tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
                (2 * y + 1) * v * np.pi / 16)
        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
        #
        self.tensor =  nn.Parameter(torch.from_numpy(tensor).float())
        self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() )
        
    def forward(self, image):
        image = image - 128
        result = self.scale * torch.tensordot(image, self.tensor, dims=2)
        result.view(image.shape)
        return result


class y_quantize(nn.Module):
    """ JPEG Quantization for Y channel
    Input:
        image(tensor): batch x height x width
        rounding(function): rounding function to use
        factor(float): Degree of compression
    Output:
        image(tensor): batch x height x width
    """
    def __init__(self):
        super(y_quantize, self).__init__()
        self.y_table = y_table

    def forward(self, image, factor, rounding):  # Modification to DiffJPEG
        image = image.float() / (self.y_table * factor)
        image = rounding(image)
        return image


class c_quantize(nn.Module):
    """ JPEG Quantization for CrCb channels
    Input:
        image(tensor): batch x height x width
        rounding(function): rounding function to use
        factor(float): Degree of compression
    Output:
        image(tensor): batch x height x width
    """
    def __init__(self):
        super(c_quantize, self).__init__()
        self.c_table = c_table

    def forward(self, image, factor, rounding):  # Modification to DiffJPEG
        image = image.float() / (self.c_table * factor)
        image = rounding(image)
        return image


class compress_jpeg(nn.Module):
    """ Full JPEG compression algortihm
    Input:
        imgs(tensor): batch x 3 x height x width
        rounding(function): rounding function to use
    Ouput:
        compressed(dict(tensor)): batch x h*w/64 x 8 x 8
    """
    def __init__(self, subsampling=True, color_space_conversions=True):  # Modification to DiffJPEG
        super(compress_jpeg, self).__init__()
        self.l1 = nn.Sequential(
            rgb_to_ycbcr_jpeg(color_space_conversions),
            chroma_subsampling(subsampling)
        )
        self.l2 = nn.Sequential(
            block_splitting(),
            dct_8x8()
        )
        self.c_quantize = c_quantize()
        self.y_quantize = y_quantize()
        self.subsampling = subsampling  # Modification to DiffJPEG

    def forward(self, image, factor, rounding):
        y, cb, cr = self.l1(image*255)
        components = {'y': y, 'cb': cb, 'cr': cr}
        for k in components.keys():
            comp = self.l2(components[k])
            if k in ('cb', 'cr') and self.subsampling:  # Modification to DiffJPEG
                comp = self.c_quantize(comp, factor, rounding)  # Modification to DiffJPEG
            else:
                comp = self.y_quantize(comp, factor, rounding)  # Modification to DiffJPEG

            components[k] = comp

        return components['y'], components['cb'], components['cr']

In [None]:
# From decompression.py of DiffJPEG

class y_dequantize(nn.Module):
    """ Dequantize Y channel
    Inputs:
        image(tensor): batch x height x width
    Outputs:
        image(tensor): batch x height x width
    """
    def __init__(self):
        super(y_dequantize, self).__init__()
        self.y_table = y_table

    def forward(self, image, factor):  # Modification to DiffJPEG
        return image * (self.y_table * factor)


class c_dequantize(nn.Module):
    """ Dequantize CbCr channel
    Inputs:
        image(tensor): batch x height x width
    Outputs:
        image(tensor): batch x height x width
    """
    def __init__(self):
        super(c_dequantize, self).__init__()
        self.c_table = c_table

    def forward(self, image, factor):  # Modification to DiffJPEG
        return image * (self.c_table * factor)


class idct_8x8(nn.Module):
    """ Inverse discrete Cosine Transformation
    Input:
        dcp(tensor): batch x height x width
    Output:
        image(tensor): batch x height x width
    """
    def __init__(self):
        super(idct_8x8, self).__init__()
        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
        self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
        for x, y, u, v in itertools.product(range(8), repeat=4):
            tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
                (2 * v + 1) * y * np.pi / 16)
        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())

    def forward(self, image):
        image = image * self.alpha
        result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
        result.view(image.shape)
        return result


class block_merging(nn.Module):
    """ Merge pathces into image
    Inputs:
        patches(tensor) batch x height*width/64, height x width
        height(int)
        width(int)
    Output:
        image(tensor): batch x height x width
    """
    def __init__(self):
        super(block_merging, self).__init__()
        
    def forward(self, patches, height, width):
        k = 8
        batch_size = patches.shape[0]
        image_reshaped = patches.view(batch_size, height//k, width//k, k, k)
        image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
        return image_transposed.contiguous().view(batch_size, height, width)


class chroma_upsampling(nn.Module):
    """ Upsample chroma layers
    Input: 
        y(tensor): y channel image
        cb(tensor): cb channel
        cr(tensor): cr channel
    Ouput:
        image(tensor): batch x height x width x 3
    """
    def __init__(self, subsampling):
        super(chroma_upsampling, self).__init__()
        self.subsampling = subsampling

    def forward(self, y, cb, cr):
        if self.subsampling:  # Modification to DiffJPEG
#             def repeat(x, k=2):
#                 height, width = x.shape[1:3]
#                 x = x.unsqueeze(-1)
#                 x = x.repeat(1, 1, k, k)
#                 x = x.view(-1, height * k, width * k)
#                 return x

#             cb = repeat(cb)
#             cr = repeat(cr)
            cb = torch.squeeze(nn.Upsample(scale_factor=2, mode='bilinear')(torch.unsqueeze(cb, 0)), 0)  # Modification to DiffJPEG
            cr = torch.squeeze(nn.Upsample(scale_factor=2, mode='bilinear')(torch.unsqueeze(cr, 0)), 0)  # Modification to DiffJPEG
        
        return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)


class ycbcr_to_rgb_jpeg(nn.Module):
    """ Converts YCbCr image to RGB JPEG
    Input:
        image(tensor): batch x height x width x 3
    Outpput:
        result(tensor): batch x 3 x height x width
    """
    def __init__(self, color_space_conversions=True):  # Modification to DiffJPEG
        super(ycbcr_to_rgb_jpeg, self).__init__()

        matrix = np.array(
            [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
            dtype=np.float32).T
        self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
        self.matrix = nn.Parameter(torch.from_numpy(matrix))
        
        self.color_space_conversions = color_space_conversions  # Modification to DiffJPEG

    def forward(self, image):
        if self.color_space_conversions:  # Modification to DiffJPEG
            result = torch.tensordot(image + self.shift, self.matrix, dims=1)
            result.view(image.shape)
        else:
            result = image  # Modification to DiffJPEG

        return result.permute(0, 3, 1, 2)


class decompress_jpeg(nn.Module):
    """ Full JPEG decompression algortihm
    Input:
        compressed(dict(tensor)): batch x h*w/64 x 8 x 8
    Ouput:
        image(tensor): batch x 3 x height x width
    """
    def __init__(self, height, width, subsampling=True, color_space_conversions=True):  # Modification to DiffJPEG
        super(decompress_jpeg, self).__init__()
        self.c_dequantize = c_dequantize()
        self.y_dequantize = y_dequantize()
        self.idct = idct_8x8()
        self.merging = block_merging()
        self.chroma = chroma_upsampling(subsampling)  # Modification to DiffJPEG
        self.colors = ycbcr_to_rgb_jpeg(color_space_conversions)  # Modification to DiffJPEG
        
        self.height, self.width = height, width
        self.subsampling = subsampling
        
    def forward(self, y, cb, cr, factor, rounding):
        components = {'y': y, 'cb': cb, 'cr': cr}
        for k in components.keys():
            if k in ('cb', 'cr') and self.subsampling:  # Modification to DiffJPEG
                comp = self.c_dequantize(components[k], factor)  # Modification to DiffJPEG
                height, width = int(self.height/2), int(self.width/2)    
            else:  # Modification to DiffJPEG
                comp = self.y_dequantize(components[k], factor)  # Modification to DiffJPEG
                height, width = self.height, self.width  # Modification to DiffJPEG
            comp = self.idct(comp)
            components[k] = self.merging(comp, height, width)
            #
        image = self.chroma(components['y'], components['cb'], components['cr'])
        image = self.colors(image)
        
        image = torch.min(255*torch.ones_like(image), torch.max(torch.zeros_like(image), image))
#         return image/255  # Modification to DiffJPEG
        image = rounding(image)  # Modification to DiffJPEG
        
        return image/255

In [None]:
# DiffJPEG with some modifications:
#     -subsampling: True = 4:2:0 chroma subsampling, False = 4:4:4 chroma sampling (i.e., no chroma subsampling)
#     -color_space_conversions: True = RGB->YCbCr and YCbCr->RGB color space conversions (typical JPEG behavior), False = Input image expected to be in YCbCr, so no color space conversions performed
#     -forward() function modified to allow for dynamic changes of the JPEG quality and whether or not differentiable rounding is enabled
#     -Under the hood, the chroma upsampling operation was also corrected to use bilinear upsampling (much closer to a typical JPEG implementation).
class DiffJPEG(nn.Module):
    def __init__(self, height, width, subsampling=True, color_space_conversions=True):
        ''' Initialize the DiffJPEG layer
        Inputs:
            height(int): Original image hieght
            width(int): Original image width
            differentiable(bool): If true uses custom differentiable
                rounding function, if false uses standrard torch.round
            quality(float): Quality factor for jpeg compression scheme. 
        '''
        super(DiffJPEG, self).__init__()
        self.compress = compress_jpeg(subsampling=subsampling, color_space_conversions=color_space_conversions)
        self.decompress = decompress_jpeg(height, width, subsampling=subsampling, color_space_conversions=color_space_conversions)

    def forward(self, x, quality, differentiable):
        rounding = None
        if differentiable:
            rounding = diff_round
        else:
            rounding = torch.round
        
        factor = quality_to_factor(quality)
        y, cb, cr = self.compress(x, factor, rounding)
        recovered = self.decompress(y, cb, cr, factor, rounding)
        return recovered, y, cb, cr

## FlyingThings3D Dataset

In [None]:
_read_pfm_file = partial(_read_pfm, slice_channels=1)  # For reading PFM (Portable FloatMap) files from FlyingThings3D dataset

# Modified version of a PyTorch dataloader
class FlyingThingsDataset(ABC, VisionDataset):
    """Base interface for Stereo matching datasets"""

    _has_built_in_disparity_mask = False

    def __init__(self, root: str, train_or_test: str, random_crops, transforms: Optional[Callable] = None) -> None:
        """
        Args:
            root(str): Root directory of the dataset.
            transforms(callable, optional): A function/transform that takes in Tuples of
                (images, disparities, valid_masks) and returns a transformed version of each of them.
                images is a Tuple of (``PIL.Image``, ``PIL.Image``)
                disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
                valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
                In some cases, when a dataset does not provide disparities, the ``disparities`` and
                ``valid_masks`` can be Tuples containing None values.
                For training splits generally the datasets provide a minimal guarantee of
                images: (``PIL.Image``, ``PIL.Image``)
                disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
                Optionally, based on the dataset, it can return a ``mask`` as well:
                valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
                For some test splits, the datasets provides outputs that look like:
                imgaes: (``PIL.Image``, ``PIL.Image``)
                disparities: (``None``, ``None``)
                Optionally, based on the dataset, it can return a ``mask`` as well:
                valid_masks: (``None``, ``None``)
        """
        super().__init__(root=root)
        self.transforms = transforms  # You can load up the transforms argument with you data augmentation

        self._images = []  # type: ignore
        self._disparities = []  # type: ignore

        self.train_or_test = train_or_test
        self.random_crops = random_crops
        
        root = Path(root)
        
        left_disparity_pattern = str(root / "FlyingThings3D/disparity" / train_or_test / "*/*/left/*.pfm")
        #left_image_pattern = str(root / "FlyingThings3D" / "*" / "*" / "*" / "left" / "*.png")
        right_disparity_pattern = str(root / "FlyingThings3D/disparity" / train_or_test / "*/*/right/*.pfm") 
        
        self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)

    def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
        img = Image.open(file_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        return img

    def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None,
    ) -> List[Tuple[str, Optional[str]]]:
        left_paths = list(sorted(glob(paths_left_pattern)))

        right_paths: List[Union[None, str]]
        if paths_right_pattern:
            right_paths = list(sorted(glob(paths_right_pattern)))
        else:
            right_paths = list(None for _ in left_paths)
        
        if not left_paths:
            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")

        if not right_paths:
            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")

        if len(left_paths) != len(right_paths):
            raise ValueError(
                f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
                f"left pattern: {paths_left_pattern}\n"
                f"right pattern: {paths_right_pattern}\n"
            )

        paths = list((left, right) for left, right in zip(left_paths, right_paths))
        return paths
    
    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
        disparity_map = _read_pfm_file(file_path)
        disparity_map = np.abs(disparity_map)  # Ensure that the disparity is positive
        valid_mask = None
        return disparity_map

    def __getitem__(self, index: int) -> np.ndarray:
        """Return example at given index.

        Args:
            index(int): The index of the example to retrieve

        Returns:
            tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
                can be a numpy boolean mask of shape (H, W) if the dataset provides a file
                indicating which disparity pixels are valid. The disparity is a numpy array of
                shape (1, H, W) and the images are PIL images. ``disparity`` is None for
                datasets on which for ``split="test"`` the authors did not provide annotations.
        """
                
        if random.random() < 0.5:
            dsp_map = self._read_disparity(self._disparities[index][0])
        else:
            dsp_map = self._read_disparity(self._disparities[index][1])
        
        #dsp_maps = (dsp_map_left, dsp_map_right)
        dsp_maps = torch.from_numpy(dsp_map)

        if self.random_crops:
            x_crop_ratio = random.uniform(0.16875, 1.0)
            y_crop_ratio = random.uniform(0.3, 1.0)
            crop_size = (int(540 * y_crop_ratio), int(960 * x_crop_ratio))
                
            # Random crop
            i, j, h, w = T.RandomCrop.get_params(dsp_maps, output_size=crop_size)
            dsp_maps = T.functional.crop(dsp_maps, i, j, h, w)

        if self.transforms is not None:
            dsp_maps = self.transforms(dsp_maps)
        
        dsp_maps[dsp_maps <= 0] = 1e-6  # In case there are any disparities that equal 0, we need to make them a small positive value so that depth is never calculated as NaN (for depth normalization purposes)
            
        depth_maps = 1 / dsp_maps

        return depth_maps

    def __len__(self) -> int:
        return len(self._disparities)

### Neural Depth Autoencoder

In [None]:
class NeuralDepthAutoencoder(nn.Module):
    def __init__(self, config, height, width, subsampling=True, color_space_conversions=True, device=torch.device("cuda")):
        super(NeuralDepthAutoencoder, self).__init__()
        self.device = device

        # Save a copy of relevant config variables to the model object
        self.update_config(config)
                
        # ENCODER
        self.encoder = nn.Sequential(
            nn.Conv2d(2, self.neural_encoder_width, kernel_size=1, stride=1, padding=0),
            self.encoder_activation_function,
            nn.Conv2d(self.neural_encoder_width, self.neural_encoder_width // 2, kernel_size=1, stride=1, padding=0),
            self.encoder_activation_function,
            nn.Conv2d(self.neural_encoder_width // 2, self.neural_encoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.encoder_activation_function,
            nn.Conv2d(self.neural_encoder_width // 4, self.neural_encoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.encoder_activation_function,
            nn.Conv2d(self.neural_encoder_width // 4, self.neural_encoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.encoder_activation_function,
            nn.Conv2d(self.neural_encoder_width // 4, self.neural_encoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.encoder_activation_function,
            nn.Conv2d(self.neural_encoder_width // 4, self.neural_encoder_width // 32, kernel_size=1, stride=1, padding=0),
            self.encoder_activation_function,
            nn.Conv2d(self.neural_encoder_width // 32, 3, kernel_size=1, stride=1, padding=0)
        )

        # DECODER
        self.decoder = nn.Sequential(
            nn.Conv2d(3, self.neural_decoder_width, kernel_size=1, stride=1, padding=0),
            self.decoder_activation_function,
            nn.Conv2d(self.neural_decoder_width, self.neural_decoder_width // 2, kernel_size=1, stride=1, padding=0),
            self.decoder_activation_function,
            nn.Conv2d(self.neural_decoder_width // 2, self.neural_decoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.decoder_activation_function,
            nn.Conv2d(self.neural_decoder_width // 4, self.neural_decoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.decoder_activation_function,
            nn.Conv2d(self.neural_decoder_width // 4, self.neural_decoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.decoder_activation_function,
            nn.Conv2d(self.neural_decoder_width // 4, self.neural_decoder_width // 4, kernel_size=1, stride=1, padding=0),
            self.decoder_activation_function,
            nn.Conv2d(self.neural_decoder_width // 4, self.neural_decoder_width // 32, kernel_size=1, stride=1, padding=0),
            self.decoder_activation_function,
            nn.Conv2d(self.neural_decoder_width // 32, 2, kernel_size=1, stride=1, padding=0)
        )
        
        self.JPEG = DiffJPEG(height, width, subsampling, color_space_conversions)

        # Weight Freezing
        self.freeze_weights(config)
        self.JPEG.requires_grad_(False)  # VERY IMPORTANT TO INCLUDE: DiffJPEG includes some parameters that could be "trained", which would give the illusion of good results.

    def update_config(self, config):
        self.neural_encoder_width = config.neural_encoder_width
        self.encoder_activation_function = config.encoder_activation_function
        self.final_encoder_activation_function = config.final_encoder_activation_function

        self.neural_decoder_width = config.neural_decoder_width
        self.decoder_activation_function = config.decoder_activation_function
        
        self.depth_loss_enabled = config.depth_loss_enabled
        self.mask_loss_enabled = config.mask_loss_enabled

        self.freeze_encoder = config.freeze_encoder
        self.freeze_decoder = config.freeze_decoder
    
    def freeze_weights(self, config):
        self.update_config(config)
        
        self.encoder.requires_grad_(not self.freeze_encoder)
        self.decoder.requires_grad_(not self.freeze_decoder)
    
    def forward(self, x, quality, differentiable_rounding=True, image_compression_type=ImageCompressionType.DIFF_JPEG, perform_encoding=True, perform_decoding=True):
        # For downstream convenience during testing (where video frames may be being used instead of still images), both the encoding and decoding steps can be enabled/disabled in the forward() pass
        if perform_encoding:
            # Split apart the normalized depth input and the mask layer (1's where the "foreground" is and 0's for the "background")
            z_norm, input_mask = x.split(1, dim=1)

            # Depth + Mask layers => Color Pixels
            encoding = self.encoder(x)

            # Apply an activation function to ensure [0,1] range
            if self.final_encoder_activation_function is FinalEncoderActivationFunction.SINE:
                encoding = (torch.sin(2 * torch.pi * encoding) + 1) / 2  # Sine activation with [0,1]-normalized range
            elif self.final_encoder_activation_function is FinalEncoderActivationFunction.SIGMOID:
                encoding = nn.Sigmoid()(encoding)
            elif self.final_encoder_activation_function is FinalEncoderActivationFunction.NONE:
                # Excluding this without enforcing a [0,1] output range elsewhere in the model will result in incompatibility with DiffJPEG
                pass
            else:
                raise ValueError("Invalid final encoder activation function selected!")

            # Differentiable rounding should be used during training, but not during testing or validation.
            if differentiable_rounding:
                encoding = diff_round(encoding * 255) / 255
            else:
                encoding = torch.round(encoding * 255) / 255

            # JPEG or DiffJPEG compression (if enabled, otherwise "lossless image compression")
            # The actual JPEG codec option should be used during testing.
            JPEG_compressed = None
            if image_compression_type == ImageCompressionType.DIFF_JPEG:
                JPEG_compressed, y, cb, cr = self.JPEG(encoding, quality, differentiable_rounding)
                JPEG_size = None  # Not calculated since the lossless compression parts of the JPEG algorithm are not simulated
            elif image_compression_type == ImageCompressionType.JPEG:
                img = toPILImage(encoding.squeeze())
                jpegInMemory = BytesIO()
                if self.subsampling:
                    img.save(jpegInMemory, quality=int(round(quality)), format="jpeg", subsampling=2)  # 4:2:0 chroma subsampling
                else:
                    img.save(jpegInMemory, quality=int(round(quality)), format="jpeg", subsampling=0)  # 4:4:4 (no chroma subsampling)
                JPEG_size = jpegInMemory.getbuffer().nbytes  # JPEG size in bytes
                JPEG_compressed = torch.reshape(toTensor(Image.open(jpegInMemory)), [1, 3, x.shape[2], x.shape[3]]).to(device)
                y, cb, cr = None, None, None
            elif image_compression_type == ImageCompressionType.LOSSLESS:
                JPEG_compressed = encoding  # "Lossless compression".  A bit of a misnomer in terms of variable names if JPEG compression is not enabled.
                y, cb, cr = None, None, None
                JPEG_size = None
            else:
                print(image_compression_type)
                raise ValueError("Invalid image compression type selected!")
        else:
            JPEG_compressed = x
            encoding = None
            y, cb, cr = None, None, None
            JPEG_size = None
        
        if perform_decoding and (self.depth_loss_enabled or self.mask_loss_enabled):
            recovered_depth_and_mask = self.decoder(JPEG_compressed)
            recovered_depth = recovered_depth_and_mask[:, 0, :, :].reshape([x.shape[0], 1, x.shape[2], x.shape[3]])
            recovered_mask = recovered_depth_and_mask[:, 1, :, :].reshape([x.shape[0], 1, x.shape[2], x.shape[3]])
        else:
            recovered_depth = None
            recovered_mask = None
    
        return recovered_depth, recovered_mask, encoding, y, cb, cr, JPEG_size

In [None]:
# I didn't use a rate-distortion loss, so I periodically calculated some normalization factors to approximately equalize losses across different JPEG quality levels.
# I elected to use differentiable rounding here, but it wasn't a very important choice.  Not really relevant to how you should do validation.
def calculate_error_normalization_by_JPEG_quality(config, calibration_depth_map, calibration_mask, image_compression_type, subsampling, low_JPEG_quality, high_JPEG_quality, depth_loss_function, mask_type, mask_loss_function, error_threshold_for_masking):
    with torch.no_grad():
        # Run as long as JPEG compression is enabled and you are actively training based on depth or mask losses
        if not (config.image_compression_type == ImageCompressionType.LOSSLESS or (not config.depth_loss_enabled and not config.mask_loss_enabled)):
            recovered_depth, recovered_mask, neural_encoding, y, cb, cr, JPEG_size = net(torch.cat((calibration_depth_map, calibration_mask), 1), low_JPEG_quality, True, image_compression_type)
            if config.depth_loss_enabled:
                depth_MSE_loss = MSE_criterion(recovered_depth, calibration_depth_map)
                depth_L1_loss_low_JPEG_quality = L1_criterion(recovered_depth, calibration_depth_map)
                depth_RMSE_loss_low_JPEG_quality = math.sqrt(depth_MSE_loss)
                depth_loss_low_JPEG_quality = depth_loss_function(recovered_depth, calibration_depth_map)

            if config.mask_loss_enabled:
                mask_loss_low_JPEG_quality = None
                if mask_type == MaskType.BINARY:
                    mask_loss_low_JPEG_quality = mask_loss_function(recovered_mask, calibration_mask)
                elif mask_type == MaskType.ERROR_THRESHOLDED_BINARY:
                    mask_loss_low_JPEG_quality = mask_loss_function(recovered_mask, calibration_mask * (torch.abs(calibration_depth_map - recovered_depth * calibration_mask) < error_threshold_for_masking))
                elif mask_type == MaskType.ERROR_MAP:
                    # mask_loss_low_JPEG_quality = mask_loss_function(recovered_mask, torch.abs(calibration_depth_map - recovered_depth) * calibration_mask)
                    mask_loss_low_JPEG_quality = mask_loss_function(recovered_mask,  torch.nn.Tanh()(torch.abs(calibration_depth_map - recovered_depth) * 30) * calibration_mask + (1 - calibration_mask))
                else:
                    raise ValueError("Invalid image compression type selected!")
                
            recovered_depth, recovered_mask, neural_encoding, y, cb, cr, JPEG_size = net(torch.cat((calibration_depth_map, calibration_mask), 1), high_JPEG_quality, True, image_compression_type)
            
            if config.depth_loss_enabled:
                depth_MSE_loss = MSE_criterion(recovered_depth, calibration_depth_map)
                depth_L1_loss_high_JPEG_quality = L1_criterion(recovered_depth, calibration_depth_map)
                depth_RMSE_loss_high_JPEG_quality = math.sqrt(depth_MSE_loss)
                depth_loss_high_JPEG_quality = depth_loss_function(recovered_depth, calibration_depth_map)

            if config.mask_loss_enabled:
                mask_loss_high_JPEG_quality = None
                if mask_type == MaskType.BINARY:
                    mask_loss_high_JPEG_quality = mask_loss_function(recovered_mask, calibration_mask)
                elif mask_type == MaskType.ERROR_THRESHOLDED_BINARY:
                    mask_loss_high_JPEG_quality = mask_loss_function(recovered_mask, calibration_mask * (torch.abs(calibration_depth_map - recovered_depth * calibration_mask) < error_threshold_for_masking))
                elif mask_type == MaskType.ERROR_MAP:
                    # mask_loss_high_JPEG_quality = mask_loss_function(recovered_mask, torch.abs(calibration_depth_map - recovered_depth) * calibration_mask)
                    mask_loss_high_JPEG_quality = mask_loss_function(recovered_mask,  torch.nn.Tanh()(torch.abs(calibration_depth_map - recovered_depth) * 30) * calibration_mask + (1 - calibration_mask))
                else:
                    raise ValueError("Invalid image compression type selected!")
                    
            if config.depth_loss_enabled:
                depth_RMSE_JPEG_quality_normalization_factor = (depth_RMSE_loss_low_JPEG_quality - depth_RMSE_loss_high_JPEG_quality) / depth_RMSE_loss_high_JPEG_quality
                depth_L1_JPEG_quality_normalization_factor = (depth_L1_loss_low_JPEG_quality - depth_L1_loss_high_JPEG_quality) / depth_L1_loss_high_JPEG_quality
                depth_loss_JPEG_quality_normalization_factor = (depth_loss_low_JPEG_quality - depth_loss_high_JPEG_quality) / depth_loss_high_JPEG_quality

                avg_depth_L1_loss = (depth_L1_loss_low_JPEG_quality + depth_L1_loss_high_JPEG_quality) / 2
                avg_depth_RMSE_loss = (depth_RMSE_loss_low_JPEG_quality + depth_RMSE_loss_high_JPEG_quality) / 2
            else:
                depth_loss_JPEG_quality_normalization_factor, depth_RMSE_JPEG_quality_normalization_factor, depth_L1_JPEG_quality_normalization_factor, avg_depth_L1_loss, avg_depth_RMSE_loss = 1, 1, 1, torch.nan, torch.nan
                
            if config.mask_loss_enabled:
                mask_JPEG_quality_normalization_factor = (mask_loss_low_JPEG_quality - mask_loss_high_JPEG_quality) / mask_loss_high_JPEG_quality

                avg_mask_loss = (mask_loss_low_JPEG_quality + mask_loss_high_JPEG_quality) / 2
            else:
                mask_JPEG_quality_normalization_factor, avg_mask_loss = torch.nan, torch.nan

        else:
            depth_RMSE_JPEG_quality_normalization_factor = 1
            depth_L1_JPEG_quality_normalization_factor = 1
            depth_loss_JPEG_quality_normalization_factor = 1
            mask_JPEG_quality_normalization_factor = 1
            avg_depth_L1_loss = torch.nan
            avg_depth_RMSE_loss = torch.nan
            avg_mask_loss = torch.nan

    return depth_RMSE_JPEG_quality_normalization_factor, depth_L1_JPEG_quality_normalization_factor, depth_loss_JPEG_quality_normalization_factor, mask_JPEG_quality_normalization_factor, avg_depth_L1_loss, avg_depth_RMSE_loss, avg_mask_loss

In [None]:
# I was really focused on ~equal performance across the whole normalized depth range, so this was effectively my "validation loss".
# Note that I'm not using differentiable rounding since there is no need for backpropagation.
def validate_depth_gradient_image_performance(depth_gradient_image, depth_gradient_image_and_ones_mask, depth_gradient_image_and_zeros_mask, target_validation_JPEG_quality, subsampling, gradient_validation_image_compression_type, save_RGB_encodings=False):
    with torch.no_grad():
        if config.depth_loss_enabled or config.mask_loss_enabled:
            # Note: neural_encoding_for_ones_mask, the RGB encodings corresponding to 10,000 depth values in the range [0,1] can be saved to images and later plotted in 3D for analysis.
            decoded_depth, _, neural_encoding_for_ones_mask, _, _, _, _ = net(depth_gradient_image_and_ones_mask, target_validation_JPEG_quality, False, gradient_validation_image_compression_type)
            
            depth_range_gradient_L1_loss = nn.L1Loss()(depth_gradient_image, decoded_depth)
            max_depth_range_gradient_L1_loss = torch.max(torch.abs(depth_gradient_image - decoded_depth))
        else:
            depth_range_gradient_L1_loss = torch.nan
            max_depth_range_gradient_L1_loss = torch.nan

        if save_RGB_encodings:
            # Note: neural_encoding_for_zeros_mask should converge to being a mostly uniform color since likely a single color be selected to denote "background" pixels.
            _, _, neural_encoding_for_zeros_mask, _, _, _, _ = net(depth_gradient_image_and_zeros_mask, target_validation_JPEG_quality, False, gradient_validation_image_compression_type, True, False)
        else:
            neural_encoding_for_zeros_mask = None

        return depth_range_gradient_L1_loss, max_depth_range_gradient_L1_loss, neural_encoding_for_ones_mask, neural_encoding_for_zeros_mask

In [None]:
# Useful if you want to check on the current learning rate (due to use of learning rate schedulers)
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [None]:
# This is the main model inference function, which is called from the training loop.
def inference_and_loss_calculation(config, depth_maps, ground_truth_mask, quality, depth_RMSE_JPEG_quality_normalization_factor, depth_L1_JPEG_quality_normalization_factor, depth_loss_JPEG_quality_normalization_factor, mask_JPEG_quality_normalization_factor):
    # Inference
    recovered_depth, recovered_mask, _, _, _, _, _ = net(torch.cat((depth_maps, ground_truth_mask), 1), quality, True, config.image_compression_type)

    combined_loss = 0
    
    # Depth Loss
    if config.depth_loss_enabled:
        depth_loss = config.depth_loss_function(recovered_depth * ground_truth_mask, depth_maps * ground_truth_mask)
        depth_MSE_loss = MSE_criterion(recovered_depth * ground_truth_mask, depth_maps * ground_truth_mask)
        depth_L1_loss = L1_criterion(recovered_depth * ground_truth_mask, depth_maps * ground_truth_mask)
        depth_RMSE_loss = torch.sqrt(depth_MSE_loss)

        # Normalization of losses across JPEG qualities
        reweighted_depth_loss = depth_loss * (1 + depth_loss_JPEG_quality_normalization_factor * (quality - config.low_JPEG_quality) / (config.high_JPEG_quality - config.low_JPEG_quality))
        reweighted_depth_RMSE_loss = depth_RMSE_loss * (1 + depth_RMSE_JPEG_quality_normalization_factor * (quality - config.low_JPEG_quality) / (config.high_JPEG_quality - config.low_JPEG_quality))  # Just for printing losses (not used for backpropagation)
        reweighted_depth_L1_loss = depth_L1_loss * (1 + depth_L1_JPEG_quality_normalization_factor * (quality - config.low_JPEG_quality) / (config.high_JPEG_quality - config.low_JPEG_quality))  # Just for printing losses (not used for backpropagation)

        # Add reweighted mask loss to combined loss (used for backpropagation)
        combined_loss += config.depth_loss_weighting * reweighted_depth_loss
    else:
        reweighted_depth_loss = torch.nan
        reweighted_depth_RMSE_loss = torch.nan
        reweighted_depth_L1_loss = torch.nan
        

    # Mask Loss
    mask_loss = torch.nan
    if config.mask_loss_enabled:
        # print("using mask loss")
        if config.mask_type == MaskType.BINARY:
            mask_loss = config.mask_loss_function(recovered_mask, ground_truth_mask.float())
        elif config.mask_type == MaskType.ERROR_THRESHOLDED_BINARY:
            # Remove depth entries with recovered depth errors above the specified error threshold from the ground truth mask
            mask_loss = config.mask_loss_function(recovered_mask, (ground_truth_mask.float() * (torch.abs(depth_maps * ground_truth_mask - recovered_depth * ground_truth_mask) < config.error_threshold_for_masking)).detach())
        elif config.mask_type == MaskType.ERROR_MAP:
            mask_loss = config.mask_loss_function(recovered_mask, torch.nn.Tanh()(torch.abs(depth_maps * ground_truth_mask - recovered_depth * ground_truth_mask) * 30) + (1 - ground_truth_mask.float()))
        else:
            raise ValueError("Invalid mask type type selected!")

        # Normalization of losses across JPEG qualities
        reweighted_mask_loss = mask_loss * (1 + mask_JPEG_quality_normalization_factor * (quality - config.low_JPEG_quality) / (config.high_JPEG_quality - config.low_JPEG_quality))

        # Add reweighted mask loss to combined loss (used for backpropagation)
        combined_loss += config.mask_loss_weighting * reweighted_mask_loss
    else:
        reweighted_mask_loss = torch.nan

    return combined_loss, mask_loss, reweighted_depth_RMSE_loss, reweighted_depth_L1_loss, reweighted_mask_loss

In [None]:
# I wanted equal performance across a normalized depth range, so this image histogram equalization was important to allow me to equally sample across [0,1]
def image_histogram_equalization(image, number_bins=1024):
    # From http://www.janeriksolem.net/histogram-equalization-with-python-and.html

    # Get image histogram
    image_histogram, bins = np.histogram(image.flatten(), number_bins, density=True)
    cdf = (image_histogram*1025).cumsum() # cumulative distribution function
    cdf = (number_bins-1) * cdf / cdf[-1] # normalize

    # Use linear interpolation of cdf to find new pixel values
    image_equalized = np.interp(image.flatten(), bins[:-1], cdf)

    return image_equalized.reshape(image.shape), cdf

In [None]:
def initialize_optimizer(config, net):    
    if config.optimizer_type == optim.Adam:
        optimizer = optim.Adam(net.parameters(), lr=config.starting_lr, weight_decay=config.weight_decay)
    elif config.optimizer_type == optim.SGD:
        optimizer = optim.SGD(net.parameters(), lr=config.starting_lr, momentum=0.9, weight_decay=config.weight_decay, nesterov=False)
    else:
        raise ValueError("Invalid optimizer type selected!")

    return optimizer

In [None]:
def initialize_scheduler(config, optimizer):
    if config.scheduler_type == optim.lr_scheduler.CosineAnnealingWarmRestarts:
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, config.cosine_annealing_scheduler_period, config.cosine_annealing_scheduler_restart_factor)
    elif config.scheduler_type == optim.lr_scheduler.ReduceLROnPlateau:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=config.lr_gamma, patience=config.patience, threshold=0.0001, threshold_mode='rel', cooldown=config.cooldown, min_lr=0, eps=1e-10)
    else:
        raise ValueError("Invalid scheduler type selected!")
    
    return scheduler

### Prepare Save Locations

In [None]:
# Setting a naming format for checkpoint files and preparing the folder structures
notebook_name = os.path.splitext(notebook_name)[0]
save_directory_all_runs = "./weights/" + notebook_name
training_run_number, weight_save_directory, config, checkpoint_file_path = determine_training_run_number(config, save_directory_all_runs)
create_images_folder_structure(config, notebook_name)
print(f'weight_save_directory = {weight_save_directory}')

In [None]:
print(config.start_from_scratch)

## Instantiation

In [None]:
# TensorBoard Logging
if not (config.dry_run or config.debug_mode):
    writer = SummaryWriter(log_dir="TensorBoard_logs", filename_suffix=notebook_name)

In [None]:
# Some data augmentation
transforms = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomVerticalFlip(),
                T.Resize((config.height, config.width), interpolation=config.interpolation_mode)
                # More augmentation transforms can be added here
            ])

# Initialize training dataset
dataset = FlyingThingsDataset(config.dataset_path, "TRAIN", config.random_crops, transforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=1, persistent_workers=True)

### Neural Network Instantiation

In [None]:
net = NeuralDepthAutoencoder(config, config.height, config.width, config.subsampling, config.color_space_conversions, config.device).to(config.device)

net.train()  # Important for layers such as batch normalization, if you end up adding them

### Optimizer and Scheduler Instantiation

In [None]:
optimizer = initialize_optimizer(config, net)
scheduler = initialize_scheduler(config, optimizer)

### Intantiate various parameters

In [None]:
starting_epoch = 0  # Epoch number starts with 0, if this is the first training run 
global_step = 0  # Keeps track of the overall mini-batch number

# Initialize these metrics to 0 since they are being minimized
best_depth_range_gradient_L1_loss = torch.inf
best_max_depth_range_gradient_L1_loss = torch.inf
best_avg_L1_loss_on_JPEG_normalization_image = torch.inf
best_avg_RMSE_loss_on_JPEG_normalization_image = torch.inf
best_avg_mask_loss_on_JPEG_normalization_image = torch.inf

# Prepopulate some metric arrays with zeros (list appending or sole use of the TensorBoard logs may be a preferable alternative to you)
best_depth_range_gradient_L1_losses_by_epoch = np.zeros(config.num_training_epochs)
best_max_depth_range_gradient_L1_losses_by_epoch = np.zeros(config.num_training_epochs)
best_avg_L1_losses_on_JPEG_normalization_by_epoch = np.zeros(config.num_training_epochs)
best_avg_RMSE_losses_on_JPEG_normalization_by_epoch = np.zeros(config.num_training_epochs)
best_avg_mask_loss_on_JPEG_normalization_image_by_epoch = np.zeros(config.num_training_epochs)

# May override some of the above parameters if continuing an interrupted training session
if config.weight_initialization_source is not None:
    load_checkpoint_file(config.weight_initialization_source)  # Load a checkpoint from a different source directory
elif not config.start_from_scratch and checkpoint_file_path is not None:
    load_checkpoint_file(checkpoint_file_path)  # Load a checkpoint file from an interrupted training session

### Instantiate a depth map for normalization of losses across JPEG qualities

In [None]:
# I didn't really use this test dataset for testing.
# In this code cell I'm just using a single disparity map to use for error normalization purposes across JPEG quality levels (since I elected not to use a rate-distortion loss).
test_dataset = FlyingThingsDataset(config.dataset_path, "TEST", False, T.Resize((config.height, config.width), interpolation=config.interpolation_mode))

calibration_depth_map = test_dataset[5].reshape([1, 1, 224, 224])

calibration_depth_map_normalized, _ = image_histogram_equalization(calibration_depth_map.detach().cpu().numpy(), 1024)
calibration_depth_map = torch.from_numpy(calibration_depth_map_normalized).float()

minimum = torch.min(calibration_depth_map)
maximum = torch.max(calibration_depth_map)
val_range = maximum - minimum

# Cutoff values so that there is some "background" region to mask out
low_cutoff = minimum + 0.1*val_range
high_cutoff = maximum - 0.1*val_range

calibration_mask1 = calibration_depth_map > low_cutoff
calibration_mask2 = calibration_depth_map < high_cutoff
calibration_mask = calibration_mask1 * calibration_mask2
calibration_depth_map = calibration_depth_map * calibration_mask

calibration_depth_map = (calibration_depth_map.view(config.batch_size, 1, config.height, config.width) - low_cutoff) / (high_cutoff - low_cutoff)
calibration_depth_map = calibration_depth_map * calibration_mask
calibration_depth_map = calibration_depth_map.to(config.device)

calibration_depth_map = torch.clone(calibration_depth_map)
calibration_mask = torch.clone(calibration_mask).float().to(config.device)

In [None]:
# Visualization of the depth map used for normalizing the losses across different JPEG qualities
plt.imshow(calibration_depth_map.cpu().numpy().reshape([224,224]))
plt.colorbar()

### Instantiate a depth gradient image (i.e., tilted plane) for lossless validation

In [None]:
# 100x100 Depth range gradient image to serve as a sort of validation image
depth_gradient_image = torch.zeros(10000).to(config.device)
for i in range(10000):
    depth = i/9999
    depth_gradient_image[i] = depth
depth_gradient_image = depth_gradient_image.reshape([1, 1, 100, 100])
depth_gradient_image_ones_mask = torch.ones([1, 1, 100, 100]).to(config.device)
depth_gradient_image_zeros_mask = torch.zeros([1, 1, 100, 100]).to(config.device)

depth_gradient_image_and_ones_mask = torch.cat((depth_gradient_image, depth_gradient_image_ones_mask), 1)  # All "foreground"
depth_gradient_image_and_zeros_mask = torch.cat((depth_gradient_image, depth_gradient_image_zeros_mask), 1)  # All "background"

In [None]:
# Visualization of the depth gradient image (i.e., a tilted plane)
plt.imshow(depth_gradient_image.cpu().numpy().reshape([100,100]))
plt.colorbar()

### Instantiate some functions for shorthand

In [None]:
MSE_criterion = nn.MSELoss()
L1_criterion = nn.L1Loss()
BCE_criterion = nn.BCEWithLogitsLoss()
toPILImage = T.ToPILImage()

### For verification of histogram equalization of depth maps (tends to improve training results)

In [None]:
test_overall_histogram = np.zeros(1024)

## Training

### Training Loop

In [None]:
def run_training_loop(config, starting_epoch, weight_save_directory, notebook_name, training_run_number, scaler, calibration_depth_map, calibration_mask, depth_gradient_image, depth_gradient_image_and_ones_mask, depth_gradient_image_and_zeros_mask, test_overall_histogram):
    if not config.dry_run:
        config.ran_to_completion = False
        with open(f'{weight_save_directory}{notebook_name}_training_configuration_run{training_run_number}.pt', 'wb') as handle:
            pickle.dump(config, handle)

    # These were instantiated earlier and they will persist between training runs
    global best_depth_range_gradient_L1_loss
    global best_max_depth_range_gradient_L1_loss
    global best_avg_L1_loss_on_JPEG_normalization_image
    global best_avg_RMSE_loss_on_JPEG_normalization_image
    global best_avg_mask_loss_on_JPEG_normalization_image
    global best_depth_range_gradient_L1_losses_by_epoch
    global best_max_depth_range_gradient_L1_losses_by_epoch
    global best_avg_L1_losses_on_JPEG_normalization_by_epoch
    global best_avg_RMSE_losses_on_JPEG_normalization_by_epoch
    global best_avg_mask_loss_on_JPEG_normalization_image_by_epoch
    global global_step

    last_epoch_saved = -1  # Used to tell if a model checkpoint has occurred for the current epoch
    
    for epoch in range(starting_epoch, config.num_training_epochs):
        progress_bar = tqdm(train_loader, miniters=config.print_interval, maxinterval=1e9)
        num_samples = len(progress_bar)

        # For performance reasons, training checkpoints are cached into a dictionary and saved to disk at the end of every epoch
        cached_checkpoints = {}

        # Loop over all training samples (1 epoch)
        for i, depth_maps in enumerate(progress_bar):
            # Randomly select a JPEG quality level from the specified range
            quality = random.uniform(config.low_JPEG_quality, config.high_JPEG_quality)

            # Histogram equalization of depth map to improve depth range sampling pattern
            depth_maps_histogram_normalized, _ = image_histogram_equalization(depth_maps.detach().cpu().numpy(), 1024)
            depth_maps = torch.from_numpy(depth_maps_histogram_normalized).float()

            minimum = torch.min(depth_maps)
            maximum = torch.max(depth_maps)
            val_range = maximum - minimum

            # Randomly cut off a portion of the depth range to create "background" regions for the ground truth mask
            low_cutoff = minimum + random.uniform(0.0, 0.3)*val_range
            high_cutoff = maximum - random.uniform(0.0, 0.3)*val_range

            # Just in case
            if high_cutoff <= low_cutoff:
                raise Exception("The depth range low and high cutoff values are invalid")

            # Creating a ground truth mask and applying it to the ground truth depth map(s)
            mask1 = depth_maps > low_cutoff
            mask2 = depth_maps < high_cutoff
            ground_truth_mask = mask1 * mask2
            depth_maps = depth_maps * ground_truth_mask

            # Extended normalized depth range by +1%/-1% (configurable) during training to achieve accurate results near extrema during inference
            depth_maps = ((depth_maps.view(config.batch_size, 1, config.height, config.width) - low_cutoff) / (high_cutoff - low_cutoff)* config.expanded_training_range - (config.expanded_training_range - 1) / 2) * (config.depth_normalization_upper_bound - config.depth_normalization_lower_bound) + config.depth_normalization_lower_bound
            depth_maps = depth_maps * ground_truth_mask

            # Zero gradients so they don't accumulate
            optimizer.zero_grad()

            depth_maps = depth_maps.to(config.device)
            ground_truth_mask = ground_truth_mask.to(config.device)

            # Periodically calculate loss normalization factors since I didn't use a rate-distortion loss and am randomizing the JPEG quality
            if i % config.JPEG_normalization_interval == 0:
                depth_RMSE_JPEG_quality_normalization_factor, depth_L1_JPEG_quality_normalization_factor, depth_loss_JPEG_quality_normalization_factor, mask_JPEG_quality_normalization_factor, avg_L1_loss_on_JPEG_normalization_image, avg_RMSE_loss_on_JPEG_normalization_image, avg_mask_loss_on_JPEG_normalization_image = calculate_error_normalization_by_JPEG_quality(config, calibration_depth_map, calibration_mask, config.image_compression_type, config.subsampling, config.low_JPEG_quality, config.high_JPEG_quality, config.depth_loss_function, config.mask_type, config.mask_loss_function, config.error_threshold_for_masking)

            # Bool to control periodic image saving (plot the values from the resulting images in 3D to see the encoding pattern in the RGB color space)
            log_encoded_images = (global_step % config.encoded_image_log_interval == 0 and not (config.dry_run or config.debug_mode))
                
            # Depth gradient "validation" losses on a depth gradient image that samples full [0,1] depth range (i.e., a "tilted plane")
            if i % config.depth_gradient_val_interval == 0 or global_step % config.encoded_image_log_interval == 0:
                depth_range_gradient_L1_loss, max_depth_range_gradient_L1_loss, neural_encoding_for_ones_mask, neural_encoding_for_zeros_mask = validate_depth_gradient_image_performance(depth_gradient_image, depth_gradient_image_and_ones_mask, depth_gradient_image_and_zeros_mask, config.target_validation_JPEG_quality, config.subsampling, config.gradient_validation_image_compression_type, log_encoded_images)                
            
            # Log lossless encoding images
            if not config.dry_run:
                if log_encoded_images:
                    # This will log the encoded colors for the depth range [0, 1.0] for 10,000 entries.  Plot these in 3D for best visualization.
                    image_save_name = "./images/" + notebook_name + "/neural_encoding_for_ones_mask/" + notebook_name + "_neural_encoding_for_ones_mask_%08d.png" % global_step 
                    toPILImage(neural_encoding_for_ones_mask.squeeze()).save(image_save_name, "PNG")

                    # This should result in "background" encoding color (likely just a single color, more or less).
                    image_save_name = "./images/" + notebook_name + "/neural_encoding_for_zeros_mask/" + notebook_name + "_neural_encoding_for_zeros_mask_%08d.png" % global_step 
                    toPILImage(neural_encoding_for_zeros_mask.squeeze()).save(image_save_name, "PNG")

            # Inference and loss calculations
            if config.precision_mode == PrecisionMode.MIXED_PRECISION:
                with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                    combined_loss, mask_loss, reweighted_depth_RMSE_loss, reweighted_depth_L1_loss, reweighted_mask_loss = inference_and_loss_calculation(config, depth_maps, ground_truth_mask, quality, depth_RMSE_JPEG_quality_normalization_factor, depth_L1_JPEG_quality_normalization_factor, depth_loss_JPEG_quality_normalization_factor, mask_JPEG_quality_normalization_factor)
            else:
                combined_loss, mask_loss, reweighted_depth_RMSE_loss, reweighted_depth_L1_loss, reweighted_mask_loss = inference_and_loss_calculation(config, depth_maps, ground_truth_mask, quality, depth_RMSE_JPEG_quality_normalization_factor, depth_L1_JPEG_quality_normalization_factor, depth_loss_JPEG_quality_normalization_factor, mask_JPEG_quality_normalization_factor)
                
            # Saving best losses and checkpointing
            if depth_range_gradient_L1_loss < best_depth_range_gradient_L1_loss:
                best_depth_range_gradient_L1_loss = depth_range_gradient_L1_loss

                # This is my favorite loss to checkpoint on (I largely do not use the others, but they can be useful if you encounter overfitting)
                cached_checkpoints["best_L1_rangeGradient_val_loss"] = cache_checkpoint(config, epoch, net, optimizer, scheduler, scaler, 
                    best_depth_range_gradient_L1_loss, best_max_depth_range_gradient_L1_loss, 
                    best_avg_L1_loss_on_JPEG_normalization_image, best_avg_RMSE_loss_on_JPEG_normalization_image, best_avg_mask_loss_on_JPEG_normalization_image,
                    best_depth_range_gradient_L1_losses_by_epoch, best_max_depth_range_gradient_L1_losses_by_epoch, best_avg_L1_losses_on_JPEG_normalization_by_epoch, 
                    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch, best_avg_mask_loss_on_JPEG_normalization_image_by_epoch)
                
            if max_depth_range_gradient_L1_loss < best_max_depth_range_gradient_L1_loss:
                best_max_depth_range_gradient_L1_loss = max_depth_range_gradient_L1_loss

                cached_checkpoints["best_max_rangeGradient_val_loss"] = cache_checkpoint(config, epoch, net, optimizer, scheduler, scaler, 
                    best_depth_range_gradient_L1_loss, best_max_depth_range_gradient_L1_loss, 
                    best_avg_L1_loss_on_JPEG_normalization_image, best_avg_RMSE_loss_on_JPEG_normalization_image, best_avg_mask_loss_on_JPEG_normalization_image,
                    best_depth_range_gradient_L1_losses_by_epoch, best_max_depth_range_gradient_L1_losses_by_epoch, best_avg_L1_losses_on_JPEG_normalization_by_epoch, 
                    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch, best_avg_mask_loss_on_JPEG_normalization_image_by_epoch)
                
            if avg_L1_loss_on_JPEG_normalization_image < best_avg_L1_loss_on_JPEG_normalization_image:
                best_avg_L1_loss_on_JPEG_normalization_image = avg_L1_loss_on_JPEG_normalization_image

                cached_checkpoints["best_avg_L1_val_loss_on_JPEG_normalization_image"] = cache_checkpoint(config, epoch, net, optimizer, scheduler, scaler, 
                    best_depth_range_gradient_L1_loss, best_max_depth_range_gradient_L1_loss, 
                    best_avg_L1_loss_on_JPEG_normalization_image, best_avg_RMSE_loss_on_JPEG_normalization_image, best_avg_mask_loss_on_JPEG_normalization_image,
                    best_depth_range_gradient_L1_losses_by_epoch, best_max_depth_range_gradient_L1_losses_by_epoch, best_avg_L1_losses_on_JPEG_normalization_by_epoch, 
                    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch, best_avg_mask_loss_on_JPEG_normalization_image_by_epoch)
                
            if avg_RMSE_loss_on_JPEG_normalization_image  < best_avg_RMSE_loss_on_JPEG_normalization_image:
                best_avg_RMSE_loss_on_JPEG_normalization_image = avg_RMSE_loss_on_JPEG_normalization_image 

                cached_checkpoints["best_avg_RMSE_loss_on_JPEG_normalization_image"] = cache_checkpoint(config, epoch, net, optimizer, scheduler, scaler, 
                    best_depth_range_gradient_L1_loss, best_max_depth_range_gradient_L1_loss, 
                    best_avg_L1_loss_on_JPEG_normalization_image, best_avg_RMSE_loss_on_JPEG_normalization_image, best_avg_mask_loss_on_JPEG_normalization_image,
                    best_depth_range_gradient_L1_losses_by_epoch, best_max_depth_range_gradient_L1_losses_by_epoch, best_avg_L1_losses_on_JPEG_normalization_by_epoch, 
                    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch, best_avg_mask_loss_on_JPEG_normalization_image_by_epoch)
                
            if avg_mask_loss_on_JPEG_normalization_image < best_avg_mask_loss_on_JPEG_normalization_image:
                best_avg_mask_loss_on_JPEG_normalization_image = avg_mask_loss_on_JPEG_normalization_image

                cached_checkpoints["best_avg_mask_loss_on_JPEG_normalization_image"] = cache_checkpoint(config, epoch, net, optimizer, scheduler, scaler, 
                    best_depth_range_gradient_L1_loss, best_max_depth_range_gradient_L1_loss, 
                    best_avg_L1_loss_on_JPEG_normalization_image, best_avg_RMSE_loss_on_JPEG_normalization_image, best_avg_mask_loss_on_JPEG_normalization_image,
                    best_depth_range_gradient_L1_losses_by_epoch, best_max_depth_range_gradient_L1_losses_by_epoch, best_avg_L1_losses_on_JPEG_normalization_by_epoch, 
                    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch, best_avg_mask_loss_on_JPEG_normalization_image_by_epoch)
           
            # Also save once per epoch, regardless of performance on any metrics
            if (scheduler.get_last_lr()[0] == config.starting_lr and last_epoch_saved < epoch) or (config.debug_mode and i == config.debug_mode_loop_length - 1):
                last_epoch_saved = epoch
                
                cached_checkpoints["epoch%i" % epoch] = cache_checkpoint(config, epoch, net, optimizer, scheduler, scaler, 
                    best_depth_range_gradient_L1_loss, best_max_depth_range_gradient_L1_loss, 
                    best_avg_L1_loss_on_JPEG_normalization_image, best_avg_RMSE_loss_on_JPEG_normalization_image, best_avg_mask_loss_on_JPEG_normalization_image,
                    best_depth_range_gradient_L1_losses_by_epoch, best_max_depth_range_gradient_L1_losses_by_epoch, best_avg_L1_losses_on_JPEG_normalization_by_epoch, 
                    best_avg_RMSE_losses_on_JPEG_normalization_by_epoch, best_avg_mask_loss_on_JPEG_normalization_image_by_epoch)

            # Backpropagation and model updates
            if config.precision_mode == PrecisionMode.MIXED_PRECISION:
                scaler.scale(combined_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                combined_loss.backward()
                optimizer.step()

            # Learning rate scheduler step
            if config.scheduler_type is optim.lr_scheduler.ReduceLROnPlateau:
                scheduler.step(best_depth_range_gradient_L1_loss)
            else:
                scheduler.step()
            
            # Checking that histogram equalization worked
            histogram, bin_edges = np.histogram(depth_maps.reshape([config.height, config.width]).detach().cpu().numpy(), bins=1024, range=(-0.01, 1.01))
            test_overall_histogram += histogram

            # Update progress bar
            if global_step % config.print_interval == 0 or i == (num_samples - 1):  # Rate limiting TQDM updates to prevent Jupyter Notebook from complaining
                scheduler_lr = scheduler.get_last_lr()[0]
                
                progress_bar.set_description(f'Run {training_run_number}, Epoch {epoch}, Batch {i}')
                progress_metrics = {
                    'LR': f'{scheduler_lr:.2E}',
                    'Best val': f'{best_depth_range_gradient_L1_loss*1000:.3f}',
                    'Best max val': f'{best_max_depth_range_gradient_L1_loss*1000:.3f}',
                    'Cal. depth': f'{best_avg_mask_loss_on_JPEG_normalization_image*1000:.3f}',
                    'Cal. mask': f'{avg_mask_loss_on_JPEG_normalization_image*1000:.3f}',
                    'Val': f'{depth_range_gradient_L1_loss*1000:.3f}',
                    'Max val': f'{max_depth_range_gradient_L1_loss*1000:.3f}',
                    'RMSE': f'{reweighted_depth_RMSE_loss*1000:.3f}',
                    'L1': f'{reweighted_depth_L1_loss*1000:.3f}',
                    'Mask loss': f'{reweighted_mask_loss*1000:.3f}',
                    'Quality': f'{quality:.3f}'
                }
                progress_bar.set_postfix(progress_metrics)

            # Tensorboard logging
            if global_step % config.tensorboard_log_interval == 0 and not (config.dry_run or config.debug_mode):
                writer.add_scalar('Training_loss/reweighted_depth_L1_loss', reweighted_depth_L1_loss, global_step)
                writer.add_scalar('Training_loss/reweighted_depth_RMSE_loss', reweighted_depth_RMSE_loss, global_step)
                writer.add_scalar('Training_loss/reweighted_mask_loss', reweighted_mask_loss, global_step)
                
                # Most recent validation losses
                writer.add_scalar('Validation_loss/depth_range_gradient_L1_loss', depth_range_gradient_L1_loss, global_step)
                writer.add_scalar('Validation_loss/max_depth_range_gradient_L1_loss', max_depth_range_gradient_L1_loss, global_step)
                writer.add_scalar('Validation_loss/avg_L1_loss_on_JPEG_normalization_image', avg_L1_loss_on_JPEG_normalization_image, global_step)
                writer.add_scalar('Validation_loss/avg_RMSE_loss_on_JPEG_normalization_image', avg_RMSE_loss_on_JPEG_normalization_image, global_step)
                writer.add_scalar('Validation_loss/avg_mask_loss_on_JPEG_normalization_image', avg_mask_loss_on_JPEG_normalization_image, global_step)
                
                # Best validation losses
                writer.add_scalar('Validation_loss/best_depth_range_gradient_L1_loss', best_depth_range_gradient_L1_loss, global_step)
                writer.add_scalar('Validation_loss/best_max_depth_range_gradient_L1_loss', best_max_depth_range_gradient_L1_loss, global_step)
                writer.add_scalar('Validation_loss/best_avg_L1_loss_on_JPEG_normalization_image', best_avg_L1_loss_on_JPEG_normalization_image, global_step)
                writer.add_scalar('Validation_loss/best_avg_RMSE_loss_on_JPEG_normalization_image', best_avg_RMSE_loss_on_JPEG_normalization_image, global_step)
                writer.add_scalar('Validation_loss/best_avg_mask_loss_on_JPEG_normalization_image', best_avg_mask_loss_on_JPEG_normalization_image, global_step)
            
            global_step += 1
            
            # Use artificially short epochs if running in debug mode
            if config.debug_mode:
                if i >= (config.debug_mode_loop_length - 1):
                    break

        # Save cached checkpoints to disk at the end of the epoch
        for checkpoint_save_name, cached_checkpoint in cached_checkpoints.items():
            save_checkpoint(cached_checkpoint, checkpoint_save_name, config, weight_save_directory, notebook_name, training_run_number)

        # Save best losses to their by-epoch list versions
        best_depth_range_gradient_L1_losses_by_epoch[epoch] = best_depth_range_gradient_L1_loss
        best_max_depth_range_gradient_L1_losses_by_epoch[epoch] = best_max_depth_range_gradient_L1_loss
        best_avg_L1_losses_on_JPEG_normalization_by_epoch[epoch] = best_avg_L1_loss_on_JPEG_normalization_image
        best_avg_RMSE_losses_on_JPEG_normalization_by_epoch[epoch] = best_avg_RMSE_loss_on_JPEG_normalization_image
        best_avg_mask_loss_on_JPEG_normalization_image_by_epoch[epoch] = best_avg_mask_loss_on_JPEG_normalization_image
        
        # Flush the TensorBoard outputs after every epoch
        if not config.dry_run and not config.debug_mode:
            writer.flush()

        # End run if progress stalls or we hit the epoch limit for the current run
        end_run = False
        if config.depth_loss_enabled:
            end_run = (epoch - starting_epoch) > (config.stalled_val_patience + 1) and (best_depth_range_gradient_L1_losses_by_epoch[epoch - config.stalled_val_patience] - best_depth_range_gradient_L1_loss) < config.stalled_val_threshold
        if config.debug_mode or (epoch - starting_epoch) >= config.max_epochs_per_run:
            end_run = True
        
        if end_run:
            config.ran_to_completion = True
            with open(f'{weight_save_directory}{notebook_name}_training_configuration_run{training_run_number}.pt', 'wb') as handle:
                pickle.dump(config, handle)
            break

        # Only 1 epoch per run in debug mode
        if config.debug_mode:
            break
        
    epoch = epoch + 1
    
    return epoch

In [None]:
# Just verification of a couple of important parameters
encoder_frozen = not all(param.requires_grad for param in net.encoder.parameters())
print(f'encoder_frozen = {encoder_frozen}')
encoder_frozen = not all(param.requires_grad for param in net.decoder.parameters())
print(f'decoder_frozen = {encoder_frozen}')

In [None]:
# Printing out the network sizes for reference
encoder_size = sum(p.numel() for p in net.encoder.parameters())
print(f'encoder_size = {encoder_size:,} parameters')
decoder_size = sum(p.numel() for p in net.decoder.parameters())
print(f'decoder_size = {decoder_size:,} parameters')

### Mixed Precision Training

In [None]:
# Running the training loop the first time (mixed precision mode)
if config.precision_mode is PrecisionMode.MIXED_PRECISION:
    starting_epoch = run_training_loop(config, starting_epoch, weight_save_directory, notebook_name, training_run_number, scaler, calibration_depth_map, calibration_mask, depth_gradient_image, depth_gradient_image_and_ones_mask, depth_gradient_image_and_zeros_mask, test_overall_histogram)

    # Prepare for default PyTorch precision level run
    config.start_from_scratch = False
    training_run_number, weight_save_directory, config, checkpoint_file_path = determine_training_run_number(config, save_directory_all_runs)
    config.precision_mode = PrecisionMode.PYTORCH_DEFAULT
    scaler = set_performance_and_precision_settings(config.precision_mode)

    # Resetting the optimizer and scheduler for next training run at a new precision level
    optimizer = initialize_optimizer(config, net)
    scheduler = initialize_scheduler(config, optimizer)

### Default PyTorch Precision Training

In [None]:
# Running the training loop the second time (default PyTorch precision mode)
if config.precision_mode is PrecisionMode.PYTORCH_DEFAULT:
    starting_epoch = run_training_loop(config, starting_epoch, weight_save_directory, notebook_name, training_run_number, scaler, calibration_depth_map, calibration_mask, depth_gradient_image, depth_gradient_image_and_ones_mask, depth_gradient_image_and_zeros_mask, test_overall_histogram)

    # Prepare for the full float32 precision level run
    config.start_from_scratch = False
    training_run_number, weight_save_directory, config, checkpoint_file_path = determine_training_run_number(config, save_directory_all_runs)
    config.precision_mode = PrecisionMode.FLOAT32_PRECISION
    scaler = set_performance_and_precision_settings(config.precision_mode)

    # Resetting the optimizer and scheduler for next training run at a new precision level
    optimizer = initialize_optimizer(config, net)
    scheduler = initialize_scheduler(config, optimizer)

### Full float32 Precision Training

In [None]:
# Running the training loop the final time (full float32 precision mode)
if config.precision_mode is PrecisionMode.FLOAT32_PRECISION:
    starting_epoch = run_training_loop(config, starting_epoch, weight_save_directory, notebook_name, training_run_number, scaler, calibration_depth_map, calibration_mask, depth_gradient_image, depth_gradient_image_and_ones_mask, depth_gradient_image_and_zeros_mask, test_overall_histogram)