In [1]:
import os

In [2]:
# this makes the parent file the current working directory
# NOTE - Restart whole notebook otherwise the wrong working directory will be set
os.chdir("../")

print(os.getcwd())

/Users/yyh/Documents/selfstudy/E2E-Kidney-Disease-MLOps-Project


In [3]:
import os
import zipfile
import gdown
import torch

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from dataclasses import dataclass
from pathlib import Path

from KidneyDiseasePrediction.constants import CONFIG_FILE_PATH, PARAMS_FILE_PATH
from KidneyDiseasePrediction.utils.common import read_yaml, create_directories, get_size


@dataclass(frozen=True)
class ModelTrainingConfig:
    root_dir: Path
    modified_base_model_path: Path
    training_data: Path
    trained_model_path: Path
    params_epochs: int
    params_batch_size: int
    params_learning_rate: float
    params_augmentation: bool


class ConfigManager:
    def __init__(self, config_file_path=CONFIG_FILE_PATH, params_file_path=PARAMS_FILE_PATH):

        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_training_config(self) -> ModelTrainingConfig:

        create_directories([self.config.model_training.root_dir])

        model_training_config = ModelTrainingConfig(
            root_dir=self.config.model_training.root_dir,
            modified_base_model_path=self.config.base_model.modified_base_model_path,
            training_data=os.path.join(self.config.data_ingestion.unzip_dir, "dataset"),
            trained_model_path=self.config.model_training.trained_model_path,
            params_epochs=self.params.EPOCHS,
            params_batch_size=self.params.BATCH_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_augmentation=self.params.AUGMENTATION,
        )

        return model_training_config


class Training:
    def __init__(self, config: ModelTrainingConfig):
        self.config = config

    def get_base_model(self):
        self.model = torch.load(self.config.modified_base_model_path, weights_only=False)

    def train_valid_generators(self):

        transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
        dataset = datasets.ImageFolder(root=self.config.training_data, transform=transform)

        train_size = int(0.8 * len(dataset))
        valid_size = len(dataset) - train_size
        train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

        self.train_loader = DataLoader(train_dataset, batch_size=self.config.params_batch_size, shuffle=True)
        self.valid_loader = DataLoader(valid_dataset, batch_size=self.config.params_batch_size, shuffle=False)

    def train_model(self):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.params_learning_rate)

        for epoch in range(self.config.params_epochs):
            self.model.train()
            for images, labels in self.train_loader:
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            print(f"Epoch [{epoch + 1}/{self.config.params_epochs}], Loss: {loss.item():.4f}")

        self.save_model(self.config.trained_model_path, self.model)

    @staticmethod
    def save_model(path: Path, model: torch.nn.Module):
        torch.save(model, path)

In [4]:
try:
    config_manager = ConfigManager()
    model_training_config = config_manager.get_training_config()
    model_training = Training(model_training_config)
    model_training.get_base_model()
    model_training.train_valid_generators()
    model_training.train_model()

except Exception as e:
    raise e

2025-07-10 15:04:27,713 - INFO - common - YAML file config/config.yaml loaded successfully.
2025-07-10 15:04:27,715 - INFO - common - YAML file params.yaml loaded successfully.
2025-07-10 15:04:27,715 - INFO - common - Directory artifacts created at artifacts
2025-07-10 15:04:27,716 - INFO - common - Directory artifacts/training created at artifacts/training
Epoch [1/1], Loss: 0.0068
