In [5]:
import os
os.chdir("../")

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelPreparationConfig:
    root_dir: Path
    model_name: str
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_include_top: bool
    params_classes: int
    params_weights: str
    params_learning_rate: float

In [16]:
from cvClassifier import logger
from cvClassifier.utils.common import get_size, read_yaml, create_directories 
from cvClassifier.constants import *
from torchvision.models import VGG16_Weights
from torchvision.models import resnet50, ResNet50_Weights

In [24]:
class ConfigurationManager:
    # this class manages the configuration of the model preparation pipeline

    def __init__(self, config_filepath = CONFIG_FILE_PATH, params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    def get_model_preparation_config(self) -> ModelPreparationConfig:
        ''' Gets the config details for the model preparation ingestion pipeline '''
        config = self.config.model_preparation

        create_directories([config.root_dir])

        model_preparation_config = ModelPreparationConfig(
            root_dir = config.root_dir,
            model_name = config.model_name,
            base_model_path = f'{config.base_model_path}/base_model_{config.model_name}.pth',
            updated_base_model_path = f'{config.updated_base_model_path}/updated_base_model_{config.model_name}.pth',
            params_image_size = self.params.IMAGE_SIZE,
            params_include_top= self.params.INCLUDE_TOP,
            params_classes = self.params.CLASSES,
            params_weights = self.params.WEIGHTS,
            params_learning_rate = self.params.LEARNING_RATE
        )

        return model_preparation_config
    

In [None]:
class ModelFactory:
    @staticmethod
    def create_model(model_name: str, num_classes: int, pretrained: bool = True, include_top: bool = False, weights: str = 'imagenet'):
        """Create VGG16 or ResNet50 model"""
        
        if model_name.lower() == "vgg16":
            weights = VGG16_Weights.IMAGENET1K_V1 if weights == 'imagenet' else None
            model = models.vgg16(weights=weights)

            if not include_top:
                model = nn.Sequential(*list(model.features.children()))
                # *list() unpacks the list of layers in the model.features and passes them as separate arguments to nn.Sequential
            
            logger.info(f"Created VGG16 model with include_top={include_top}")
            
        elif model_name.lower() == "resnet50":
            weights = ResNet50_Weights.IMAGENET1K_V1 if weights == 'imagenet' else None
            model = models.resnet50(weights=weights)

            #if not include_top:
                # Remove the final classification layer
                #model = nn.Sequential(*list(model.children())[:-1])
            
            
            logger.info(f"Created ResNet50 model with include_top={include_top}")
            
        else:
            raise ValueError(f"Unsupported model: {model_name}. Only 'vgg16' and 'resnet50' are supported.")
            
        return model

In [9]:
import os
import urllib.request as requests
from zipfile import ZipFile
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights

In [34]:
class ModelPreparation:
    def __init__(self,config = ModelPreparationConfig):
        self.config = config
    
    def get_base_model(self):
        """Get base model using factory pattern"""
        
        # Create model using factory
        self.model = ModelFactory.create_model(
            model_name=self.config.model_name,
            num_classes=self.config.params_classes,
            pretrained=True,
            include_top=self.config.params_include_top,
            weights=self.config.params_weights
        )
        
        logger.info(f"{self.config.model_name} model created successfully")

    @staticmethod
    def save_model(path: Path, model: nn.Module):
        torch.save(model, path)

    @staticmethod
    def prepare_full_model(model, classes, freeze_all, freeze_till, learning_rate, model_name):
        
        if model_name.lower() == "vgg16":
            if freeze_all:
                for param in model.parameters():
                    param.requires_grad = False
            elif (freeze_till is not None) and (freeze_till > 0):
                layers = list(model.children())
                for layer in layers[:-freeze_till]:
                    for param in layer.parameters():
                        param.requires_grad = False

            num_features = 512 * 7 * 7  # VGG16 features
            
            if isinstance(model, nn.Sequential):
                base_layers = list(model.children())
            else:
                base_layers = list(model.features.children())
                
            full_model = nn.Sequential(
                *base_layers,
                nn.Flatten(), 
                nn.Linear(num_features, classes), 
            )
                
        elif model_name.lower() == "resnet50":
            if freeze_all:
                # Freeze all layers except layer4 (conv5)
                for name, param in model.named_parameters():
                    if 'layer4' not in name:
                        param.requires_grad = False
                        print(f"Frozen: {name}")
                    else:
                        param.requires_grad = True
                        print(f"Trainable: {name}")
            elif (freeze_till is not None) and (freeze_till > 0):
                for name, param in model.named_parameters():
                    if 'layer4' not in name:
                        param.requires_grad = False

            # For ResNet50, replace the fc layer
            model.fc = nn.Linear(2048, classes)
            full_model = model  # Use the full ResNet50 model
            
        else:
            raise ValueError(f"Unsupported model: {model_name}")

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(full_model.parameters(), lr=learning_rate)

        print(full_model)
        print(f"Full model has {len(list(full_model.children()))} layers")
        
        # Print trainable parameters
        trainable_params = sum(p.numel() for p in full_model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in full_model.parameters())
        print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.1f}%)")

        logger.info(f"{model_name} model prepared with {classes} classes, freeze_all={freeze_all}, freeze_till={freeze_till}, learning_rate={learning_rate}")

        return full_model, optimizer, criterion

    def update_base_model(self):
        model, optimizer, criterion = self.prepare_full_model(
            model = self.model,
            classes = self.config.params_classes,
            freeze_all = True,
            freeze_till = None,
            learning_rate = self.config.params_learning_rate,
            model_name = self.config.model_name
        )

        self.save_model(path = self.config.updated_base_model_path, model = model)

        logger.info(f'Updated {self.config.model_name} base model saved at {self.config.updated_base_model_path}')

    

In [35]:
try:
    config = ConfigurationManager()
    base_model_config = config.get_model_preparation_config()
    prepare_base_model = ModelPreparation(config=base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2025-07-19 18:18:38,388: INFO: common]: yaml file successfully loaded from config/config.yaml
[2025-07-19 18:18:38,393: INFO: common]: yaml file successfully loaded from params.yaml
[2025-07-19 18:18:38,393: INFO: common]: Directory created at: artifacts
[2025-07-19 18:18:38,394: INFO: common]: Directory created at: artifacts/model_preparation
[2025-07-19 18:18:38,929: INFO: 1616009934]: Created ResNet50 model with include_top=False
[2025-07-19 18:18:38,929: INFO: 4035689665]: resnet50 model created successfully
Frozen: 0.weight
Frozen: 1.weight
Frozen: 1.bias
Frozen: 4.0.conv1.weight
Frozen: 4.0.bn1.weight
Frozen: 4.0.bn1.bias
Frozen: 4.0.conv2.weight
Frozen: 4.0.bn2.weight
Frozen: 4.0.bn2.bias
Frozen: 4.0.conv3.weight
Frozen: 4.0.bn3.weight
Frozen: 4.0.bn3.bias
Frozen: 4.0.downsample.0.weight
Frozen: 4.0.downsample.1.weight
Frozen: 4.0.downsample.1.bias
Frozen: 4.1.conv1.weight
Frozen: 4.1.bn1.weight
Frozen: 4.1.bn1.bias
Frozen: 4.1.conv2.weight
Frozen: 4.1.bn2.weight
Frozen: 4.1.bn