In [1]:
import os

os.chdir("../")

In [29]:
from pydantic import BaseModel , FilePath , FileUrl , DirectoryPath , AnyUrl
from pydantic.dataclasses import dataclass
from pathlib import Path


class PrepareCallbacksConfig (BaseModel):
    root_dir: DirectoryPath
    tensorboard_root_log_dir: DirectoryPath
    checkpoint_model_filepath: Path
    early_stopping_patience : int
    early_stopping__monitor : str

In [30]:
from BirdClassifier.constants import *
from BirdClassifier.utils import create_directories, read_yaml

In [31]:
class ConfigurationManager:
    def __init__(
        self, config_file_path=CONFIG_FILE_PATH, param_file_path=PARAMS_FILE_PATH
    ) -> None:
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(param_file_path)
        create_directories([self.config.artifacts_root])

    def get_prepare_callbacks_config(self) -> PrepareCallbacksConfig:
        config = self.config.prepare_callbacks
        model_checkpoint_dir = os.path.dirname(config.checkpoint_model_filepath)
        create_directories([config.tensorboard_root_log_dir, model_checkpoint_dir])
        prepare_callback_config = PrepareCallbacksConfig(
            root_dir=Path(config.root_dir),
            tensorboard_root_log_dir=Path(config.tensorboard_root_log_dir),
            checkpoint_model_filepath=Path(config.checkpoint_model_filepath),
            early_stopping_patience= self.params.EARLY_STOPPING_PATIENCE,
            early_stopping__monitor = self.params.EARLY_STOPPING_MONITOR,
        )
        return prepare_callback_config

In [32]:
import tensorflow as tf 

In [37]:
import time


class PrepareCallbacks:
    def __init__(self, config: PrepareCallbacksConfig):
        self.config = config

    @property
    def _create_tb_callbacks(self):
        timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
        tb_running_log_dir = os.path.join(
            self.config.tensorboard_root_log_dir, f"tb_log_at_{timestamp}"
        )
        return tf.keras.callbacks.TensorBoard(log_dir=tb_running_log_dir)

    @property
    def _create_ckpt_callbacks(self):
        return tf.keras.callbacks.ModelCheckpoint(
            filepath=self.config.checkpoint_model_filepath , 
            save_best_only=True
        )
    @property
    def _create_early_stopping_callbacks(self):
        return tf.keras.callbacks.EarlyStopping(
            monitor=self.config.early_stopping__monitor,
            patience=self.config.early_stopping_patience,
            restore_best_weights=True
        )

    def get_tb_ckpt_callback(self):
        return [self._create_tb_callbacks, self._create_ckpt_callbacks , self._create_early_stopping_callbacks]

In [38]:
try:
    config = ConfigurationManager()
    prepare_callbacks_config = config.get_prepare_callbacks_config()
    prepare_callbacks = PrepareCallbacks(config=prepare_callbacks_config)
    prepare_callbacks.get_tb_ckpt_callback()
except Exception as e:
    raise e

2022-09-30 01:16:49.380 | INFO     | BirdClassifier.utils.common:read_yaml:30 - yaml file: configs/config.yaml loaded successfully
2022-09-30 01:16:49.382 | INFO     | BirdClassifier.utils.common:read_yaml:30 - yaml file: params.yaml loaded successfully
2022-09-30 01:16:49.383 | INFO     | BirdClassifier.utils.common:create_directories:49 - created directory at: artifacts
2022-09-30 01:16:49.384 | INFO     | BirdClassifier.utils.common:create_directories:49 - created directory at: artifacts/prepare_callbacks/tensorboard_log_dir
2022-09-30 01:16:49.385 | INFO     | BirdClassifier.utils.common:create_directories:49 - created directory at: artifacts/prepare_callbacks/checkpoint_dir
