<a href="https://colab.research.google.com/github/psu-rdmap/unet-compare/blob/main/unet_compare.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Instructions

1. Click *Open in Colab* at the top

2. Save a copy of this notebook via *File* > *Save a copy in Drive*

3. Go to [Google Drive](https://drive.google.com) and create a new folder called `data/`.

4. Create your dataset with the structure shown in the *Training & Inference* section of GitHub repository and place it in `data/`.

5. Navigate to the notebook copy and select the drop down menu next to *Connect* in the top-right. Select *Change runtime type*, choose an available GPU, and select *Save*. Connect to a runtime

6. Run all cells in *Source*. Note, you will have to authorize mounting to Google Drive

7. Open the *Training/Inference* section, set the configs using the information [here](https://github.com/psu-rdmap/unet-compare/tree/main/configs), and run all cells in this section

8. Operation will begin and a results folder will be created in your Google Drive containing files mentioned in the *Training and Inference* section of the GitHub repository 

9. To run it again, repeat 5-8

# Source

## Prepare Environment

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from pydantic import BaseModel, PositiveInt, PositiveFloat, NonNegativeFloat, ConfigDict, Field, field_validator, model_validator
from typing import Literal, List, Optional, Tuple
from pathlib import Path
from warnings import warn
import numpy as np
from natsort import os_sorted
from datetime import datetime
import cv2 as cv
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Conv2DTranspose, MaxPooling2D, Concatenate
from keras.regularizers import l2
import shutil, random, re, gc, keras, math, json
from glob import glob
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB7
from keras.preprocessing.image import save_img
import pandas as pd
import matplotlib.pyplot as plt
from keras.api.optimizers import Adam
from keras.api.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
from keras.api.saving import load_model
from keras import backend as K

random.seed(229)
AUTOTUNE = tf.data.AUTOTUNE
tf.random.set_seed(3051)

print(keras.__version__)
print(tf.__version__)

3.8.0
2.17.1


## Input Validator

In [None]:
"""
Aiden Ochoa, 4/2025, RDMAP PSU Research Group
This module validates the user input configs file and modifies it as necessary
"""

ROOT_DIR = Path('/content') / 'drive' / 'MyDrive'


class General(BaseModel):
    """Most general configs that apply to all possible use cases. Only basic validation is done here."""

    root_dir: Path = ROOT_DIR
    operation_mode: Literal['train', 'inference'] = Field(
        default='train',
        description="General parameter that defines whether a model will be trained, or if a model will be applied for inference"
    )
    dataset_name: str = Field(
        min_length=2,
        description="Dataset subdirectory prefix corresponding to unet-compare/data/<dataset_name>/"
    )
    results_dir: Optional[str] = Field(
        default=None,
        description="Path for results directory relative to /path/to/unet-compare/. Give it `null` or ignore it to use default naming scheme"
    )    
    model_config = ConfigDict(
        extra='allow',
    )


class Train(BaseModel):
    root_dir : Path
    operation_mode: str
    dataset_name: str
    results_dir: Optional[str | Path]
    input_shape: Tuple[int, int, int] = None
    encoder_name: Literal['UNet', 'EfficientNetB7'] = Field(
        default='UNet',
        description="Type of model architecture forming the encoder section of U-Net"
    )
    decoder_name: Literal['UNet', 'UNet++'] = Field(
        default='UNet',
        description="Type of model architecture forming the decoder section of U-Net"
    )
    encoder_filters: Optional[List[PositiveInt]] = Field(
        default=None,
        min_length=5, 
        max_length=5,
        description="Number of filters to learn for each convolution. Should be `null` or ignored when `encoder_name` is `EfficientNetB7`"
    )
    decoder_filters: List[PositiveInt] = Field(
        default = [512, 256, 128, 64, 32],
        description="Number of filters to learn at each resolution level. The final item is only used when `encoder_name` is `EfficientNetB7`"
    )
    backbone_weights: Optional[Literal['random', 'imagenet']] = Field(
        default=None,
        description="Weights to be loaded when using a pretrained backbone. This should be ignored when `encoder_name` is `UNet`"
    )
    backbone_finetuning: Optional[bool | List[PositiveInt]] = Field(
        default=None,
        description="Controls the finetuning of a pretrained backbone. " \
        "This should be ignored when random weights are used like when `encoder_name` is `UNet` and when `backbone_weights` is `random`." \
        "If the entire backbone is to be unfrozen, this should be `True`. Otherwise, `False` indicates the model is frozen or selected blocks (array of ints)"
    )
    learning_rate: PositiveFloat = Field(
        default=1e-4,
        lt=1.0,
        description="Learning rate for the Adam optimizer. Should be between 0 and 1"
    )
    L2_regularization_strength: NonNegativeFloat = Field(
        default=0.0,
        lt=1.0,
        description="Strength of L2 regularization used during training. Should be between 0 and 1, with 0 meaning no L2 regularization is used"
    )
    batch_size: PositiveInt = Field(
        default=4,
        description="Number of image-annotation pairs to use in a single batch. Each batch represents one weight vector update"
    )
    num_epochs: PositiveInt = Field(
        default=50,
        description="Number of epochs to train for. Each epoch is one pass through the entire dataset"
    )
    augment: bool = Field(
        default=True,
        description="Augment the training subset eightfold by flipping and rotating by 90 deg intervals"
    )
    cross_validation: bool = Field(
        default=False,
        description="Performs a k-fold cross validation study where many models are trained using different non-overlapping validation sets"
    )
    num_folds: Optional[PositiveInt] = Field(
        default=None,
        gt=1,
        description="Number of models to train for cross validation. Should be ignored when `cross_validation` is `false`"
    )
    early_stopping: Optional[bool] = Field(
        default=None,
        description="Stop the training if the validation loss does not improve after a given number of epochs provided by `patience`. " \
        "Does not apply when `cross_validation` is `true`"
    )
    patience: Optional[PositiveInt] = Field(
        default=None,
        description="Number of epochs before training is stopped automatically. Only applies when `early_stopping` is `true`"
    )
    training_set: Optional[List[str]] = Field(
        default=None,
        description="Array of image filenames to be used for the training set. Reference the logical tree to see how it may be defined"
    )
    validation_set: Optional[List[str]] = Field(
        default=None,
        description="Array of image filenames to be used for the validation set. Reference the logical tree to see how it may be defined"
    )
    auto_split: Optional[PositiveFloat] = Field(
        default=None, 
        lt=1,
        description="Validation hold out percentage used when automatically splitting the dataset. Reference the logical tree to see how it may be defined"
    )
    model_summary: bool = Field(
        default=True,
        description="Print out the model summary from Keras to a log file in the results directory"
    )
    batchnorm: bool = Field(
        default=False,
        description="Option to use batch normalization after convolution layers in the UNet encoder/decoder"
    )

    @model_validator(mode='after')
    def pretrained_backbone(self) -> 'Train':
        """Validation specific to pretrained backbones (EfficientNet)"""
        
        # set the encoder filters if not supplied and not using EfficientNet encoder
        if self.encoder_name == 'UNet' and self.encoder_filters is None:
            self.encoder_filters=[64, 128, 256, 512, 1024]
        
        # backbone weights should only be specified when using EfficientNet
        if self.encoder_name == 'UNet' and self.backbone_weights is not None:
            self.backbone_weights == None
            warn("`backbone_weights` should be `null` or ignored when `encoder_name` is `UNet`")
        elif self.encoder_name == 'EfficientNetB7' and self.backbone_weights is None: 
            raise ValueError(f"`backbone_weights` should be `random` or `imagenet` when `encoder_name` is `EfficientNetB7`")
        
        # backbone finetuning should be none when using random weights (UNet or random EfficientNet)
        if self.encoder_name == 'UNet' and self.backbone_finetuning is not None:
            self.backbone_finetuning = None
            warn("`backbone_finetuning` should be `null` or ignored when `encoder_name` is `UNet`")
        elif self.encoder_name == 'EfficientNetB7' and self.backbone_weights == 'random' and self.backbone_finetuning is not None:
            self.backbone_finetuning = None
            warn("`backbone_finetuning` should be `null` when `backbone_weights` is `random`")

        # block level unfreezing should be an array of block ints (0,1,2,...,7) 
        if self.encoder_name == 'EfficientNetB7' and type(self.backbone_finetuning) == list:
            assert len(self.backbone_finetuning) < 8 and len(self.backbone_finetuning) > 0, "There must be at least 1 and at most 7 block indices in `backbone_finetuning`"
            assert len(set(self.backbone_finetuning)) == len(self.backbone_finetuning), "All block indices must be unique in `backbone_finetuning`"
            assert np.all(np.array(self.backbone_finetuning) < 8), "Block indices must be from 0 to 7 in `backbone_finetuning`"

        # default value is True (unfrozen backbone) if null
        if self.encoder_name == 'EfficientNetB7' and self.backbone_finetuning is None and self.backbone_weights == 'imagenet':
            warn(f"Expected `backbone_finetuning` to be an `array`, `true`, or `false`. Got `null` and defaulting to `True` (unfrozen)")
            self.backbone_finetuning = True

        return self

    @field_validator('dataset_name', mode='after')
    @classmethod
    def check_dataset(cls, dataset_name : str) -> str:
        """Checks various aspects about the dataset name provided"""
        
        # check if the dataset directory exists
        abs_path = ROOT_DIR / 'data' / dataset_name
        if not abs_path.exists():
            raise ValueError(f"Dataset can not be found at `{abs_path}`")

        # check if the dataset has the proper subdirectories
        img_subdir = abs_path / 'images'
        ann_subdir = abs_path / 'annotations'
        if not img_subdir.exists():
            raise ValueError(f"Dataset is missing the `images/` subdirectory")
        elif not ann_subdir.exists():
            raise ValueError(f"Dataset is missing the `annotations/` subdirectory")
        
        # check if the dataset subdirectories have no child directories themselves (only files)
        img_childdirs = [path.is_dir() for path in img_subdir.iterdir()]
        ann_childdirs = [path.is_dir() for path in ann_subdir.iterdir()]
        if any(img_childdirs):
            raise ValueError("`images/` subdirectory should contain files, not directories")
        elif any(ann_childdirs):
            raise ValueError("`annotations/` subdirectory should contain files, not directories")
            
        # check if the dataset subdirectories have at least 2 files
        img_files = list(img_subdir.iterdir())
        ann_files = list(ann_subdir.iterdir())
        num_imgs = len(img_files)
        num_anns = len(ann_files)
        if num_imgs != num_anns:
            raise ValueError(
                f"There must be the same number of images and annotations. Got {num_imgs} image files and {num_anns} annotation files")
        elif num_imgs < 2:
            raise ValueError("There must be at least 2 image/annotation file pairs")
        elif len(set(img_files)) != num_imgs: # set removes duplicates
            raise ValueError("Every image/annotation filename must be unique")
        
        # check if image/annotations have mixed file types 
        img_ext = {file.suffix for file in img_files}
        ann_ext = {file.suffix for file in ann_files}
        if len(img_ext) != 1:
            raise ValueError(f'Images must have the same file type. Got types {img_ext}')
        elif len(ann_ext) != 1:
            raise ValueError(f'Annotations must have the same file type. Got types {ann_ext}')
        
        # make sure image/annotations are are JPEGs or PNGs
        img_ext = next(iter(img_ext)) # gets element of singleton set
        ann_ext = next(iter(ann_ext))
        allowed_file_types = ['.jpg', '.jpeg', '.png']
        if img_ext not in allowed_file_types:
            raise ValueError(f'Expected image filetype to be one of {allowed_file_types}, got {img_ext}')
        if ann_ext not in allowed_file_types:
            raise ValueError(f'Expected annotation filetype to be one of {allowed_file_types}, got {ann_ext}')
               
        # check if every image has a corresponding annotation (by name)
        img_stems = {file.stem for file in img_files}
        ann_stems = {file.stem for file in ann_files}
        if img_stems != ann_stems:
            unpaired_stems = img_stems ^ ann_stems # get disjointed elements (unique to each set)
            raise ValueError(f'Found unpaired image or annotation files with filenames {unpaired_stems}')
        
        return dataset_name
    
    @field_validator('batch_size', mode='after')
    @classmethod
    def batch_size_warning(cls, batch_size : int) -> int:
        """Warn the user if batch size is not a power of 2"""
        if np.log2(batch_size) % 1 != 0.0:
            warn("`batch_size` is not a power of two. Efficiency may be reduced")
        
        return batch_size
    
    @model_validator(mode='after')
    def check_early_stopping(self) -> 'Train':
        """Checks early_stopping and patience fields and validate when doing cross validation"""
        
        # patience should only be provided when using early stopping and should be less than the number of epochs
        if self.early_stopping == False and self.patience is not None:
            self.patience = None
            warn(f"`patience` should be `null` if `early_stopping` is `false`. Changed `patience` to `null`")
        elif self.early_stopping == True:
            assert self.patience < self.num_epochs, "`patience` can not be greater than `num_epochs`"

        # early_stopping should be null when doing cross validation
        if self.cross_validation is True:
            if self.early_stopping is not None or self.patience is not None:
                self.early_stopping = None
                self.patience = None
                warn("`early_stopping` and `patience` should be `null` or ignored when `cross_validation` is `true`")

        return self

    @model_validator(mode='after')
    def check_train_val(self) -> 'Train':
        """Checks many aspects of the train-val splitting when training single models and doing cross validation. It applies the logical tree from the docs"""
        
        # dataset has already been validated since field_validators run first
        data_dir = ROOT_DIR / 'data' / self.dataset_name / 'images'
        img_stems = [file.stem for file in data_dir.iterdir()]
        img_stems_set = set(img_stems)

        # basic cross validation check 
        if self.cross_validation == True:
            # val sets will be determined algorithmically
            assert self.validation_set is None, "`validation_set` should be `null` or ignored when `cross_validation` is `true`"

        # logical tree with train at the top (see docs)

        # ------------ FORK 1: train set provided or not ------------ #
        if self.training_set is not None:
            # make sure all train files exist
            training_set = set(self.training_set)
            check_files(training_set, img_stems_set, 'train')

            # ------------ FORK 2: cross validation is true or not ------------ # 
            if self.cross_validation == True:
                # check number of folds
                assert self.num_folds < len(self.training_set), "`num_folds` can not be greater than the number of training images used for cross validation"

            else:
                # ------------ FORK 3: val set is provided or not ------------ # 
                if self.validation_set is not None:
                    # check overlap between train and val sets and make sure val set files exist
                    validation_set = set(self.validation_set)
                    train_val_overlap = training_set & validation_set
                    assert len(train_val_overlap) == 0, f"Files with the names {train_val_overlap} were found in both `training_set` and `validation_set`"
                    check_files(validation_set, img_stems_set, 'validation')
                else:

                    # ------------ FORK 4: auto split is provided or not ------------ # 
                    if self.auto_split:
                        # make sure auto_split is not too large where no train set is created
                        assert self.auto_split < (1-(1/len(self.training_set))), f"auto_split validation hold-out percentage must be less than {1-(1/len(self.training_set))} for the `training_set` provided"
                    else:
                        # val set is the complement of train; make sure train does not have all available images
                        assert training_set < img_stems_set, "`training_set` can not have all images in the dataset. Some must be left over for `validation_set`"
                        # generate val set (sort them naturally)
                        self.validation_set = os_sorted(list(img_stems_set - training_set))

        # ------------ FORK 1: back to the top (train set not provided) ------------ # 
        else:

            # ------------ FORK 2: cross validation is true or not ------------ # 
            if self.cross_validation == True:
                self.training_set = img_stems
            else:

                # ------------ FORK 3: val_set is provided or not ------------ #
                if self.validation_set is not None:
                    # train set is the complement to val set
                    validation_set = set(self.validation_set)
                    assert validation_set < img_stems_set, "`validation_set` can not have all images in the dataset. Some must be left over for `training_set`"
                    # generate train set (sort them naturally)
                    self.training_set = os_sorted(list(img_stems_set - validation_set))
                else:
                    # define auto_split if it is not provided, or just check its value
                    if not self.auto_split:
                        self.auto_split = 0.4
                        warn("`auto_split` validation hold-out percentage not provided even though `training_set` and `validation_set` are `null` and `cross_validation` is `false`. Defaulting to 40%")
                    else:      
                        assert self.auto_split < (1-1/len(img_stems_set)), f"`auto_split` validation hold-out percentage must be less {1-(1/len(img_stems_set))} for the `dataset_name` provided"

        return self
    
    @model_validator(mode='after')
    def generate_results_dir(self) -> 'Train':
        """Create the results directory following a naming scheme if one is not provided"""

        if self.results_dir is None:
            now = datetime.now()
            self.results_dir = 'results_' + self.dataset_name + '_' + self.operation_mode + '_' + self.encoder_name + '_' + self.decoder_name 
            if self.cross_validation:
                self.results_dir += '_crossval'
            self.results_dir += now.strftime('_(%Y-%m-%d)_(%H-%M-%S)')
        
        self.results_dir = ROOT_DIR / self.results_dir

        return self
    
    @model_validator(mode='after')
    def get_image_shape(self) -> 'Train':
        """Get the image shape for model instantiation and make sure it is consistent"""
        
        data_dir = ROOT_DIR / 'data' / self.dataset_name

        # load each image or annotation and get the array shape; len(set)=1is found
        img_shapes = {cv.imread(str(img_path), cv.IMREAD_COLOR).shape for img_path in (data_dir / 'images').iterdir()}
        ann_shapes = {cv.imread(str(ann_path), cv.IMREAD_COLOR).shape for ann_path in (data_dir / 'annotations').iterdir()}
        img_shape = next(iter(img_shapes))
        ann_shape = next(iter(ann_shapes))

        # make sure all images and annotations have only one shape 
        if len(img_shapes) > 1:
            raise KeyError(f"Expected all images to have the same shape. Got shapes {img_shapes}")
        elif len(ann_shapes) > 1:
            raise KeyError(f"Expected all annotations to have the same shape. Got shapes {ann_shapes}")
        elif img_shape != ann_shape:
            raise KeyError(f"Expected images and annotations to have the same shape. Got an image shape of `{img_shape}` and an annotation shape of `{ann_shape}`")
        else:
            self.input_shape = img_shape

        return self

            
class Inference(BaseModel):
    root_dir: Path
    operation_mode: str
    dataset_name: str
    results_dir: Optional[str | Path]
    model_path: str = Field(
        description="Path relative to /path/to/unet-compare/ to an existing model to be used for inference"
    )
    
    @field_validator('dataset_name', mode='after')
    @classmethod
    def check_dataset(cls, dataset_name : str) -> str:
        """Checks various aspects about the dataset name provided"""
        
        # check if the dataset directory exists
        abs_path = ROOT_DIR / 'data' / dataset_name
        if not abs_path.exists():
            raise ValueError(f"Dataset can not be found at `{abs_path}`")
        
        # check if the dataset directory has no child directories (only files)
        childdirs = [path.is_dir() for path in abs_path.iterdir()]
        if any(childdirs):
            raise ValueError("Dataset subdirectory should contain files, not directories")
            
        # check if the dataset has at least 1 file
        img_files = list(abs_path.iterdir())
        num_imgs = len(img_files)
        if not num_imgs:
            raise ValueError(f"There must be at least 1 image to inference in the dataset directory")
        
        # make sure images are JPEGs or PNGs
        img_exts = {file.suffix for file in img_files}
        allowed_file_types = ['.jpg', '.jpeg', '.png']
        unallowed_files = []
        for file in img_files:
            if file.suffix not in allowed_file_types:
                unallowed_files.append(file)
        if len(unallowed_files):
            raise ValueError(f'The files {unallowed_files} have invalid file types')
        
        # make sure images have the same file type
        if len(img_exts) > 1:
            raise KeyError(f"Expected all files to have the same file type. Got mixed types {img_exts}")
    
        return dataset_name

    @field_validator('model_path', mode='after')
    @classmethod
    def check_model(cls, model_path) -> 'General':
        # check if the model file exists
        abs_path = ROOT_DIR / model_path
        if not abs_path.exists():
            raise ValueError(f'No model file exists at `{abs_path}`')

        # check if it is a .keras model file
        ext = abs_path.name.split('.', 1)[-1]
        if ext == 'weights.h5':
            raise ValueError('Expected a .keras model file, got a weights.h5 file instead')
        elif ext == 'keras':
            pass
        else:
            raise ValueError(f'Expected a .keras file, got a .{ext} file')
        
        return model_path

    @model_validator(mode='after')
    def generate_results_dir(self) -> 'Train':
        """Create the results directory following a naming scheme if one is not provided"""

        if self.results_dir is None:
            now = datetime.now()
            self.results_dir = 'results_' + self.dataset_name + '_' + self.operation_mode + now.strftime('_(%Y-%m-%d)_(%H-%M-%S)')
        
        self.results_dir = ROOT_DIR / self.results_dir

        return self
    

def check_files(file_set: set, master_set: set, set_type: str):
    """Checks if all files in a set exist in a provided directory"""
    if not file_set <= master_set:
        missing_fns = file_set - master_set
        raise ValueError(f"Could not find files with the names {missing_fns} in the dataset given in `{set_type}_set")


def validate(input_configs: dict) -> dict:
    """Generate Pydantic models and validate input"""
    general = General.model_validate(input_configs)

    # select validator specific to the operation mode
    if general.operation_mode == 'train':
        output_configs = Train.model_validate(general.model_dump())
    else:
        output_configs = Inference.model_validate(general.model_dump())

    return output_configs.model_dump()

## Blocks

In [None]:
"""
Aiden Ochoa, 4/2025, RDMAP PSU Research Group
This module handles the definition of the custom convolution layers used in UNet models
"""

def ConvBlock(inputs, filters, batchnorm, l2_reg, index):
    """Two Conv2D with optional batchnorm layer and ReLU activation"""

    def ConvUnit(inputs, layer_index):
        x = Conv2D(filters, 3, padding='same', name='conv_'+layer_index, use_bias=not(batchnorm), kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg))(inputs)
        if batchnorm:
            x = BatchNormalization(name='bn_'+layer_index)(x)
        return Activation('relu', name='relu_'+layer_index)(x)

    x = ConvUnit(inputs, index+'a')
    return ConvUnit(x, index+'b')


def UpsampleBlock(inputs, filters, batchnorm, l2_reg, index):
    """Conv2DTranspose layer for upsampling with optional batchnorm layer and ReLU activation"""
    
    x = Conv2DTranspose(filters, 2, padding='same', name='up_'+index, use_bias=not(batchnorm), kernel_initializer='he_normal', kernel_regularizer=l2(l2_reg), strides=2)(inputs)
    if batchnorm:
        x = BatchNormalization(name='bn_'+index)(x)
    return Activation('relu', name='relu_'+index)(x)

## Dataloader

In [None]:
"""
Aiden Ochoa, 4/2025, RDMAP PSU Research Group
This module handles dataset processing for training and inference
"""

def create_train_dataset(configs: dict) -> dict:
    """Creates the training dataset based on the configs"""

    data_dir: Path = configs['root_dir'] / 'data' / configs['dataset_name']
    dataset_dir: Path = configs['root_dir'] / 'dataset'

    # Step 1: remove existing directory if it exists
    if dataset_dir.exists():
        shutil.rmtree(dataset_dir)

    # Step 2: generate train/val filename lists
    configs['training_set'], configs['validation_set'] = split_data(configs)

    # Step 3: convert filename stems to full paths
    img_ext = next(iter({file.suffix for file in (data_dir / 'images').iterdir()}))
    ann_ext = next(iter({file.suffix for file in (data_dir / 'annotations').iterdir()}))

    train_img_paths = [data_dir / 'images' / (file + img_ext) for file in configs['training_set']]
    val_img_paths = [data_dir / 'images' / (file + img_ext) for file in configs['validation_set']]
    train_ann_paths = [data_dir / 'annotations' / (file + ann_ext) for file in configs['training_set']]
    val_ann_paths = [data_dir / 'annotations' / (file + ann_ext) for file in configs['validation_set']]

    # Step 4: populate dataset tree
    copy_files(train_img_paths, dataset_dir / 'images' / 'train')
    copy_files(val_img_paths, dataset_dir / 'images' / 'val')
    copy_files(train_ann_paths, dataset_dir / 'annotations' / 'train')
    copy_files(val_ann_paths, dataset_dir / 'annotations' / 'val')

    # Step 5: augment training set
    if configs['augment']:
        augment_dataset(dataset_dir / 'images' / 'train', img_ext, ann_ext)

    # Step 6: create Dataset Tensors
    train_dataset = tf.data.Dataset.list_files(str(dataset_dir / 'images' / 'train' / '*'))
    val_dataset = tf.data.Dataset.list_files(str(dataset_dir / 'images' / 'val' / '*'))

    # replace each string with a tuple that has elements (image_tensor, annotation_tensor, image_path_tensor)
    train_dataset = train_dataset.map(lambda x: parse_image(x, img_ext, ann_ext), num_parallel_calls=AUTOTUNE)
    val_dataset = val_dataset.map(lambda x: parse_image(x, img_ext, ann_ext), num_parallel_calls=AUTOTUNE)

    BUFFER_SIZE = 48

    # shuffle the training dataset and batch it
    train_dataset = train_dataset.shuffle(buffer_size=BUFFER_SIZE)
    train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.batch(configs['batch_size'])
    train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)

    # batch the validation dataset
    val_dataset = val_dataset.repeat()
    val_dataset = val_dataset.batch(1)
    val_dataset = val_dataset.prefetch(buffer_size=AUTOTUNE)

    # determine number of steps to take in an epoch
    train_steps = len(list((dataset_dir / 'images' / 'train').iterdir())) // configs['batch_size']
    val_steps = len(list((dataset_dir / 'images' / 'val').iterdir())) // 1

    return {'train_dataset' : train_dataset, 'val_dataset' : val_dataset, 'train_steps' : train_steps, 'val_steps' : val_steps}


def create_train_val_inference_dataset(configs: dict) -> dict:

    data_dir: Path = configs['root_dir'] / 'data' / configs['dataset_name']
    img_ext = next(iter({file.suffix for file in (data_dir / 'images').iterdir()}))

    # Step 1: convert filenames to full paths
    train_paths = [data_dir / 'images' / (file + img_ext) for file in configs['training_set']]
    val_paths = [data_dir / 'images' / (file + img_ext) for file in configs['validation_set']]

    # Step 2: convert full paths list to tensor
    train_dataset = tf.data.Dataset.from_tensor_slices([str(path) for path in train_paths])
    val_dataset = tf.data.Dataset.from_tensor_slices([str(path) for path in val_paths])

    # Step 3: load images and batch them
    train_dataset = train_dataset.map(lambda x: parse_image(x, img_ext, None), num_parallel_calls=AUTOTUNE)
    val_dataset = val_dataset.map(lambda x: parse_image(x, img_ext, None), num_parallel_calls=AUTOTUNE)

    train_dataset = train_dataset.batch(1)
    val_dataset = val_dataset.batch(1)

    return {'train_dataset' : train_dataset, 'train_paths' : train_paths, 'val_dataset' : val_dataset, 'val_paths' : val_paths}


def create_inference_dataset(configs: dict) -> dict[tf.Tensor, list[Path]]:
     # define path to images
    data_dir= configs['root_dir'] / 'data' / configs['dataset_name']
    data_paths = [path for path in data_dir.iterdir()]
    img_ext = next(iter(set(data_paths)))

    # create tensorflow dataset
    dataset = tf.data.Dataset.from_tensor_slices([str(path) for path in data_paths])
    dataset = dataset.map(lambda x: parse_image(x, img_ext, None), num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(1)

    return {'dataset' : dataset, 'data_paths' : data_paths}


"""Miscellaneous functions"""


def split_data(configs: dict) -> Tuple[List[str], List[str]]:    
    # get current state
    data_dir = configs['root_dir'] / 'data' / configs['dataset_name'] / 'images'
    all_fns = [file.stem for file in data_dir.iterdir()]
    train_fns, val_fns = configs['training_set'], configs['validation_set']

    # Case 1: only training files provided
    if (train_fns is not None) and (val_fns is None):
        # split it with the given percentage
        if configs['auto_split']:
            random.shuffle(train_fns)
            train_upper = int(len(train_fns)*(1-configs['auto_split']))
            return train_fns[:train_upper], train_fns[train_upper:]
        # or use remaining files
        else:
            val_fns = os_sorted(list(set(all_fns) - set(train_fns)))
            return train_fns, val_fns
    
    # Case 2: only validation files provided
    elif (train_fns is None) and (val_fns is not None):
        # use remaining files
        train_fns = os_sorted(list(set(all_fns) - set(val_fns)))
        return train_fns, val_fns
    
    # Case 3: both training and validation files provided
    elif (train_fns is not None) and (val_fns is not None):
        # do nothing
        return train_fns, val_fns

    # Case 4: neither training nor validation files provided
    else:
        # split it with the given percentage (40% is default)
        random.shuffle(all_fns)
        train_upper = int(len(all_fns)*(1-configs['auto_split']))
        return all_fns[:train_upper], all_fns[train_upper:]


def copy_files(file_paths: list[Path], dest_dir: Path):
    if not dest_dir.exists():
        dest_dir.mkdir(parents=True)
       
    for path in file_paths:
        shutil.copy(path, dest_dir / path.name)


def augment_dataset(train_image_dir: Path, img_ext: str, ann_ext: str):
    # loop through training images
    for img in glob(str(train_image_dir / '*')):
        # replace images in path with annotations and image extension to annotation extension
        ann = re.sub('images', 'annotations', img)
        ann = re.sub(img_ext, ann_ext, ann)

        # augment image and annotation
        augment_single_image(Path(img))
        augment_single_image(Path(ann))


def augment_single_image(path : Path):
    # load file
    image = cv.imread(str(path))

    # perform transformations on image
    image_1 = path.parent / (path.stem + '_1' + path.suffix)    # original (just change its name)
    image_2 = cv.rotate(image, cv.ROTATE_90_CLOCKWISE)          # rot90
    image_3 = cv.rotate(image, cv.ROTATE_180)                   # rot180
    image_4 = cv.rotate(image, cv.ROTATE_90_COUNTERCLOCKWISE)   # rot270
    image_5 = cv.flip(image, 1)                                 # xflip
    image_6 = cv.flip(image_2, 1)                               # rot90 + xflip
    image_7 = cv.flip(image_2, 0)                               # rot90 + yflip
    image_8 = cv.flip(image_3, 1)                               # rot180 + xflip

    # save augmentations
    path.rename(image_1)
    cv.imwrite(str(path.parent / (path.stem + '_2' + path.suffix)), image_2)
    cv.imwrite(str(path.parent / (path.stem + '_3' + path.suffix)), image_3)
    cv.imwrite(str(path.parent / (path.stem + '_4' + path.suffix)), image_4)
    cv.imwrite(str(path.parent / (path.stem + '_5' + path.suffix)), image_5)
    cv.imwrite(str(path.parent / (path.stem + '_6' + path.suffix)), image_6)
    cv.imwrite(str(path.parent / (path.stem + '_7' + path.suffix)), image_7)
    cv.imwrite(str(path.parent / (path.stem + '_8' + path.suffix)), image_8)


def parse_image(img_path: tf.Tensor, img_ext: str, ann_ext: str) -> tuple[tf.Tensor, tf.Tensor]:
    # read image and load it into 3 channels (pre-trained backbones require 3) and normalize it
    image = tf.io.read_file(img_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    
    try:
        # adjust path to lead to the corresponding annotation and load it
        ann_path = tf.strings.regex_replace(img_path, 'images', 'annotations')
        ann_path = tf.strings.regex_replace(ann_path, img_ext, ann_ext)
        annotation = tf.io.read_file(ann_path)
        annotation = tf.image.decode_png(annotation, channels=1)
        annotation = tf.cast(annotation, tf.float32) / 255.0
        return image, annotation
    except:
        # if ann_ext is None, we just want the image
        return image

## Models

In [None]:
"""
Aiden Ochoa, 4/2025, RDMAP PSU Research Group
This module handles the definition of encoder/decoder subnetworks and their connection
"""

def load_UNet(configs : dict) -> keras.Model:
    """U-Net built with Functional API using either U-Net or EfficientNetB7 encoders and either U-Net or U-Net++ decoders"""

    enc_filters = configs['encoder_filters']
    dec_filters = configs['decoder_filters']
    batchnorm = configs['batchnorm']
    l2_reg = configs['L2_regularization_strength']

    input = keras.Input(shape = configs['input_shape'], name = 'main_input')

    # encoder
    if configs['encoder_name'] == 'UNet':
        model_name = 'UNet'
        x = input
        enc_outputs = []
        # repeatedly adds a convolution layer and saves the current outputs
        for idx, filters in enumerate(enc_filters):
            name_idx = f'{idx}0'
            # conv
            x = ConvBlock(x, filters, batchnorm, l2_reg, name_idx)
            enc_outputs.append(x)
            # pool except for the last block
            if idx < 4:
                x = MaxPooling2D(pool_size=2, name = 'pool_'+name_idx)(x)
        
    elif configs['encoder_name'] == 'EfficientNetB7':
        model_name = 'EfficientNetB7'
        if configs['backbone_weights'] == 'random':
            weights = None
        else:
            weights = 'imagenet'
        backbone = EfficientNetB7(include_top = False, weights = weights, input_tensor = input)

        # handles model freezing (batchnorm layers must always stay frozen)
        # freeze backbone
        if configs['backbone_finetuning'] == False:
            for layer in backbone.layers:
                layer.trainable = False
        else:
            # freeze specific blocks (or leave the whole model unfrozen if not a list)
            if type(configs['backbone_finetuning']) == list:
                block_strs = ['block' + str(block_idx) for block_idx in configs['backbone_finetuning']]
                for layer in backbone.layers:
                    if layer.name[:6] not in block_strs:
                        layer.trainable = False
            # freeze batchnorm layers
            for layer in backbone.layers:
                if isinstance(layer, tf.keras.layers.BatchNormalization):
                    layer.trainable = False

        enc_stages = ['stem_activation', 'block2g_add', 'block3g_add', 'block5j_add', 'block7d_add']
        enc_outputs = [backbone.get_layer(stage).output for stage in enc_stages]
    
    # decoder
    enc_outputs.reverse()
    if configs['decoder_name'] == 'UNet':
        model_name += '-UNet'
        x = enc_outputs[0]
        for idx, filters in enumerate(dec_filters[:-1]):
            name_idx = f'{3-idx}{idx+1}' # 31, 22, 13, 04
            x = UpsampleBlock(x, dec_filters[idx], batchnorm, l2_reg, name_idx)
            x = Concatenate(name='cat_'+name_idx)([x, enc_outputs[idx+1]])
            x = ConvBlock(x, dec_filters[idx], batchnorm, l2_reg, name_idx)

    elif configs['decoder_name'] == 'UNet++':
        model_name += '-UNetpp'
        prev_row_outs = [enc_outputs[0]]
        for row in range(4): # number of rows
            current_row_outs = [enc_outputs[row+1]]
            for node in range(row+1): # 1, 2, 3, 4 nodes per row
                name_idx = f'{3-row}{node+1}' # 31, 21, 22, 11, 12, 13, 01, 02, 03, 04
                x = UpsampleBlock(prev_row_outs[node], dec_filters[row], batchnorm, l2_reg, name_idx)
                x = Concatenate(name='cat_'+name_idx)([x] + current_row_outs[:(node+1)])
                x = ConvBlock(x, dec_filters[row], batchnorm, l2_reg, name_idx)
                current_row_outs.append(x)
            prev_row_outs = current_row_outs

    # final layers
    if configs['encoder_name'] == 'EfficientNetB7':
        x = UpsampleBlock(x, dec_filters[4], batchnorm, l2_reg, 'final') # upsamples back to original resolution
    sigmoid = Conv2D(1, 1, activation='sigmoid', name='main_output', kernel_initializer='he_normal', padding='same', kernel_regularizer=l2(l2_reg))(x)

    return keras.Model(inputs=input, outputs=sigmoid, name = model_name)

## Utils

In [None]:
"""
Aiden Ochoa, 4/2025, RDMAP PSU Research Group
This module handles all accessory operations such as plotting
"""

def save_preds(preds : list[np.array], save_paths : list[Path]):
    """Save inference predictions from model.predict()"""
    
    for i, pred in enumerate(preds):
        save_img(str(save_paths[i]), pred)


def print_save_configs(configs : dict):
    """Creates the results directory, prints the input configs after validation, and saves a copy to results"""
    
    # create top-level results directory
    configs['results_dir'].mkdir()
    
    # print input to user for confirmation
    print('-'*60 + ' User Input ' + '-'*60)
    for key, val in configs.items():
        print(key + ':', val)
    print('-'*132)

    # save configs into results dir for reference
    with open(configs['results_dir'] / 'configs.json', 'w') as con:
        # make Path objects strings for serialization
        configs['root_dir'] = str(configs['root_dir'])
        configs['results_dir'] = str(configs['results_dir'])
        json.dump(configs, con)


def plot_results(configs : dict):
    """Loads training metrics from a single training loop, plots them, and saves it to the results directory"""

    # paths
    metrics_path = configs['results_dir'] / 'metrics.csv'
    plot_save_path = configs['results_dir'] / 'metrics.png'

    # read metrics into dataframe
    metrics = pd.read_csv(metrics_path)

    # get num epochs, and redefine it offset by 1
    num_epochs = metrics['epoch'].count()
    metrics['epoch'] = metrics['epoch'] + 1

    # determine lowest val loss index (where val loss is closest to min(val_loss))
    best_idx = np.where(np.isclose(metrics['val_loss'], min(metrics['val_loss'])))[0]

    # add f1 columns to the dataframe
    metrics['f1'] = add_f1(metrics)
    metrics['val_f1'] = add_f1(metrics, val = True)

    # generate subplots
    fig, axs = plt.subplots(4, 1, figsize=(12,20))

    titles = ['BCE Loss', 'Precision', 'Recall', 'F1-Score']
    y_axes = ['loss', 'Precision', 'Recall', 'f1']

    for i in range(len(axs)):
        y1 = y_axes[i]
        y2 = 'val_' + y1

        # plot metric curve
        axs[i].plot(metrics['epoch'], metrics[y1], '-o',  label='Train')
        axs[i].plot(metrics['epoch'], metrics[y2], '-o', label='Val')
        
        # add point corresponding to lowest val loss on each curve
        axs[i].plot(best_idx + 1, metrics[y1].iloc[best_idx], 'D', color='purple')
        axs[i].plot(best_idx + 1, metrics[y2].iloc[best_idx], 'D', color='purple', label='Min Val Loss')
        
        # misc settings
        axs[i].set_xlabel('Epoch')
        axs[i].set_ylabel(titles[i])
        axs[i].set_xlim([1,num_epochs])
        if i == 0:
            axs[i].set_yscale('log')
        else:
            axs[i].set_yticks(ticks=np.arange(0,1.1,0.1))
        axs[i].grid(visible=True)
        axs[i].legend()     

    fig.savefig(str(plot_save_path), bbox_inches="tight")


def add_f1(metrics : pd.DataFrame, val = False) -> pd.DataFrame:
    """Calculates the f1-score element-wise given columns of a Pandas metrics dataframe"""

    if val == True:
        p = 'val_Precision'
        r = 'val_Recall'
    else:
        p = 'Precision'
        r = 'Recall'
    
    # if the denominator is 0, then f1=0, otherwise it is the harmonic mean
    return np.where(metrics[p] + metrics[r] == 0, 0, 2  * (metrics[p] * metrics[r]) / (metrics[p] + metrics[r]))


def cv_plot_results(configs : dict):
    """Loads in metrics from every cross validation fold, plots loss curves together, and plots statistics for each epoch"""

    # paths
    loss_save_path = configs['results_dir'] / 'loss.png'
    metrics_save_path = configs['results_dir'] / 'metrics.png'

    # get fold directory names and sort them using natural sorting
    fold_dirs = glob(str(configs['results_dir'] / 'fold_*'))
    fold_dirs = os_sorted(fold_dirs)

    # dict to hold dataframes for each fold
    all_metrics = []
    for fold in range(len(fold_dirs)):
        # get metrics from csv
        fold_metrics = pd.read_csv(str(Path(fold_dirs[fold]) / 'metrics.csv'))

        # add f1 columns to the dataframe
        fold_metrics['f1'] = add_f1(fold_metrics)
        fold_metrics['val_f1'] = add_f1(fold_metrics, val = True)

        # convert to np array and add metrics array to list 
        fold_metrics_np = fold_metrics.to_numpy()
        all_metrics.append(fold_metrics_np)

    # stack all metrics arrays along a new 3d axis
    all_metrics = np.stack(all_metrics, axis=0)

    # get epochs list (always the same)
    epochs = all_metrics[0, :, 0].astype(int) + 1

    # plot loss curves together on two separate plots
    fig, axs = plt.subplots(2, 1, figsize=(12, 10))

    for fold in range(configs['num_folds']):
        # train/val losses
        train_loss = all_metrics[fold, :, 4]
        val_loss = all_metrics[fold, :, 8]

        # plot losses
        axs[0].plot(epochs, train_loss, label = 'Fold {}'.format(fold+1))
        axs[1].plot(epochs, val_loss, label = 'Fold {}'.format(fold+1))

    # formatting
    axs[0].set_ylabel('Train Loss (BCE)')
    axs[1].set_ylabel('Val Loss (BCE)')
    for ax in axs:
        ax.set_yscale('log')   
        ax.set_xlim([1, configs['num_epochs']])
        ax.set_xlabel('Epoch')
        ax.grid(visible=True)
        ax.legend()

    fig.savefig(str(loss_save_path), bbox_inches="tight")

    # get mean and stdev across all folds
    num_metrics = np.shape(all_metrics)[-1]
    metrics_mean = np.zeros((configs['num_epochs'], num_metrics))
    metrics_std = np.zeros((configs['num_epochs'], num_metrics))

    for metric in range(num_metrics):
        for epoch in range(configs['num_epochs']):
            metrics_mean[epoch, metric] = np.mean(all_metrics[:, epoch, metric])
            metrics_std[epoch, metric] = np.std(all_metrics[:, epoch, metric])

    # plot averaged metrics with std as error bars
    fig, axs = plt.subplots(4, 1, figsize=(12, 20))

    # settings specific to each plot 
    titles = ['BCE Loss', 'Precision', 'Recall', 'F1-Score']
    train_metrics_idcs = [4, 1, 2, 9]
    val_metrics_idcs = [8, 5, 6, 10]

    for i in range(len(axs)):
        # means and stds
        train_mean = metrics_mean[:, train_metrics_idcs[i]]
        val_mean = metrics_mean[:, val_metrics_idcs[i]]

        train_std = metrics_std[:, train_metrics_idcs[i]]
        val_std = metrics_std[:, val_metrics_idcs[i]]

        # plot mean
        axs[i].errorbar(epochs, train_mean, yerr=train_std, fmt='-o', capsize=3, capthick=1, label='Train')
        axs[i].errorbar(epochs, val_mean, yerr=val_std, fmt='-o', capsize=3, capthick=1, label='Val')

        axs[i].set_xlabel('Epoch')
        axs[i].set_ylabel(titles[i])
        axs[i].set_xlim([1, configs['num_epochs']])
        if i == 0:
            axs[i].set_yscale('log')
        else:
            axs[i].set_yticks(ticks=np.arange(0,1.1,0.1))
        axs[i].grid(visible=True)
        axs[i].legend()

    fig.savefig(str(metrics_save_path), bbox_inches="tight")


def create_folds(img_list : list, num_folds : int) -> tuple[list, list]:
    """Given images and the number of folds, create training/validation sets with the most even distribution possible"""

    # number of validation images to be held out in each fold
    num_val = np.zeros(num_folds)

    # randomly shuffle the image list given a numpy seed to prevent sequence bias
    np.random.seed(203)
    img_list = np.random.permutation(img_list)

    # determine number of hold out images in each fold
    for i in range(num_folds):
        # start with integer quotient
        num_val[i] = math.floor(len(img_list) / num_folds)
        # distribute the remainder evenly among the first folds   
        if i < (len(img_list) % num_folds):
            num_val[i] += 1
    
    # convert number of hold out images to indicies
    running_sum = np.cumsum(num_val)
    running_sum = np.insert(running_sum, 0, 0)
    lower_idxs = running_sum[:-1].astype(int)
    upper_idxs = running_sum[1:].astype(int)

    # save train/val sets as elements of a list
    train_sets, val_sets = [0]*num_folds, [0]*num_folds
    for i in range(num_folds):
        # bounds
        low = lower_idxs[i]
        up = upper_idxs[i]

        # fold val set
        val_sets[i] = img_list[low:up]

        # fold train set is the complement
        train_sets[i] = np.delete(img_list, np.arange(low, up)).tolist()

    return train_sets, val_sets

## Run

In [None]:
"""
Aiden Ochoa, 4/2025, RDMAP PSU Research Group
This module handles all training and inference operations. It has __main__
"""

class Operations:
    """Singleton class for training and inference"""

    def __init__(self, configs: dict):
        """Initialize configs, dataset, and model"""
        self.configs = configs
        self.dataset = None
        self.model = None

    def single_loop(self):
        """Trains a single model using a training and validation set"""

        # load dataset and model
        print(f"\nCreating dataset from `{self.configs['dataset_name']}`...\n")
        self.dataset = create_train_dataset(self.configs)
        print(f"Training images: {os_sorted(self.configs['training_set'])}")
        print(f"Validation images: {os_sorted(self.configs['validation_set'])}\n")
        print(f"Loading and compiling `{self.configs['encoder_name']}-{self.configs['decoder_name']}`...\n")
        self.model = load_UNet(self.configs)
        self.model.compile(
            optimizer = Adam(learning_rate=self.configs['learning_rate']), 
            loss = 'binary_crossentropy', 
            metrics = ['accuracy', 'Precision', 'Recall']
        )

        # define training callbacks
        callbacks = [
            CSVLogger(
                str(self.configs['results_dir'] / 'metrics.csv'), 
                separator=',', 
                append=False
            ),
            ModelCheckpoint(
                str(self.configs['results_dir'] / 'best_model.keras'), 
                verbose=1, 
                save_best_only=True, 
                save_weights_only=False
            )
        ]
        if self.configs['early_stopping']:
            callbacks.append(EarlyStopping(patience = self.configs['patience']))
        
        # start training
        print("Training model...\n")
        self.model.fit(
            self.dataset['train_dataset'],
            epochs = self.configs['num_epochs'],
            steps_per_epoch = self.dataset['train_steps'],
            validation_data = self.dataset['val_dataset'],
            validation_steps = self.dataset['val_steps'],
            callbacks=callbacks,
            verbose=2
        )

        # load best model and inference train/val sets
        print("\nInferencing training and validation images with best model...\n")
        self.model = load_model(str(self.configs['results_dir'] / 'best_model.keras'))
        self.dataset = create_train_val_inference_dataset(self.configs)
        self.inference()

        # plot metrics
        print("\nPlotting metrics...\n")
        plot_results(self.configs)

        print("Cleaning up...\n")
        # remove training dataset and clear memory
        shutil.rmtree(self.configs['root_dir'] / 'dataset')
        K.clear_session()
        gc.collect()


    def crossval_loop(self):
        """Trains num_folds models using different non-overlapping validation sets"""

        # determine all training and val set combinations given the number of folds
        train_sets, val_sets = create_folds(self.configs['training_set'], self.configs['num_folds'])
        
        # save original results directory
        top_level_results = self.configs['results_dir']
        
        print(f"\nStarting cross validation with {self.configs['num_folds']} folds...\n")

        # loop through folds
        for fold in range(self.configs['num_folds']):
            print('-'*62 + ' Fold {} '.format(fold+1) + '-'*62)

            # update train/val sets
            self.configs.update({'training_set' : train_sets[fold]})
            self.configs.update({'validation_set' : val_sets[fold]})

            # create results directory for this fold
            results_dir = self.configs['results_dir'] / ('fold_' + str(fold+1))
            self.configs.update({'results_dir' : results_dir})
            results_dir.mkdir()

            # train fold
            self.single_loop()

            # reset results directory for next fold
            self.configs.update({'results_dir' : top_level_results})

        # plot metrics over all folds
        print("Plotting cross validation results...\n")
        cv_plot_results(self.configs)


    def inference(self):
        """Either inferences a training-validation pair of images, or just a single set of images"""

        if self.configs['operation_mode'] == 'train':
            # process train and val images
            train_preds = self.model.predict(self.dataset['train_dataset'], verbose=2)
            val_preds = self.model.predict(self.dataset['val_dataset'], verbose=2)

            # define and make output directories
            train_save_dir = self.configs['results_dir'] / 'train_preds'
            val_save_dir = self.configs['results_dir'] / 'val_preds'
            train_save_dir.mkdir()
            val_save_dir.mkdir()

            # define pred save file paths
            train_save_paths = [train_save_dir / (file.stem + '.png') for file in self.dataset['train_paths']]
            val_save_paths = [val_save_dir / (file.stem + '.png') for file in self.dataset['val_paths']]

            # save predictions
            save_preds(train_preds, train_save_paths)
            save_preds(val_preds, val_save_paths)

        else:
            # load model and dataset
            print("\nLoading data and model...\n")
            self.model = load_model(str(self.configs['root_dir'] / self.configs['model_path']))
            self.dataset = create_inference_dataset(self.configs)

            # process dataset
            print("Generating model predictions...\n")
            preds = self.model.predict(self.dataset['dataset'], verbose=2)

            # define output directories and paths
            print("\nSaving model predictions...\n")
            save_dir = self.configs['results_dir'] / 'preds'
            save_dir.mkdir()
            save_paths = [save_dir / (file.stem + '.png') for file in self.dataset['data_paths']]

            # save predictions
            save_preds(preds, save_paths)


    def save_model_summary(self):
        """Writes the model summary to a file after training and writes a file about the trainable layers"""

        with open(self.configs['results_dir'] / 'model_summary.out', 'w') as f:
            self.model.summary(print_fn=lambda x: f.write(x + '\n'))

        with open(self.configs['results_dir'] / 'trainable.out', 'w') as f:
            f.write(f"{'Layer':<35} {'Trainable':<20}\n")
            f.write("=" * 50 + "\n")
            for layer in self.model.layers:
                f.write(f"{layer.name:<35} {str(layer.trainable):<20}\n")


def main(configs: dict):
    """Validates configs and instantiates operations class"""

    # validate and print configs
    configs = validate(configs)
    print_save_configs(configs.copy())

    # start operations
    operations = Operations(configs)

    # training or inference
    if configs['operation_mode'] == 'train':
        if configs['cross_validation']:
            operations.crossval_loop()
        else:
            operations.single_loop()

        if configs['model_summary']:
            operations.save_model_summary()
    else:
        operations.inference()
    
    print("Done.")

# Training/Inference

In [None]:
# configs example
configs = {
    "L2_regularization_strength" : 1e-4,
    "num_epochs" : 200,
    "patience" : 25,
    "dataset_name" : "gb_512",
    "batch_size" : 2,
    "auto_split" : 0.3
}

In [None]:
main(configs)