In [2]:
import os

In [3]:
%pwd

'c:\\Users\\supre\\PycharmProjects\\RenalHealth-AI\\research'

In [4]:
os.chdir("../")

In [32]:
from dataclasses import dataclass
from pathlib import Path
from typing import List


@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: List[int]
    params_include_top: bool
    params_weights: bool
    params_classes: int

In [33]:
from cnn_classifier.constants import *
from cnn_classifier.utils.common import read_yaml, create_directories

In [34]:
class ConfigurationManager:
    def __init__(
        self,
        config_file_path: Path = CONFIG_FILE_PATH,
        params_file_path: Path = PARAMS_FILE_PATH,
    ):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model

        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_include_top=self.params.INCLUDE_TOP,
            params_weights=self.params.WEIGHTS,
            params_classes=self.params.CLASSES,
        )

        return prepare_base_model_config

In [54]:
import os

# from urllib.request import request
# from zipfile import ZipFile
# import tensorflow as tf

from torchvision import models
from torch import (
    device as torch_device,
    cuda as torch_cuda,
    nn,
    save as torch_save,
)
from torchsummary import summary

In [63]:
class PrepareBaseModel:
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
        self.device = torch_device(
            "cuda" if torch_cuda.is_available() else "cpu"
        )

    def get_base_model(self):
        self.model = models.vgg16(
            pretrained=self.config.params_weights
        ).to(self.device)
        if not self.config.params_include_top:
            self.model.classifier = nn.Sequential(
                *list(self.model.classifier.children())[:-1]
            )
        self.save_model(path=self.config.base_model_path, model=self.model)

    @staticmethod
    def _prepare_full_model(
        model,
        classes,
        freeze_all,
        freeze_till,
        img_size,
        device,
    ):
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False
        elif (freeze_till is not None) and (freeze_till > 0):
            for param in model.parameters()[:freeze_till]:
                param.requires_grad = False

        last_layer = None
        for layer in model.classifier.children():
            if isinstance(layer, nn.Linear):
                last_layer = layer
        if last_layer is None:
            raise ValueError("No linear layer found in the classifier.")
        num_features = last_layer.in_features
        
        model.classifier.append(nn.Linear(num_features, classes).to(device))
        model.classifier.append(nn.Softmax(dim=1).to(device)) 

        print(model)
        summary(model, tuple(reversed(img_size)))
        return model

    def update_base_model(self):
        self.full_model = self._prepare_full_model(
            model=self.model,
            classes=self.config.params_classes,
            freeze_all=True,
            freeze_till=None,
            img_size=self.config.params_image_size,
            device=self.device,
        )

        self.save_model(
            path=self.config.updated_base_model_path,
            model=self.full_model
        )

    @staticmethod
    def save_model(path: Path, model: models):
        torch_save(model, path)

In [64]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2024-04-02 11:44:43,900: INFO: common: yaml file: config\config.yml loaded successfully]
[2024-04-02 11:44:43,902: INFO: common: yaml file: params.yml loaded successfully]
[2024-04-02 11:44:43,903: INFO: common: created directory at: artifacts]
[2024-04-02 11:44:43,904: INFO: common: created directory at: artifacts/prepare_base_model]




VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [75]:
summary(prepare_base_model.model, tuple(reversed(l)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256,

In [72]:
l = [224,224,3]

In [74]:
tuple(reversed(l))

(3, 224, 224)