In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import pickle
import random
from statsmodels.tsa.statespace.sarimax import SARIMAX
from prophet import Prophet

In [2]:
df = pd.concat([pd.read_csv('/kaggle/input/mts1-task/train.csv'), pd.read_csv('/kaggle/input/mts1-task/train_extra.csv')])

In [3]:
df = df[df.service_date > '2023-06-01']

In [4]:
df_dt_moscow = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Москва')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})
df_dt_krkrai = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Краснодарский край')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})
df_dt_spb = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Санкт-Петербург')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})
df_dt_mo = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Московская область')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})
df_dt_tat = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Татарстан')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})
df_dt_sver = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Свердловская область')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})
df_dt_rost = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Ростовская область')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})
df_dt_nizh = df[(df.service_status == 'Подтвержден') &
                  (df.region_name == 'Нижегородская область')][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})

df_dt_another_russia = df[(df.service_status == 'Подтвержден') &# (df.hotel_type == 'hotel') &
                  (~df.region_name.isin(['Москва',
                                         'Краснодарский край',
                                         'Санкт-Петербург',
                                         'Татарстан',
                                         'Московская область',
                                         'Свердловская область',
#                                          'Красноярский край'
                                         ]) &
                  (df.country_name == 'Россия'))][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})

df_dt_another_not_russia = df[(df.service_status == 'Подтвержден') & #(~df.region_name.isin(['Ташкент'])) &
                  (~df.country_name.isin(['Россия']))][['service_date', 'sum_price', 'hotel_id']].groupby('service_date').agg({'sum_price': 'sum',
                                                                           'hotel_id': 'count'})

In [5]:
holidays = pd.DataFrame({
    'holiday': ['День России', 'День народного единства', 'Новый год', 
                'Рождество Христово', 'День защитника Отечества', 
                'Международный женский день', 'Праздник Весны и Труда', 
                'День Победы'],
    'ds': pd.to_datetime(['2023-06-12', '2023-11-04', '2024-01-01', 
                          '2024-01-07', '2024-02-23', '2024-03-08', 
                          '2024-05-01', '2024-05-09']),
    'lower_window': -1,  # Дополнительные дни до праздника (можно расширить диапазон)
    'upper_window': 1   # Дополнительные дни после праздника
})

In [6]:
# Set a fixed seed for reproducibility
seed_value = 42
#np.random.seed(seed_value)
#random.seed(seed_value)

# Initialize an empty list to store the fitted models
models = {}

# Define SARIMA order and seasonal order (example for ARIMA(1, 1, 2), SARIMA(1, 1, 1, 7))
sarima_order = (1, 1, 2)
seasonal_order = (1, 1, 1, 7)

# Function to fit SARIMA and get forecasts
def fit_sarima_and_get_forecast(df, steps=31):
    #np.random.seed(seed_value)  # Set seed before fitting SARIMA model
    sarima_model = SARIMAX(df['sum_price'], order=sarima_order, seasonal_order=seasonal_order)
    sarima_fit = sarima_model.fit(disp=False)
    forecast = sarima_fit.forecast(steps)
    return forecast, sarima_fit.fittedvalues, sarima_fit  # Return the SARIMA model as well

# Function to fit Prophet with SARIMA as a regressor
def fit_prophet_with_sarima(df, region_name, holidays, steps=31):
    # Prepare data for Prophet
    df_prophet = df.reset_index()[['service_date', 'sum_price']].rename(columns={'service_date': 'ds', 'sum_price': 'y'})
    
    # Generate SARIMA forecast and fitted model
    sarima_forecast = fit_sarima_and_get_forecast(df, steps=steps)
    
    # Ensure SARIMA forecast matches Prophet data length
    sarima_forecast1, sarima_fitted, sarima_model_fit = [pd.Series(i).reset_index(drop=True) for i in sarima_forecast]
    
    # Add SARIMA as a regressor
    df_prophet['sarima_regressor'] = sarima_fitted
    df_prophet['myatezh_prigozhina'] = df_prophet.apply(lambda x: 1 if x['ds'] in ['2023-06-23', '2023-06-24'] else 0, axis=1)
    
    # Fit Prophet model with SARIMA regressor
    model = Prophet(weekly_seasonality=True, holidays=holidays)
    model.add_regressor('sarima_regressor')
    model.add_regressor('myatezh_prigozhina')

    #np.random.seed(seed_value)  # Set seed before fitting Prophet model
    model_fit = model.fit(df_prophet)
    
    # Append the fitted Prophet and SARIMA models
    models[region_name] = {
            'prophet_model': model_fit,  # Prophet model
            'sarima_model': sarima_model_fit[0]  # SARIMA model
    }

# Save the fitted models to disk
def save_models(models, file_path='fitted_models.pkl'):
    with open(file_path, 'wb') as file:
        pickle.dump(models, file)

# Load models from disk
def load_models(file_path='fitted_models.pkl'):
    with open(file_path, 'rb') as file:
        return pickle.load(file)

# List of dataframes and region names
regions_data = {
    'Moscow': df_dt_moscow,
    'Krasnodar': df_dt_krkrai,
    'St_Petersburg': df_dt_spb,
    'Moscow_Region': df_dt_mo,
    'Tatarstan': df_dt_tat,
    'Sverdlovsk': df_dt_sver,
    'Rostov': df_dt_rost,
    'Nizhny_Novgorod': df_dt_nizh,
    'Another Not Russia': df_dt_another_not_russia,
    'Another Russia1': df_dt_another_russia,
}

# Fit Prophet models with SARIMA regressors for all regions
for region, df in regions_data.items():
    fit_prophet_with_sarima(df, region, holidays)

# Save the models after fitting
save_models(models)


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
17:17:47 - cmdstanpy - INFO - Chain [1] start processing
17:17:47 - cmdstanpy - INFO - Chain [1] done processing
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
17:17:48 - cmdstanpy - INFO - Chain [1] start processing
17:17:48 - cmdstanpy - INFO - Chain [1] done processing
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
17:17:49 - cmdstanpy - INFO - Chain [1] start processing
17:17:49 - cmdstanpy - INFO - Chain [1] done processing
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  warn('Non-invertible starting MA parameters found.'
17:17:50 - cmdstanpy - INFO - Chain [1] start processing
17:17:50 - cmdstanpy - INFO - Chain [1] done processing
  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
17:17:51 - cmdstanpy - INFO - Chain [1] start processing
17:17:51 - cmdstanpy - INFO - Chain [1] done processing
  self._init_dates(dates, freq)
  self._init_dates(dates, fre