In [3]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import numpy as np
import pandas as pd

try:
    from statsmodels.tsa.statespace.sarimax import SARIMAX
except Exception:
    raise RuntimeError("statsmodels is required. Install with: pip install statsmodels")

try:
    from sklearn.metrics import mean_squared_error, mean_absolute_error
except Exception:
    raise RuntimeError("scikit-learn is required. Install with: pip install scikit-learn")


DATA_FILE = "SnP_daily_update_AMZN_features_with_target.csv"
METHODS_FILE = "AMZN_methods_all_in_one.csv"
WINDOW = 14


def load_methods(path):
    df = pd.read_csv(path)
    # If file is single-column list of feature names
    if df.shape[1] == 1:
        return df.iloc[:, 0].dropna().astype(str).str.strip().tolist()
    # if there's an obvious column name
    for col in ("feature", "name", "method"):
        if col in df.columns:
            return df[col].dropna().astype(str).str.strip().tolist()
    # otherwise, fall back to column names
    return df.columns.astype(str).tolist()


def prepare_data(data_path, methods_path):
    df = pd.read_csv(data_path, parse_dates=True)
    methods = load_methods(methods_path)
    # intersect methods with available columns
    features = [c for c in methods if c in df.columns]
    if "y_ret_t1" not in df.columns:
        raise RuntimeError("Target column 'y_ret_t1' not found in data")
    # drop rows where target or all features are NaN
    if features:
        keep = df[["y_ret_t1"] + features].dropna()
    else:
        keep = df[["y_ret_t1"]].dropna()
    keep = keep.reset_index(drop=True)
    return keep, features


def rolling_arimax(df, features, window=14, order=(1, 0, 0)):
    preds = []
    trues = []
    idxs = []

    n = len(df)
    if n <= window:
        raise RuntimeError("Not enough rows for the given window")

    for i in range(0, n - window):
        train = df.iloc[i : i + window]
        test = df.iloc[i + window]

        endog = train["y_ret_t1"].values
        exog = train[features].values if features else None
        exog_next = test[features].values.reshape(1, -1) if features else None

        try:
            model = SARIMAX(endog, exog=exog, order=order, enforce_stationarity=False, enforce_invertibility=False)
            res = model.fit(disp=False)
            pred = res.predict(start=len(endog), end=len(endog), exog=exog_next)
            pred_val = float(pred[0])
        except Exception:
            pred_val = np.nan

        preds.append(pred_val)
        trues.append(float(test["y_ret_t1"]))
        idxs.append(i + window)

    out = pd.DataFrame({"index": idxs, "y_true": trues, "y_pred": preds})
    return out


def main():
    try:
        root = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        root = os.getcwd()
    data_path = os.path.join(root, DATA_FILE)
    methods_path = os.path.join(root, METHODS_FILE)

    if not os.path.exists(data_path):
        print(f"Data file not found: {data_path}")
        sys.exit(1)
    if not os.path.exists(methods_path):
        print(f"Methods file not found: {methods_path}")
        sys.exit(1)

    df, features = prepare_data(data_path, methods_path)
    print(f"Using {len(features)} features: {features}")
    print(f"Data rows after dropna: {len(df)}; window={WINDOW}")

    preds_df = rolling_arimax(df, features, window=WINDOW, order=(1, 0, 0))

    # metrics
    mask = preds_df["y_pred"].notna()
    if mask.sum() == 0:
        print("All predictions failed; check model/convergence or try different order")
    else:
        y_true = preds_df.loc[mask, "y_true"].values
        y_pred = preds_df.loc[mask, "y_pred"].values
        rmse = mean_squared_error(y_true, y_pred, squared=False)
        mae = mean_absolute_error(y_true, y_pred)
        print(f"RMSE: {rmse:.6f}  MAE: {mae:.6f}  predictions: {mask.sum()}")

    out_file = os.path.join(root, "arima_predictions.csv")
    preds_df.to_csv(out_file, index=False)
    print(f"Saved predictions to {out_file}")

main()

Using 15 features: ['lower_wick', 'vol_chg', 'ma_gap_20', 'vol_ratio_20', 'ret_1', 'upper_wick', 'co_ret', 'range_pct', 'ret_2', 'vol_10', 'ret_5', 'vol_20', 'ret_4', 'ma_gap_10', 'ret_3']
Data rows after dropna: 4028; window=14


KeyboardInterrupt: 