In [1]:
%load_ext autoreload
%autoreload 2

In [23]:
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from scipy import stats
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

In [None]:
def simple_model_test(df, stat='PTS'):
    """
    Test simple linear adjustments for rest and home court effects
    
    Model structure:
    y = baseline + home_effect + rest_effect
    where:
    - baseline is team's season average
    - home_effect is single league-wide parameter
    - rest_effect is one parameter per rest category
    """

    # Reshape data
    data = reshape_data(df)
    
    # Calculate team season baselines
    team_seasons = data.groupby(['SEASON', 'TEAM'])[stat].mean().reset_index()
    data = data.merge(team_seasons, 
                     on=['SEASON', 'TEAM'], 
                     suffixes=('', '_baseline'))
    
    # Create design matrix
    X = pd.get_dummies(data['REST'], prefix='rest')
    X['is_home'] = data['is_home']
    
    # Response variable: difference from baseline
    y = data[stat] - data[f'{stat}_baseline']
    
    # Fit model
    model = LinearRegression(fit_intercept=False)  # no intercept needed
    model.fit(X, y)
    
    # Print results
    print("\n=== Simple Model Results ===")
    print("\nHome Court Effect:")
    print(f"{model.coef_[-1]:.2f} {stat}")
    
    print("\nRest Effects:")
    for i, coef in enumerate(model.coef_[:-1]):
        print(f"Rest {i}: {coef:.2f} {stat}")
    
    # Calculate R-squared
    y_pred = model.predict(X)
    r2 = 1 - np.sum((y - y_pred)**2) / np.sum((y - y.mean())**2)
    print(f"\nR-squared: {r2:.3f}")
    
    # Visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Actual vs Predicted
    axes[0,0].scatter(y, y_pred, alpha=0.1)
    axes[0,0].plot([y.min(), y.max()], [y.min(), y.max()], 'r--')
    axes[0,0].set_xlabel('Actual Difference from Baseline')
    axes[0,0].set_ylabel('Predicted Difference from Baseline')
    axes[0,0].set_title('Actual vs Predicted Differences')
    
    # 2. Rest Effects with Confidence Intervals
    rest_effects = pd.DataFrame({
        'effect': model.coef_[:-1],
        'rest': range(len(model.coef_[:-1]))
    })
    
    # Calculate standard errors
    X_rest = X.iloc[:, :-1]  # exclude home effect
    residuals = y - y_pred
    mse = np.sum(residuals**2) / (len(y) - len(model.coef_))
    var_coef = mse * np.linalg.inv(X_rest.T @ X_rest).diagonal()
    rest_effects['se'] = np.sqrt(var_coef)
    
    rest_effects.plot(x='rest', y='effect', 
                     yerr='se', kind='bar', ax=axes[0,1])
    axes[0,1].set_title('Rest Effects with 95% CI')
    axes[0,1].set_xlabel('Rest Days')
    axes[0,1].set_ylabel(f'Effect on {stat}')
    
    # 3. Residual Plot
    axes[1,0].scatter(y_pred, residuals, alpha=0.1)
    axes[1,0].axhline(y=0, color='r', linestyle='--')
    axes[1,0].set_xlabel('Predicted Difference')
    axes[1,0].set_ylabel('Residuals')
    axes[1,0].set_title('Residual Plot')
    
    # 4. QQ Plot of Residuals
    from scipy import stats
    stats.probplot(residuals, dist="norm", plot=axes[1,1])
    axes[1,1].set_title('Normal Q-Q Plot of Residuals')
    
    plt.tight_layout()
    plt.show()
    
    # Additional Diagnostics
    print("\nDiagnostic Statistics:")
    print(f"RMSE: {np.sqrt(mean_squared_error(y, y_pred)):.2f}")
    print(f"Mean Absolute Error: {np.mean(np.abs(y - y_pred)):.2f}")
    print(f"Mean Residual: {np.mean(residuals):.2f}")
    print(f"Residual Std: {np.std(residuals):.2f}")
    
    # Test for homoscedasticity
    from scipy import stats
    _, p_value = stats.levene(residuals[X['is_home'] == 0], 
                             residuals[X['is_home'] == 1])
    print(f"\nLevene's test p-value (home/away): {p_value:.3f}")
    
    # Return model and diagnostics
    results = {
        'model': model,
        'coefficients': dict(zip(X.columns, model.coef_)),
        'r2': r2,
        'rmse': np.sqrt(mean_squared_error(y, y_pred)),
        'residuals': residuals,
        'y_pred': y_pred
    }
    
    return results
