In [3]:
import pandas as pd
import xgboost as xgb
import pandas as pd
import numpy as np
import logging
#from app.core.config import DATA_PATH, MODEL_PATH

In [11]:
MODEL_PATH="xgboost_final.json"
DATA_PATH="final_df_w_fe.csv"

In [12]:
logger = logging.getLogger(__name__)

def load_data(path):
    try:
        df = pd.read_csv(path, low_memory=False)
        df = df.drop(columns='Unnamed: 0')
        df['Date'] = pd.to_datetime(df['Date'])
        df = df.sort_values(by='Date')
        logger.info('Daten erfolgreich geladen')
        return df
    except Exception as e:
        logger.error(f'Data Loading fehlgeschlagen: {e}')
        return None

def create_holdout_set(dataframe, feature_list, anzahl_wochen, get_dates=False):
    try:
        train_list = []
        test_list = []

        dataframe = dataframe[feature_list]
        dataframe = dataframe[dataframe["Open"] != 0]
        dataframe = dataframe[dataframe["Sales"] > 0]
        dataframe = dataframe.sort_values(by=['Store', 'Date'])
        grouped = dataframe.groupby('Store')

        for store_id, group in grouped:
            train_list.append(group[:-7*anzahl_wochen])
            test_list.append(group[-7*anzahl_wochen:])

        train = pd.concat(train_list)
        test = pd.concat(test_list)

        if get_dates: 
            train_dates = train.Date
            test_dates = test.Date 

        y_train = np.log1p(train['Sales'])
        X_train = train.drop(columns=['Sales', 'Open'])
        y_test = np.log1p(test['Sales'])  
        X_test = test.drop(columns=['Sales', 'Open', 'Date'])

        if get_dates:
            return X_train, y_train, X_test, y_test, train_dates, test_dates
        else:
            return X_train, y_train, X_test, y_test
    except Exception as e:
        logger.error(f'Fehler beim Erstellen des Holdout-Sets: {e}')
        return None, None, None, None, None, None

def load_model():
    try:
        model = xgb.Booster()
        model.load_model(MODEL_PATH)
        logger.info('XGBoost Model erfolgreich geladen')
        print('XGBoost Model erfolgreich geladen')
        return model
    except Exception as e:
        logger.error(f'XGBoost Model konnte nicht geladen werden: {e}')
        return None

def get_predictions(model, store_id, forecast_horizon):
    try:
        df = load_data(DATA_PATH)
        if df is None:
            raise ValueError("Daten konnten nicht geladen werden")

        feature_group = ['Date', 'Store', 'AvgLastMonthSales', 'AvgLastYearSales', 'AvgPromoSales', 'AvgHolidaySales', 'Open', 'Sales', 'Day', 'Month', 'DayOfWeek', 'WeekOfYear', 'Promo', 'StateHoliday', 'SchoolHoliday', 'CompetitionDistance', 'MonthsSinceCompetitionOpen', 'PromoWeeks', 'Assortment', 'StoreType']
        X_train, y_train, X_test, y_test, train_dates, test_dates = create_holdout_set(df, feature_group, anzahl_wochen=6, get_dates=True)
        
        store_data = X_test[X_test['Store'] == store_id]
        dmatrix = xgb.DMatrix(store_data, enable_categorical=True)
        prediction = model.predict(dmatrix)

        return np.expm1(prediction).tolist()
    except Exception as e:
        logger.error(f'Fehler bei der Vorhersage: {e}')
        return []


In [17]:
df = load_data(DATA_PATH)
feature_group = ['Date', 'Store', 'AvgLastMonthSales', 'AvgLastYearSales', 'AvgPromoSales', 'AvgHolidaySales', 'Open', 'Sales', 'Day', 'Month', 'DayOfWeek', 'WeekOfYear', 'Promo', 'StateHoliday', 'SchoolHoliday', 'CompetitionDistance', 'MonthsSinceCompetitionOpen', 'PromoWeeks', 'Assortment', 'StoreType']

In [18]:
X_train, y_train, X_test, y_test, train_dates, test_dates = create_holdout_set(df, feature_group, anzahl_wochen=6, get_dates=True)

In [21]:
store_data = X_test[X_test['Store'] == 144]
dmatrix = xgb.DMatrix(store_data, enable_categorical=True)

In [22]:
model = load_model()

XGBoost Model erfolgreich geladen


In [23]:
prediction = model.predict(dmatrix)
print(prediction)

[8.616707  9.134683  8.95622   8.967396  9.003748  9.051197  8.593282
 8.748238  8.697527  8.666334  8.689075  8.812992  8.473973  9.199252
 9.1237335 9.022075  9.061629  8.933664  8.624006  8.91811   8.841375
 8.799709  8.799153  8.954138  8.628993  9.218123  8.981782  9.007857
 9.018217  9.015262  8.616912  8.934781  8.856712  8.80793   8.80798
 8.902411  8.617696  9.264584  9.035755  9.069745  9.102515  9.136642 ]
