In [2]:
import os 
import sys 
from dataclasses import dataclass
from pathlib import Path 
import urllib.request as request 
from zipfile import ZipFile
import tensorflow as tf
from poxVisionDetection import logging,CustomException

In [3]:
# os.chdir('d:\\codes\\DeepLearning_Proj\\poxVision_detection\\research')
currdir = os.getcwd()
print(currdir)

d:\codes\DeepLearning_Proj\poxVision_detection\research


In [4]:
os.chdir('../') # MOVING ONE LEVEL UP 
print(os.getcwd())

d:\codes\DeepLearning_Proj\poxVision_detection


In [17]:
@dataclass(frozen = True)
class PrepareBaseModelConfig:
    root_dir                   : Path 
    base_model_path            : Path
    updated_base_model_path    : Path
    params_image_size          : list
    params_learning_rate       : float 
    params_include_top         : bool
    params_weight              : str
    params_classes             : int

In [18]:
from poxVisionDetection.constants import *
from poxVisionDetection.utils.common import read_yaml,create_directory

In [19]:
class ConfigurationManager:
    def __init__(
            self,
            config_filepath = CONFIG_FILE_PATH,
            params_filepath = PARAMS_FILE_PATH):
        
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directory([self.config.artifacts_root])

    def get_prepare_base_model(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model

        create_directory([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir                 = Path(config.root_dir),
            base_model_path          = Path(config.base_model_path),
            updated_base_model_path  = Path(config.updated_base_model_path),
            params_image_size        = self.params.IMAGE_SIZE,
            params_learning_rate     = self.params.LEARNING_RATE,
            params_include_top       = self.params.INCLUDE_TOP,
            params_weight            = self.params.WEIGHTS,
            params_classes           = self.params.CLASSES
        )

        return prepare_base_model_config

In [26]:
class PrepareBaseModel:
    def __init__(self, config : PrepareBaseModelConfig):
        self.config = config

    def get_base_model(self):
        self.model              = tf.keras.applications.ResNet50(
            include_top         = self.config.params_include_top,
            weights             = self.config.params_weight,
            input_shape         = self.config.params_image_size,
        )

        # THE BASE MODEL WILL GET SAVED IN THE PATH PROVIDED
        self.save_model(path    = self.config.base_model_path,
                        model   = self.model)

    # THE WEIGHTS THAT ARE PRESENT IN THE ResNet50 MODEL ARE GOING TO BE USED AS SUCH ONLY THE INPUT AND OUTPUT LAYERS ARE GOING TO BE TRAINED
    @staticmethod
    def _prepare_full_model(model, classes, freeze_all, freeze_till, learning_rate):
        if freeze_all:
            for layer in model.layers:
                model.trainable = False
        elif(freeze_till is not None) and (freeze_till > 0):
            for layer in model.layers[:-freeze_till]:
                model.trainable = False

        flatten = model.output

        Globalavgpool2D   = tf.keras.layers.GlobalAveragePooling2D()(flatten)

        Dlayer1            = tf.keras.layers.Dense(
            units           = 64,
            activation      = 'relu'
        )(Globalavgpool2D)

        pred_layer            = tf.keras.layers.Dense(
            units           = classes,
            activation      = 'softmax'
        )(Dlayer1)

        full_model        = tf.keras.models.Model(
            inputs          = model.input,
            outputs         = pred_layer
        )

        print(full_model)
        print('-------------------------------------------')

        full_model.compile(
            optimizer           = tf.keras.optimizers.SGD(learning_rate = learning_rate),
            loss                = tf.keras.losses.CategoricalCrossentropy(),
            metrics             = ['accuracy']
        )

        full_model.summary()
        return full_model

    def updated_base_model(self):
        self.full_model        = self._prepare_full_model(
            model              = self.model,
            classes            = self.config.params_classes,
            freeze_all         = True,
            freeze_till        = None,
            learning_rate      = self.config.params_learning_rate
        )

        self.save_model(path   = self.config.updated_base_model_path,
                        model  = self.full_model)

    @staticmethod
    def save_model(path : Path, model : tf.keras.Model):
        print(model.summary)
        model.save(path)

CREATING THE PIPELINE

In [27]:
try:
    config                       = ConfigurationManager()
    prepare_base_model_config    = config.get_prepare_base_model()
    prepare_base_model           = PrepareBaseModel(config = prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.updated_base_model()
except Exception as e:
    logging.exception(CustomException(e,sys))

[2023-11-13 01:01:11,047] 31 root - INFO - YAML FILE {path_to_yaml} LOADED SUCCESSFULLY 
[2023-11-13 01:01:11,050] 31 root - INFO - YAML FILE {path_to_yaml} LOADED SUCCESSFULLY 
[2023-11-13 01:01:11,051] 47 root - INFO - CREATED DIRECTORY AT : artifacts
[2023-11-13 01:01:11,052] 47 root - INFO - CREATED DIRECTORY AT : artifacts/prepare_base_model
<bound method Model.summary of <keras.src.engine.functional.Functional object at 0x000002338219A310>>
<keras.src.engine.functional.Functional object at 0x0000023381FD27D0>
-------------------------------------------
Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_8 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv1_pad (ZeroPaddin