In [9]:
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_absolute_error
import itertools
import pandas as pd
import numpy as np

In [10]:
def grid_search_sarima(data, p_values, d_values, q_values, P_values, D_values, Q_values, m_values):
    best_score, best_cfg = float("inf"), None
    train_size = int(len(data) * 0.8)
    train, test = data[:train_size], data[train_size:]
    
    for p, d, q, P, D, Q, m in itertools.product(p_values, d_values, q_values, P_values, D_values, Q_values, m_values):
        try:
            model = SARIMAX(train, order=(p, d, q), seasonal_order=(P, D, Q, m))
            model_fit = model.fit(disp=False)
            predictions = model_fit.forecast(steps=len(test))
            mae = mean_absolute_error(test, predictions)
            if mae < best_score:
                best_score, best_cfg = mae, (p, d, q, P, D, Q, m)
        except:
            continue
    return best_cfg, best_score

def evaluate_sarima(data, best_cfg):
    train_size = int(len(data) * 0.8)
    train, test = data[:train_size], data[train_size:]
    p, d, q, P, D, Q, m = best_cfg
    model = SARIMAX(train, order=(p, d, q), seasonal_order=(P, D, Q, m))
    model_fit = model.fit(disp=False)
    predictions = model_fit.forecast(steps=len(test))
    mae = mean_absolute_error(test, predictions)
    return mae


In [11]:
data = pd.read_csv('/workspaces/benchmark_ts_model_test/datasets/candy_production.csv', index_col = 0, parse_dates = True)

print(data.head())

                  IPG3113N
observation_date          
1972-01-01         85.6945
1972-02-01         71.8200
1972-03-01         66.0229
1972-04-01         64.5645
1972-05-01         65.0100


In [None]:
# from statsmodels.tsa.statespace.sarimax import SARIMAX
# from sklearn.metrics import mean_absolute_error
# import itertools
# import pandas as pd
# import numpy as np

# # Function to infer seasonal cycle (m) based on time index frequency
# def infer_seasonality(data):
#     freq = pd.infer_freq(data.index)
#     if freq in ['D', 'B']:  # Daily or Business Daily
#         return [7, 365]  # Weekly and yearly seasonality
#     elif freq == 'W':  # Weekly
#         return [52]  # Yearly seasonality
#     elif freq == 'M':  # Monthly
#         return [12]  # Yearly seasonality
#     elif freq == 'Q':  # Quarterly
#         return [4]  # Yearly seasonality
#     elif freq == 'A':  # Yearly
#         return [1]  # No seasonality or multi-year cycle if needed
#     else:
#         return [1]  # Default to no seasonality if unknown

# # Function to perform SARIMA grid search
# def grid_search_sarima(data, p_values, d_values, q_values, P_values, D_values, Q_values):
#     best_score, best_cfg = float("inf"), None
#     train_size = int(len(data) * 0.8)
#     train, test = data[:train_size], data[train_size:]

#     # Infer the seasonal periods from the data index
#     m_values = infer_seasonality(data)
    
#     # Iterate over all combinations of parameters
#     for p, d, q, P, D, Q, m in itertools.product(p_values, d_values, q_values, P_values, D_values, Q_values, m_values):
#         try:
#             model = SARIMAX(train, order=(p, d, q), seasonal_order=(P, D, Q, m))
#             model_fit = model.fit(disp=False)
#             predictions = model_fit.forecast(steps=len(test))
#             mae = mean_absolute_error(test, predictions)
#             if mae < best_score:
#                 best_score, best_cfg = mae, (p, d, q, P, D, Q, m)
#         except Exception as e:
#             continue  # Skip any invalid parameter combinations
#     return best_cfg, best_score

# # Function to evaluate SARIMA on the test set with best config
# def evaluate_sarima(data, best_cfg):
#     train_size = int(len(data) * 0.8)
#     train, test = data[:train_size], data[train_size:]
#     p, d, q, P, D, Q, m = best_cfg
#     model = SARIMAX(train, order=(p, d, q), seasonal_order=(P, D, Q, m))
#     model_fit = model.fit(disp=False)
#     predictions = model_fit.forecast(steps=len(test))
#     mae = mean_absolute_error(test, predictions)
#     return mae

# # Sample parameter grid
# sarima_params = {
#     'p_values': [0, 1],
#     'd_values': [0, 1],
#     'q_values': [0, 1],
#     'P_values': [0, 1],
#     'D_values': [0, 1],
#     'Q_values': [0, 1]
# }


In [8]:
# from statsmodels.tsa.statespace.sarimax import SARIMAX
# from sklearn.metrics import mean_absolute_error
# import itertools
# import pandas as pd
# import numpy as np

# # Function to detect frequency and return the appropriate m value
# def get_seasonal_period(data):
#     freq = pd.infer_freq(data.index)
#     if freq is None:
#         raise ValueError("Frequency could not be inferred. Make sure your data has a valid DateTimeIndex.")
    
#     if freq in ['D']:       # Daily data
#         return [7, 365]     # Weekly and yearly seasonality
#     elif freq in ['W']:     # Weekly data
#         return [52]         # Yearly seasonality
#     elif freq in ['M']:     # Monthly data
#         return [12]         # Yearly seasonality
#     elif freq in ['Q']:     # Quarterly data
#         return [4]          # Yearly seasonality
#     elif freq in ['A', 'Y']: # Yearly data
#         return [1]          # No seasonality (or multi-year cycles)
#     else:
#         raise ValueError("Unsupported frequency: {}".format(freq))

# # SARIMA grid search function
# def grid_search_sarima(data, p_values, d_values, q_values, P_values, D_values, Q_values):
#     best_score, best_cfg = float("inf"), None
#     train_size = int(len(data) * 0.8)
#     train, test = data[:train_size], data[train_size:]
    
#     # Get adaptive m_values based on the data's frequency
#     m_values = get_seasonal_period(data)
    
#     for p, d, q, P, D, Q, m in itertools.product(p_values, d_values, q_values, P_values, D_values, Q_values, m_values):
#         try:
#             model = SARIMAX(train, order=(p, d, q), seasonal_order=(P, D, Q, m))
#             model_fit = model.fit(disp=False)
#             predictions = model_fit.forecast(steps=len(test))
#             mae = mean_absolute_error(test, predictions)
#             if mae < best_score:
#                 best_score, best_cfg = mae, (p, d, q, P, D, Q, m)
#         except:
#             continue
#     return best_cfg, best_score

# # Evaluation function
# def evaluate_sarima(data, best_cfg):
#     train_size = int(len(data) * 0.8)
#     train, test = data[:train_size], data[train_size:]
#     p, d, q, P, D, Q, m = best_cfg
#     model = SARIMAX(train, order=(p, d, q), seasonal_order=(P, D, Q, m))
#     model_fit = model.fit(disp=False)
#     predictions = model_fit.forecast(steps=len(test))
#     mae = mean_absolute_error(test, predictions)
#     return mae





In [12]:
# Example gridsearch space, adaptive to data type
sarima_params = {
    'p_values': [0, 1],
    'd_values': [0, 1],
    'q_values': [0, 1],
    'P_values': [0, 1],
    'D_values': [0, 1],
    'Q_values': [0, 1], 
    'm_values': [1, 4, 7, 12, 52, 365]

}

# Usage example (assuming `data` is a pandas DataFrame with a datetime index)
best_cfg, best_score = grid_search_sarima(data, **sarima_params)
sarima_mae = evaluate_sarima(data, best_cfg)

  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  warn('Non-invertible starting seasonal moving average'
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  warn('Non-invertible starting seasonal moving average'
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  warn('Too few observations to estimate starting parameters%s.'
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  self._init_dates(da