In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
from statsmodels.stats.diagnostic import acorr_ljungbox
from scipy.stats import chi2

# Load and clean data, setting 'Epoch' as the index
data = pd.read_csv("E:\Vibration\Stats\Overlapping Epochs\All Frequency EMdelay\StatsGeneral_5_AF.csv")  # Replace with your actual file path
data = data.replace([float('inf'), float('-inf')], float('nan')).dropna(subset=['CVForce', 'CVISI', 'SDCST'])
data.set_index('Epoch', inplace=True)  # Set 'Epoch' as the index to ensure time alignment

# Set common ARIMAX parameters
common_p = 1
common_d = 2
common_q = 1

# Function to fit ARIMAX model and return statistics
def fit_arimax_model(endog, exog, order=(common_p,common_d,common_q)):
    model = SARIMAX(endog, exog=exog, order=order, enforce_stationarity=False, enforce_invertibility=False)
    model_fit = model.fit(disp=False)
    return model_fit

# Prepare results list
results = []

# Loop through each subject
for subject_id, subject_data in data.groupby('SubjectID'):
    print(f"\nProcessing SubjectID: {subject_id}")
    
    # Prepare the endogenous variable (CVForce) and exogenous predictors with lagged values
    subject_data = subject_data.reset_index()  # Reset index to use 'Epoch' as a column
    y = subject_data['CVForce']
    
    # Dictionary to store statistics for each model
    model_stats = {}

    # Fit models for both predictors, CVISI only, and SDCST only
    for model_name, exog in {
        'Both Predictors': subject_data[['CVISI', 'SDCST']],
        'CVISI Only': subject_data[['CVISI']],
        'SDCST Only': subject_data[['SDCST']]
    }.items():
        # Fit ARIMAX model
        model_fit = fit_arimax_model(y, exog, order=(common_p, common_d, common_q))
        y_pred = model_fit.predict().shift(-1)  # Shift the predictions to account for model differencing
        valid_idx = y_pred.dropna().index   # Indeces of valid predictions
        y_valid = y.loc[valid_idx]  # Valid values
        y_pred_valid = y_pred.loc[valid_idx]

        # Fit the null model (intercept-only)
        null_model_fit = fit_arimax_model(y, None, order=(0, common_d, 0))
        null_pred = null_model_fit.predict().shift(-1)  # Null model predictions
        null_pred_valid = null_pred.loc[valid_idx]
        
        # Mean Squared Error pseudo R-squared
        mse_model = mean_squared_error(y_valid, y_pred_valid)
        mse_baseline = mean_squared_error(y_valid, null_pred_valid)
        mse_r_squared = 1 - (mse_model/mse_baseline) if mse_baseline != 0 else float('nan')
        
        # Likelihood ratio test for model p-value
        ll_full = model_fit.llf # Log-likelihood of the full model
        ll_null = null_model_fit.llf    # Log-likelihood of the null model
        df_full = len(model_fit.params) # degrees of freedom
        df_null = len(null_model_fit.params)    # degrees of freedom
        df = df_full - df_null
        lr_stat = 2 * (ll_full - ll_null)
        model_p_value = chi2.sf(lr_stat, df)
        
        # Calculate autocorrelation of residuals with Ljung-Box test
        lb_test = acorr_ljungbox(model_fit.resid, lags=[2], return_df=True)
        lb_p_value = lb_test['lb_pvalue'].values[0]
        
        # Store statistics for this model
        model_stats[model_name] = {
            'McFadden R-squared': mse_r_squared,
            'Model p-value': model_p_value,
            'Ljung-Box p-value': lb_p_value,
            'Predictions': y_pred
        }

    # Determine the best model based on criteria
    best_model_name = None
    highest_pseudo_r_squared = -float('inf')
    
    for model_name, stats in model_stats.items():
        if stats['Ljung-Box p-value'] > 0.05 and stats['Model p-value'] <= 0.1:
            if stats['McFadden R-squared'] > highest_pseudo_r_squared:
                highest_pseudo_r_squared = stats['McFadden R-squared']
                best_model_name = model_name
    
    # Collect the best model's statistics and highlight it
    for model_name, stats in model_stats.items():
        highlight = "*" if model_name == best_model_name else ""
        results.append({
            'SubjectID': subject_id,
            'Model': model_name + highlight,
            'McFadden R-squared': f"{stats['McFadden R-squared']:.3f}" + highlight,
            'Model p-value': f"{stats['Model p-value']:.3f}" + highlight,
            'Ljung-Box p-value': f"{stats['Ljung-Box p-value']:.3f}" + highlight
        })
        
        # Plot actual vs. predicted values for the best model
        if model_name == best_model_name:
            plt.figure(figsize=(12, 5))
            plt.plot(subject_data['Epoch'], y, label='Actual CVForce', color='blue', alpha=0.6)
            plt.plot(subject_data['Epoch'], stats['Predictions'], label=f'Predicted CVForce ({model_name})', color='red', linestyle='--')

            plt.title(f"Best Model for Subject {subject_id}: {model_name}")
            plt.xlabel('Epoch')
            plt.ylabel('CVForce')
            plt.legend()
            plt.show()

            # Calculate and plot residuals
            residuals = y.loc[valid_idx] - stats['Predictions'].loc[valid_idx]  # Use valid indices
            plt.figure(figsize=(12, 4))
            plt.plot(subject_data['Epoch'].loc[valid_idx], residuals, label='Residuals', color='purple')
            plt.axhline(0, color='gray', linestyle='--')
            plt.title(f"Residuals for Best Model ({model_name}) for Subject {subject_id}")
            plt.xlabel('Epoch')
            plt.ylabel('Residuals (Actual - Predicted)')
            plt.legend()
            plt.show()

# Set display options for non-truncated DataFrame output
pd.set_option("display.max_rows", None)       # Show all rows
pd.set_option("display.max_columns", None)    # Show all columns
pd.set_option("display.width", 1000)          # Set the width to avoid line breaks
pd.set_option("display.max_colwidth", None)   # Display full column width if needed

# Convert results to DataFrame and display
results_df = pd.DataFrame(results)
print("\nModel Summary for All Subjects:")
print(results_df)