In [1]:
import os
os.chdir("../") # to the prev. dir
%pwd

'c:\\Users\\15600\\Desktop\\PY\\kidney-disease-classification-project'

## Entity

In [2]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    params_num_classes: int

## Configuration

In [3]:
from src.KDClassifier.constants import *
from src.KDClassifier.utils.common import read_yaml, create_directories

In [4]:
class ConfigurationManager:
    '''
    read the config.yaml
    return base_model_config
    '''
    def __init__(self, config_filepath=CONFIG_FILE_PATH,
                       params_filepath=PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])
    
    def get_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        create_directories([config.root_dir])

        base_model_config = PrepareBaseModelConfig(
            root_dir=config.root_dir,
            base_model_path=config.base_model_path,
            params_num_classes=self.params.NUM_CLASSES

        )
        return base_model_config

In [5]:
model_config = ConfigurationManager()

[2024-02-09 14:58:52,075: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-02-09 14:58:52,076: INFO: common: yaml file: params.yaml loaded successfully]
[2024-02-09 14:58:52,077: INFO: common: created directory at: artifacts]


In [6]:
batch_size = model_config.get_base_model_config()

[2024-02-09 14:58:54,045: INFO: common: created directory at: artifacts/prepare_base_model]


## Components

In [7]:
from src.KDClassifier import logger

In [8]:
import torch
import torch.nn as nn
from transformers import ViTForImageClassification

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
class ViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ViTClassifier, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
        self.vit.classifier = nn.Linear(self.vit.classifier.in_features, num_classes)

    def forward(self, x):
        x = self.vit(pixel_values=x).logits
        return x



In [10]:
class PrepareBaseModel:
    '''
    download the model
    '''
    def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
    
    def get_base_model(self):
        base_model_path = self.config.base_model_path
        num_classes = self.config.params_num_classes
        logger.info("model creating")
        self.model = ViTClassifier(num_classes)
        self.save_model(path=base_model_path, model=self.model)
        logger.info("model saved")

    
    @staticmethod
    def save_model(path, model):
        torch.save(model, path)


## Pipeline

In [11]:
try:
    config = ConfigurationManager()
    PrepareBaseModelConfig = config.get_base_model_config()
    prepare_base_model = PrepareBaseModel(config=PrepareBaseModelConfig)
    prepare_base_model.get_base_model()
    
except Exception as e:
    raise e

[2024-02-09 14:59:01,447: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-02-09 14:59:01,448: INFO: common: yaml file: params.yaml loaded successfully]
[2024-02-09 14:59:01,449: INFO: common: created directory at: artifacts]
[2024-02-09 14:59:01,449: INFO: common: created directory at: artifacts/prepare_base_model]
[2024-02-09 14:59:01,450: INFO: 3822608664: model creating]


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[2024-02-09 14:59:02,342: INFO: 3822608664: model saved]
