In [2]:
import os

In [3]:
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification\\research'

In [4]:
os.chdir("../")

In [5]:
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification'

In [6]:
import torch

In [7]:
'cuda' if torch.cuda.is_available() else 'cpu'

'cuda'

In [8]:
#Update the entity
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class PrepareBasemodelConfig:
    root_dir: Path
    base_model_dir: Path
    updated_base_model: Path
    params_image_size: list
    params_device: str
    params_weight: str
    params_classes: int

In [12]:
# updating the configuration mnanager in src
from Food_Classification.utils.common import read_yaml,create_directory
from Food_Classification.constants import *


class ConfigurationManager:
    def __init__(self,
                 config_file_path= CONFIG_FILE_PATH,
                 params_file_path= PARAMS_FILE_PATH):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directory([self.config.artifacts_root])
    def get_base_model_config(self) -> PrepareBasemodelConfig:
        config = self.config.prepare_base_model
        params = self.params

        create_directory([config.root_dir])

        prepare_base_model_cofig = PrepareBasemodelConfig(
                                                        root_dir= Path(config.root_dir),
                                                        base_model_dir= Path(config.base_model_path),
                                                        updated_base_model= Path(config.updated_base_model),
                                                        params_image_size=  params.IMAGE_SIZE,
                                                        params_device= params.DEVICE,
                                                        params_weight= params.WEIGHTS,
                                                        params_classes= params.CLASSES)
        
        return prepare_base_model_cofig


In [13]:
# Update the component

import torch
import torchvision
from torchinfo import summary
from Food_Classification

class PrepareBaseModel:
    def __init__(self,config: PrepareBasemodelConfig):
        self.config = config

    def get_base_model(self):
        self.base_model = torchvision.models.efficientnet_b4(weights= self.config.params_weight).to(self.config.params_device)
        torch.save(self.base_model, self.config.base_model_dir)

    @staticmethod
    def preparebasemodel(model, classes, freeze:bool):
        if freeze:
            for param in model.parameters():
                param.requires_grad = False
        
        model = model
        model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=model.classifier[0].p),
            torch.nn.Linear(in_features=model.classifier[1].in_features, 
                            out_features=classes)
        )

        info = summary(model= model,input_size=(1,3,224,224),
        col_names=['input_size', 'output_size', 'num_params', "trainable"],
        col_width=20,
        row_settings=["var_names"])
        return model, info

    def update_base_model(self):
        self.model = self.preparebasemodel(model= self.base_model, 
        classes= self.config.params_classes,
        freeze = True)
        torch.save(self.model, self.config.updated_base_model)

    @staticmethod
    def save_model(path:Path, model: torch.nn.Module):
        torch.save(model, path)

In [14]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2024-05-02 10:34:43,916: INFO: common: yaml file config\config.yaml loaded successfully]
[2024-05-02 10:34:43,919: INFO: common: yaml file params.yaml loaded successfully]
[2024-05-02 10:34:43,921: INFO: common: directory artifacts created successfully]
[2024-05-02 10:34:43,921: INFO: common: directory artifacts/prepare_base_model created successfully]


In [None]:
from torchvision.models import efficientnet_b4,EfficientNet_B4_Weights

effb4 = efficientnet_b4(weights = EfficientNet_B4_Weights.DEFAULT)
effb4

Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth" to C:\Users\49179/.cache\torch\hub\checkpoints\efficientnet_b4_rwightman-23ab8bcd.pth
100%|██████████| 74.5M/74.5M [00:40<00:00, 1.95MB/s]


EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
            (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActiv

In [None]:
effb4.classifier

Sequential(
  (0): Dropout(p=0.4, inplace=True)
  (1): Linear(in_features=1792, out_features=1000, bias=True)
)