In [None]:
%pwd

In [None]:
%cd ..


In [None]:
%pwd

In [None]:
from dataclasses import dataclass
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from pathlib import Path
import pickle

@dataclass
class TransferLearning:
    root_dir: Path
    train_dir: Path
    val_dir: Path
    base_model_path: Path
    epochs: int
    batch_size: int
    learning_rate: float
    callback_path: Path


    def __post_init__(self):
        self.train_dataset = tf.data.Dataset.load(self.train_dir)
        self.val_dataset = tf.data.Dataset.load(self.val_dir)
        self.base_model = tf.keras.models.load_model(self.base_model_path, safe_mode=False)
        with open(self.callback_path, 'rb') as handle:
            self.callbacks = pickle.load(handle)

    def train(self):
        self.base_model.compile(
            loss=tf.keras.losses.BinaryCrossentropy(),
            optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate),
            metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5, name='accuracy')])
        
        self.history = self.base_model.fit(
            self.train_dataset,
            epochs=self.epochs,
            validation_data=self.val_dataset,
            callbacks=self.callbacks
        )
        self.save_model(self.base_model)

    def save_plots(self):
        """
        Saves the accuracy and loss plots for the training and validation sets.
        
        Args:
            history (tf.keras.callbacks.History): The history object returned by the model.fit() method.
        """
        
        # Plot the accuracy
        plt.figure(figsize=(8, 6))
        plt.plot(self.history.history['accuracy'])
        plt.plot(self.history.history['val_accuracy'])
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
        plt.savefig(os.path.join(self.root_dir, 'accuracy.png'))
        
        # Plot the loss
        plt.figure(figsize=(8, 6))
        plt.plot(self.history.history['loss'])
        plt.plot(self.history.history['val_loss'])
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper right')
        plt.savefig(os.path.join(self.root_dir, 'loss.png'))

    def save_model(self, model):
        """
        Saves the trained model to the specified output directory.
        
        Args:
            model (tf.keras.Model): The trained model to be saved.
        """
        # Save the model
        model.save(os.path.join(self.root_dir, 'model.keras'))


In [None]:
from brainMRI.constants import *
from brainMRI.utils.helpers import load_config, create_directories
class ConfigHandler:
    def __init__(self, file_path=CONFIG_FILE_PATH, params_path = PARAMS_FILE_PATH):
        self.config = load_config(file_path)
        self.params = load_config(params_path)
        create_directories([self.config.root_dir])

    
    def get_transfer_learning_config(self) -> TransferLearning:
        config = self.config.transfer_learning
        params = self.params.transfer_learning

        create_directories([config.root_dir])
        transfer_learning_config = TransferLearning(
            root_dir=config.root_dir,
            train_dir=config.train_dir,
            val_dir=config.val_dir,
            base_model_path=config.base_model_path,
            callback_path=config.callback_path,
            epochs=params.epochs,
            batch_size=params.batch_size,
            learning_rate=params.learning_rate
        )
        return transfer_learning_config

In [None]:
try:
    config = ConfigHandler()
    transfer_learning_config = config.get_transfer_learning_config()
    transfer_learning_config.train()
    transfer_learning_config.save_plots()
except Exception as e:
    raise e