In [5]:

import config
%pwd

'/Users/anshujoshi/PycharmProjects/CancerDetection'

In [6]:
from dataclasses import dataclass
from pathlib import Path
from typing import List


@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    params_epoch: int
    params_batch_size: int
    params_is_augmented: bool
    params_image_size: List

In [7]:
from cnnClassifier.utils.common import *
from cnnClassifier.constants import *

CONFIG_FILE_PATH set to: config/config.yaml
PARAMS_FILE_PATH set to: params.yaml


In [8]:
class ConfigManager:
    def __init__(self,
                 config_file_path=CONFIG_FILE_PATH
                 , params_file_path=PARAMS_FILE_PATH
                 ):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_training_config(self) -> TrainingConfig:
        training = self.config.training
        prepare_base_model = self.config.prepare_base_model
        params = self.params
        training_data = os.path.join(self.config.data_ingestion.unzip_dir, 'Chest-CT-Scan-data')

        create_directories([training.root_dir])

        training_config = TrainingConfig(
            root_dir=Path(training.root_dir),
            trained_model_path=Path(training.trained_model_path),
            updated_base_model_path=Path(prepare_base_model.updated_base_model),
            training_data=Path(training_data),
            params_epoch=params.EPOCH,
            params_batch_size=params.BATCH_SIZE,
            params_is_augmented=params.AUGMENTATION,
            params_image_size=params.IMAGE_SIZE,
        )

        return training_config


In [9]:
import tensorflow as tf
from cnnClassifier import logger

In [18]:
class Train:
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.model = None
    
    def get_base_model(self):
        try:
            self.model = tf.keras.models.load_model(self.config.updated_base_model_path)
            logger.info("Model loaded successfully.")
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
    
    def train_valid_generator(self):
        seed = 132
        image_size = tuple(self.config.params_image_size[:2]) 

        try:
            logger.info('Trying train generator')

            data_augmentation = tf.keras.Sequential([
                tf.keras.layers.RandomFlip("horizontal_and_vertical"),
                tf.keras.layers.RandomRotation(0.2),
                tf.keras.layers.RandomZoom(0.2),
            ]) if self.config.params_is_augmented else None

            def preprocess(image, label):
                if data_augmentation:
                    image = data_augmentation(image)
                return image, label

            self.train_generator = tf.keras.utils.image_dataset_from_directory(
                directory=self.config.training_data,
                batch_size=self.config.params_batch_size,
                image_size=image_size,
                labels='inferred',
                label_mode='categorical',
                class_names=['adenocarcinoma', 'normal'],
                subset='training',
                seed=seed,
                color_mode='rgb' if self.config.params_image_size[2] == 3 else 'grayscale',
                validation_split=0.2,
            ).map(preprocess, num_parallel_calls=tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)

            logger.info('Train generator created successfully')
        except Exception as e:
            logger.error(f'Failed train generator: {e}')
            self.train_generator = None

        try:
            logger.info('Trying validation generator')
            self.val_generator = tf.keras.utils.image_dataset_from_directory(
                directory=self.config.training_data,
                batch_size=self.config.params_batch_size,
                image_size=image_size,
                labels='inferred',
                label_mode='categorical',
                class_names=['adenocarcinoma', 'normal'],
                subset='validation',
                seed=seed,
                color_mode='rgb' if self.config.params_image_size[2] == 3 else 'grayscale',
                validation_split=0.2,
            ).cache().prefetch(tf.data.AUTOTUNE)

            logger.info('Validation generator created successfully')
        except Exception as e:
            logger.error(f'Failed validation generator: {e}')
            self.val_generator = None
    
    def train(self):
        if self.model is not None and self.train_generator is not None and self.val_generator is not None:
            self.model.fit(
                self.train_generator,
                validation_data=self.val_generator,
                epochs=self.config.params_epoch
            )
            logger.info("Model training completed.")
            
            
            self.save_model(path=self.config.trained_model_path,model=self.model)
        else:
            logger.error("Training or validation generator not initialized properly or model not loaded.")
        

    @staticmethod
    def save_model(path : Path,model : tf.keras.Model):
        model.save(path)
        

In [19]:
obj = ConfigManager()
Train_config = obj.get_training_config()
trainer = Train(Train_config)
trainer.get_base_model()
trainer.train_valid_generator()

[2024-06-19 15:26:18,752: INFO: common: yaml config/config.yaml loaded successfully]
[2024-06-19 15:26:18,755: INFO: common: yaml params.yaml loaded successfully]
[2024-06-19 15:26:18,756: INFO: common: created directory at: artifacts]
[2024-06-19 15:26:18,757: INFO: common: created directory at: artifacts/training]
[2024-06-19 15:26:18,967: INFO: 2552327953: Model loaded successfully.]
[2024-06-19 15:26:18,968: INFO: 2552327953: Trying train generator]
Found 343 files belonging to 2 classes.
Using 275 files for training.
[2024-06-19 15:26:19,003: INFO: 2552327953: Train generator created successfully]
[2024-06-19 15:26:19,004: INFO: 2552327953: Trying validation generator]
Found 343 files belonging to 2 classes.
Using 68 files for validation.
[2024-06-19 15:26:19,023: INFO: 2552327953: Validation generator created successfully]


In [20]:
import mlflow
mlflow.tensorflow.autolog()
mlflow.set_experiment(experiment_name='BaseLine')

with mlflow.start_run(run_name='Default'):
    trainer.train()



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
INFO:tensorflow:Assets written to: /var/folders/4n/gbrmm4bs17ldvh4_ypty4z5c0000gn/T/tmpllwi4doa/model/data/model/assets
[2024-06-19 15:33:05,284: INFO: builder_impl: Assets written to: /var/folders/4n/gbrmm4bs17ldvh4_ypty4z5c0000gn/T/tmpllwi4doa/model/data/model/assets]




[2024-06-19 15:33:05,315: INFO: 2552327953: Model training completed.]


  saving_api.save_model(


In [17]:
with mlflow.start_run(run_name='epoch-10'):
    trainer.train()

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
INFO:tensorflow:Assets written to: /var/folders/4n/gbrmm4bs17ldvh4_ypty4z5c0000gn/T/tmp51kggwvx/model/data/model/assets
[2024-06-19 15:10:17,023: INFO: builder_impl: Assets written to: /var/folders/4n/gbrmm4bs17ldvh4_ypty4z5c0000gn/T/tmp51kggwvx/model/data/model/assets]




[2024-06-19 15:10:17,047: INFO: 2552327953: Model training completed.]


  saving_api.save_model(
