In [None]:
import os

In [None]:
%pwd

In [None]:
import os

os.chdir("../")

In [None]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dirs: Path
    train_data_path: Path
    test_data_path: Path
    model_name: str
    order_p: int
    order_d: int
    order_q: int
    seasonal_order_p: int
    seasonal_order_d: int
    seasonal_order_q: int
    seasonal_order_s: int
    target_column: str
    index_column: str

In [None]:
from mlproject.constants import *
from mlproject.utils.common import read_yaml, create_directories

In [None]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH,
        schema_filepath = SCHEMA_FILE_PATH
    ):
         self.config = read_yaml(config_filepath)
         self.params = read_yaml(params_filepath)
         self.schema = read_yaml(schema_filepath)
         print(self.schema)
         create_directories([self.config.artifacts_root])
    
    def model_trainer_config(self):
        config = self.config.model_trainer
        params = self.params.SARIMAX
        schema_target = self.schema.TARGET_COLUMN
        schema_index = self.schema.INDEX_COLUMN
        
        create_directories([config.root_dirs])
        
        model_trainer_config = ModelTrainerConfig(
            root_dirs = config.root_dirs,
            train_data_path = config.train_data_path,
            test_data_path = config.test_data_path,
            model_name = config.model_name,
            order_p= params.order_p,
            order_d = params.order_d,
            order_q = params.order_q,
            seasonal_order_p = params.seasonal_order_p,
            seasonal_order_d = params.seasonal_order_d,
            seasonal_order_q = params.seasonal_order_q,
            seasonal_order_s = params.seasonal_order_s,
            target_column = schema_target.name,
            index_column= schema_index.name
        )
        return model_trainer_config

In [None]:
import pandas as pd
import os
from mlproject import logger
from statsmodels.tsa.statespace.sarimax import SARIMAX
import joblib

In [None]:
class ModelTrainer:
    def __init__(self,config:ModelTrainerConfig):
        self.config = config
        
    def train(self):
        print("index column :",self.config.index_column)
        print("target column :",self.config.target_column)
        train_data = pd.read_csv(self.config.train_data_path, parse_dates = [self.config.index_column], index_col = self.config.index_column)
        test_data = pd.read_csv(self.config.test_data_path, parse_dates = [self.config.index_column], index_col = self.config.index_column)
        model = SARIMAX(train_data[self.config.target_column],
                order = (self.config.order_p,
                         self.config.order_d,
                         self.config.order_q),
                seasonal_order = (self.config.seasonal_order_p,
                                  self.config.seasonal_order_d,
                                  self.config.seasonal_order_q,
                                  self.config.seasonal_order_s))
        results = model.fit(disp = False)
        joblib.dump(results, os.path.join(self.config.root_dirs,self.config.model_name))

In [None]:
try:
    config = ConfigurationManager()
    model_trainer_config = config.model_trainer_config()
    model_trainer_config = ModelTrainer(config=model_trainer_config)
    model_trainer_config.train()
except Exception as e:
    raise e