In [1]:
import os
os.chdir("../")

In [2]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelPreparationConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_include_top: bool
    params_classes: int
    params_weights: str
    params_learning_rate: float

In [3]:
from cvClassifier import logger
from cvClassifier.utils.common import get_size, read_yaml, create_directories 
from cvClassifier.constants import *

In [4]:
class ConfigurationManager:
    # this class manages the configuration of the model preparation pipeline

    def __init__(self, config_filepath = CONFIG_FILE_PATH, params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_model_preparation_config(self) -> ModelPreparationConfig:
        ''' Gets the config details for the model preparation ingestion pipeline '''
        config = self.config.model_preparation

        create_directories([config.root_dir])

        model_preparation_config = ModelPreparationConfig(
            root_dir = config.root_dir,
            base_model_path = config.base_model_path,
            updated_base_model_path = config.updated_base_model_path,
            params_image_size = self.params.IMAGE_SIZE,
            params_include_top= self.params.INCLUDE_TOP,
            params_classes = self.params.CLASSES,
            params_weights = self.params.WEIGHTS,
            params_learning_rate = self.params.LEARNING_RATE
        )

        return model_preparation_config
    

In [5]:
import os
import urllib.request as requests
from zipfile import ZipFile
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights

In [6]:
class ModelPreparation:
    def __init__(self,config = ModelPreparationConfig):
        self.config = config
    
    def get_base_model(self):
        weights = VGG16_Weights.IMAGENET1K_V1 if self.config.params_weights == 'imagenet' else None
        self.model = models.vgg16(weights=weights)

        if not self.config.params_include_top:
            self.model = nn.Sequential(*list(self.model.features.children()))
            # *list() unpacks the list of layers in the model.features and passes them as separate arguments to nn.Sequential

        self.save_model(self.config.base_model_path, self.model)
        logger.info(f"Base model saved at {self.config.base_model_path}")

    @staticmethod
    def save_model(path: Path, model: nn.Module):
        torch.save(model, path)

    @staticmethod
    def prepare_full_model(model, classes, freeze_all, freeze_till, learning_rate):
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False
        elif (freeze_till is not None) and (freeze_till > 0):
            layers = list(model.children())
            for layer in layers[:-freeze_till]:
                for param in layer.parameters():
                    param.requires_grad = False
        # add a check here to ensure freeze_till is not larger than layer size

        num_features = 512 * 7 * 7 

        if isinstance(model, nn.Sequential):
            base_layers = list(model.children())
        else:
            # If it's a full VGG model, get only the feature layers
            base_layers = list(model.features.children())

        full_model = nn.Sequential(
            *base_layers,  # Unpack the base layers
            nn.Flatten(), 
            nn.Linear(num_features, classes), 
            # softmax is included in cross-entropy loss in PyTorch
        )


        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

        print(full_model)
        print(f"Full model has {len(list(full_model.children()))} layers")
        print("Last few layers:")
        for i, layer in enumerate(list(full_model.children())[-3:]):
            print(f"  ({len(list(full_model.children()))-3+i}): {layer}")

        logger.info(f"Model prepared with {classes} classes, freeze_all={freeze_all}, freeze_till={freeze_till}, learning_rate={learning_rate}")

        return full_model, optimizer, criterion

    def update_base_model(self):
        model, optimizer, criterion = self.prepare_full_model(
            model = self.model,
            classes = self.config.params_classes,
            freeze_all = True,
            freeze_till = None,
            learning_rate = self.config.params_learning_rate
        )

        self.save_model(path = self.config.updated_base_model_path, model = model)

        logger.info(f'Updated base model saved at {self.config.updated_base_model_path}')

    

In [7]:
try:
    config = ConfigurationManager()
    base_model_config = config.get_model_preparation_config()
    prepare_base_model = ModelPreparation(config=base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2025-07-05 17:11:56,993: INFO: common]: yaml file successfully load from config/config.yaml
[2025-07-05 17:11:56,994: INFO: common]: yaml file successfully load from params.yaml
[2025-07-05 17:11:56,994: INFO: common]: directory created at: artifacts
[2025-07-05 17:11:56,994: INFO: common]: directory created at: artifacts/model_preparation
[2025-07-05 17:11:57,654: INFO: 975316810]: Base model saved at artifacts/model_preparation/base_model.pth
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilat