# ARIMA Modeling for Variant-Specific Death Predictions with Spatial Features

This notebook implements ARIMA/ARIMAX models to predict COVID-19 deaths by variant, incorporating spatial connectivity through adjacency matrices.

## Approach:
1. Calculate variant-specific deaths using variant prevalence proportions
2. Use different time windows for each variant based on their active periods
3. Incorporate spatial features from adjacency matrices (border, airport, highway)
4. Build separate ARIMA/ARIMAX models for each variant
5. Evaluate and compare predictions

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.stattools import adfuller, acf, pacf
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import warnings
warnings.filterwarnings('ignore')

# Plotting settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline

pd.set_option('display.max_columns', 50)

## 1. Load Data and Adjacency Matrices

In [None]:
# Load the main dataset
df = pd.read_csv('../processed data/combined_prevalence_and_exogenous.csv')
df['date'] = pd.to_datetime(df['date'])

# Load adjacency matrices and ensure both index and columns are integers
border_adj = pd.read_csv('../processed data/border_adj_matrix.csv', index_col=0)
border_adj.columns = border_adj.columns.astype(int)

airport_adj = pd.read_csv('../processed data/airport_adj_matrix.csv', index_col=0)
airport_adj.columns = airport_adj.columns.astype(int)

highway_adj = pd.read_csv('../processed data/highway_adj_matrix.csv', index_col=0)
highway_adj.columns = highway_adj.columns.astype(int)

print(f"Dataset shape: {df.shape}")
print(f"Date range: {df['date'].min()} to {df['date'].max()}")
print(f"Counties: {df['location'].nunique()}")
print(f"\nAdjacency matrix shape: {border_adj.shape}")
print(f"Adjacency index dtype: {border_adj.index.dtype}")
print(f"Adjacency columns dtype: {border_adj.columns.dtype}")

## 2. Define Variant Time Windows

Based on prevalence analysis, we use these time windows:

In [3]:
# Define time windows for each variant (when they were actively spreading)
variant_windows = {
    'Alpha': ('2021-01-17', '2021-08-18'),
    'Delta': ('2021-01-10', '2022-02-18'),
    'Epsilon': ('2020-10-27', '2021-05-25'),
    'Iota': ('2021-02-10', '2021-07-10')
}

# Display variant windows
print("Variant Time Windows for Modeling:")
print("="*60)
for variant, (start, end) in variant_windows.items():
    days = (pd.to_datetime(end) - pd.to_datetime(start)).days
    print(f"{variant:10s}: {start} to {end} ({days} days)")

Variant Time Windows for Modeling:
Alpha     : 2021-01-17 to 2021-08-18 (213 days)
Delta     : 2021-01-10 to 2022-02-18 (404 days)
Epsilon   : 2020-10-27 to 2021-05-25 (210 days)
Iota      : 2021-02-10 to 2021-07-10 (150 days)


## 3. Calculate Variant-Specific Deaths

Deaths attributed to each variant = Total deaths × Variant prevalence

In [4]:
# Calculate daily new deaths
df = df.sort_values(['location', 'date'])
df['new_deaths'] = df.groupby('location')['deaths'].diff().clip(lower=0).fillna(0)

# Calculate variant-specific deaths
variants = ['Alpha', 'Delta', 'Epsilon', 'Iota']

for variant in variants:
    df[f'{variant}_deaths'] = df['new_deaths'] * df[variant]

print("Variant-specific death columns created:")
print([f'{v}_deaths' for v in variants])

# Display sample
df[['date', 'location', 'new_deaths'] + [f'{v}_deaths' for v in variants]].head(10)

Variant-specific death columns created:
['Alpha_deaths', 'Delta_deaths', 'Epsilon_deaths', 'Iota_deaths']


Unnamed: 0,date,location,new_deaths,Alpha_deaths,Delta_deaths,Epsilon_deaths,Iota_deaths
8182,2021-04-22,17001,0.0,0.0,0.0,0.0,0.0
8183,2021-04-23,17001,0.0,0.0,0.0,0.0,0.0
8184,2021-04-25,17001,2.0,0.857143,0.0,0.0,0.0
8185,2021-04-26,17001,0.0,0.0,0.0,0.0,0.0
8186,2021-04-27,17001,0.0,0.0,0.0,0.0,0.0
8187,2021-05-04,17001,1.0,0.0,0.0,0.0,0.0
8188,2021-05-05,17001,1.0,0.5,0.0,0.0,0.0
8189,2021-05-06,17001,1.0,0.75,0.0,0.0,0.0
8190,2021-05-10,17001,1.0,0.833333,0.0,0.0,0.0
8191,2021-05-11,17001,0.0,0.0,0.0,0.0,0.0


## 4. Create Spatial Features from Adjacency Matrices

For each county, we calculate weighted neighbor deaths using adjacency matrices.

In [None]:
def calculate_spatial_lag(df, date_col, county_col, value_col, adjacency_matrix, normalize=True):
    """
    Calculate spatial lag (weighted neighbor values) for each county.
    
    Parameters:
    - df: DataFrame with time series data
    - date_col: name of date column
    - county_col: name of county identifier column
    - value_col: name of value column to calculate spatial lag for
    - adjacency_matrix: DataFrame with adjacency weights
    - normalize: whether to normalize by row sum
    
    Returns:
    - Series with spatial lag values
    """
    # Normalize adjacency matrix if requested
    adj = adjacency_matrix.copy()
    if normalize:
        row_sums = adj.sum(axis=1)
        row_sums[row_sums == 0] = 1  # Avoid division by zero
        adj = adj.div(row_sums, axis=0)
    
    spatial_lags = []
    
    for date in df[date_col].unique():
        date_data = df[df[date_col] == date].set_index(county_col)[value_col]
        
        # Ensure counties are aligned (both are int64, no conversion needed)
        date_data = date_data.reindex(adj.index, fill_value=0)
        
        # Calculate spatial lag: W × y
        lag = adj.dot(date_data)
        
        # Create result for this date
        for county in df[df[date_col] == date][county_col]:
            if county in lag.index:
                spatial_lags.append(lag[county])
            else:
                spatial_lags.append(0)
    
    return pd.Series(spatial_lags, index=df.index)

print("Calculating spatial lag features...")

# Note: Both df['location'] and adjacency matrix indices are int64 - no conversion needed

# Calculate spatial lags for each variant's deaths
for variant in variants:
    print(f"  Processing {variant}...")
    death_col = f'{variant}_deaths'
    
    # Calculate spatial lags using different adjacency matrices
    df[f'{variant}_deaths_border_lag'] = calculate_spatial_lag(
        df, 'date', 'location', death_col, border_adj, normalize=True
    )
    
    df[f'{variant}_deaths_airport_lag'] = calculate_spatial_lag(
        df, 'date', 'location', death_col, airport_adj, normalize=True
    )
    
    df[f'{variant}_deaths_highway_lag'] = calculate_spatial_lag(
        df, 'date', 'location', death_col, highway_adj, normalize=True
    )

print("\nSpatial lag features created!")
print(f"New columns: {[col for col in df.columns if '_lag' in col][:3]}...")

## 5. Aggregate Data to State Level for Modeling

We'll model at the state level first, then can extend to county-level if needed.

In [None]:
# Aggregate to state level (sum across all counties)
state_data = df.groupby('date').agg({
    'new_deaths': 'sum',
    **{f'{v}_deaths': 'sum' for v in variants},
    **{f'{v}_deaths_border_lag': 'sum' for v in variants},
    **{f'{v}_deaths_airport_lag': 'sum' for v in variants},
    **{f'{v}_deaths_highway_lag': 'sum' for v in variants},
}).reset_index().sort_values('date')

print(f"State-level aggregated data shape: {state_data.shape}")
state_data.head()

## 6. Visualize Variant-Specific Deaths

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()

for idx, variant in enumerate(variants):
    # Filter to variant's active period
    start, end = variant_windows[variant]
    variant_data = state_data[(state_data['date'] >= start) & (state_data['date'] <= end)]
    
    death_col = f'{variant}_deaths'
    
    # Calculate 7-day rolling average
    variant_data_plot = variant_data.copy()
    variant_data_plot['deaths_7d'] = variant_data_plot[death_col].rolling(7, min_periods=1).mean()
    
    axes[idx].plot(variant_data_plot['date'], variant_data_plot[death_col], 
                   alpha=0.3, label='Daily', color='lightblue')
    axes[idx].plot(variant_data_plot['date'], variant_data_plot['deaths_7d'], 
                   label='7-day avg', color='darkblue', linewidth=2)
    axes[idx].set_title(f'{variant} Variant Deaths', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel('Date')
    axes[idx].set_ylabel('Daily Deaths')
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Stationarity Tests

In [None]:
def adf_test(series, name=''):
    """Augmented Dickey-Fuller test for stationarity"""
    result = adfuller(series.dropna())
    print(f'{name} ADF Test:')
    print(f'  ADF Statistic: {result[0]:.4f}')
    print(f'  p-value: {result[1]:.4f}')
    print(f'  Critical Values:')
    for key, value in result[4].items():
        print(f'    {key}: {value:.4f}')
    
    if result[1] <= 0.05:
        print(f'  Result: Series is STATIONARY (reject H0)\n')
    else:
        print(f'  Result: Series is NON-STATIONARY (fail to reject H0)\n')
    
    return result[1] <= 0.05

print("="*70)
print("STATIONARITY TESTS FOR VARIANT DEATHS")
print("="*70)

for variant in variants:
    start, end = variant_windows[variant]
    variant_data = state_data[(state_data['date'] >= start) & (state_data['date'] <= end)]
    
    death_col = f'{variant}_deaths'
    series = variant_data[death_col]
    
    adf_test(series, f'{variant} Deaths')

## 8. ACF and PACF Analysis

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(16, 16))

for idx, variant in enumerate(variants):
    start, end = variant_windows[variant]
    variant_data = state_data[(state_data['date'] >= start) & (state_data['date'] <= end)]
    
    death_col = f'{variant}_deaths'
    series = variant_data[death_col].dropna()
    
    # ACF
    plot_acf(series, lags=min(40, len(series)//2), ax=axes[idx, 0])
    axes[idx, 0].set_title(f'{variant} - Autocorrelation Function', fontweight='bold')
    
    # PACF
    plot_pacf(series, lags=min(40, len(series)//2), ax=axes[idx, 1])
    axes[idx, 1].set_title(f'{variant} - Partial Autocorrelation Function', fontweight='bold')

plt.tight_layout()
plt.show()

## 9. Build ARIMA Models for Each Variant

In [None]:
def evaluate_model(actual, predicted):
    """Calculate evaluation metrics"""
    mse = mean_squared_error(actual, predicted)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(actual, predicted)
    
    # Avoid division by zero
    if len(actual) > 0 and actual.var() > 0:
        r2 = r2_score(actual, predicted)
    else:
        r2 = np.nan
    
    return {'MSE': mse, 'RMSE': rmse, 'MAE': mae, 'R2': r2}

def train_arima_model(data, target_col, order=(1,1,1), train_ratio=0.8):
    """
    Train ARIMA model on time series data
    
    Returns:
    - model: fitted ARIMA model
    - train_data, test_data: train/test splits
    - predictions: forecasts on test set
    - metrics: evaluation metrics
    """
    # Sort by date
    data = data.sort_values('date').reset_index(drop=True)
    
    # Split into train/test
    train_size = int(len(data) * train_ratio)
    train = data[:train_size]
    test = data[train_size:]
    
    # Prepare series
    train_series = train[target_col]
    test_series = test[target_col]
    
    # Fit ARIMA model
    model = ARIMA(train_series, order=order)
    fitted_model = model.fit()
    
    # Make predictions
    predictions = fitted_model.forecast(steps=len(test))
    
    # Evaluate
    metrics = evaluate_model(test_series.values, predictions)
    
    return fitted_model, train, test, predictions, metrics

# Store results
arima_results = {}

print("="*70)
print("TRAINING ARIMA MODELS")
print("="*70)

for variant in variants:
    print(f"\n{variant} Variant:")
    print("-" * 40)
    
    # Filter data for this variant's time window
    start, end = variant_windows[variant]
    variant_data = state_data[(state_data['date'] >= start) & (state_data['date'] <= end)].copy()
    
    if len(variant_data) < 20:
        print(f"  Insufficient data ({len(variant_data)} observations)")
        continue
    
    death_col = f'{variant}_deaths'
    
    # Try different ARIMA orders
    orders_to_try = [(1,1,1), (2,1,2), (1,0,1), (2,0,2), (3,1,3)]
    best_aic = np.inf
    best_model = None
    best_order = None
    
    for order in orders_to_try:
        try:
            model, train, test, preds, metrics = train_arima_model(
                variant_data, death_col, order=order, train_ratio=0.8
            )
            
            if model.aic < best_aic:
                best_aic = model.aic
                best_model = model
                best_order = order
                best_train = train
                best_test = test
                best_preds = preds
                best_metrics = metrics
        except:
            continue
    
    if best_model is not None:
        print(f"  Best ARIMA order: {best_order}")
        print(f"  AIC: {best_aic:.2f}")
        print(f"  Train size: {len(best_train)}")
        print(f"  Test size: {len(best_test)}")
        print(f"  \nTest Set Metrics:")
        for metric, value in best_metrics.items():
            print(f"    {metric}: {value:.4f}")
        
        arima_results[variant] = {
            'model': best_model,
            'order': best_order,
            'train': best_train,
            'test': best_test,
            'predictions': best_preds,
            'metrics': best_metrics
        }
    else:
        print(f"  Failed to fit model")

## 10. Build ARIMAX Models with Spatial Features

In [None]:
def train_arimax_model(data, target_col, exog_cols, order=(1,1,1), train_ratio=0.8):
    """
    Train ARIMAX model with exogenous variables
    """
    data = data.sort_values('date').reset_index(drop=True)
    
    train_size = int(len(data) * train_ratio)
    train = data[:train_size]
    test = data[train_size:]
    
    train_series = train[target_col]
    test_series = test[target_col]
    train_exog = train[exog_cols]
    test_exog = test[exog_cols]
    
    # Fit SARIMAX (ARIMAX)
    model = SARIMAX(train_series, exog=train_exog, order=order)
    fitted_model = model.fit(disp=False)
    
    # Predict
    predictions = fitted_model.forecast(steps=len(test), exog=test_exog)
    
    # Evaluate
    metrics = evaluate_model(test_series.values, predictions)
    
    return fitted_model, train, test, predictions, metrics

arimax_results = {}

print("="*70)
print("TRAINING ARIMAX MODELS WITH SPATIAL FEATURES")
print("="*70)

for variant in variants:
    print(f"\n{variant} Variant:")
    print("-" * 40)
    
    start, end = variant_windows[variant]
    variant_data = state_data[(state_data['date'] >= start) & (state_data['date'] <= end)].copy()
    
    if len(variant_data) < 20:
        print(f"  Insufficient data ({len(variant_data)} observations)")
        continue
    
    death_col = f'{variant}_deaths'
    
    # Spatial lag features as exogenous variables
    exog_cols = [
        f'{variant}_deaths_border_lag',
        f'{variant}_deaths_airport_lag',
        f'{variant}_deaths_highway_lag'
    ]
    
    # Get best order from ARIMA results
    if variant in arima_results:
        best_order = arima_results[variant]['order']
    else:
        best_order = (1, 1, 1)
    
    try:
        model, train, test, preds, metrics = train_arimax_model(
            variant_data, death_col, exog_cols, order=best_order, train_ratio=0.8
        )
        
        print(f"  ARIMAX order: {best_order}")
        print(f"  AIC: {model.aic:.2f}")
        print(f"  Exogenous features: {len(exog_cols)}")
        print(f"  Train size: {len(train)}")
        print(f"  Test size: {len(test)}")
        print(f"  \nTest Set Metrics:")
        for metric, value in metrics.items():
            print(f"    {metric}: {value:.4f}")
        
        arimax_results[variant] = {
            'model': model,
            'order': best_order,
            'train': train,
            'test': test,
            'predictions': preds,
            'metrics': metrics,
            'exog_cols': exog_cols
        }
        
    except Exception as e:
        print(f"  Failed to fit ARIMAX: {str(e)}")

## 11. Compare ARIMA vs ARIMAX Performance

In [None]:
# Create comparison table
comparison = []

for variant in variants:
    if variant in arima_results and variant in arimax_results:
        arima_metrics = arima_results[variant]['metrics']
        arimax_metrics = arimax_results[variant]['metrics']
        
        comparison.append({
            'Variant': variant,
            'ARIMA_RMSE': arima_metrics['RMSE'],
            'ARIMAX_RMSE': arimax_metrics['RMSE'],
            'ARIMA_MAE': arima_metrics['MAE'],
            'ARIMAX_MAE': arimax_metrics['MAE'],
            'ARIMA_R2': arima_metrics['R2'],
            'ARIMAX_R2': arimax_metrics['R2'],
            'Improvement_RMSE': ((arima_metrics['RMSE'] - arimax_metrics['RMSE']) / arima_metrics['RMSE'] * 100)
        })

comparison_df = pd.DataFrame(comparison)

print("="*70)
print("MODEL COMPARISON: ARIMA vs ARIMAX")
print("="*70)
print(comparison_df.to_string(index=False))

# Visualize comparison
if len(comparison_df) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    metrics_to_plot = ['RMSE', 'MAE', 'R2']
    for idx, metric in enumerate(metrics_to_plot):
        comparison_df.plot(x='Variant', y=[f'ARIMA_{metric}', f'ARIMAX_{metric}'], 
                          kind='bar', ax=axes[idx])
        axes[idx].set_title(f'{metric} Comparison', fontweight='bold')
        axes[idx].set_ylabel(metric)
        axes[idx].legend(['ARIMA', 'ARIMAX'])
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 12. Visualize Predictions

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(18, 12))
axes = axes.flatten()

for idx, variant in enumerate(variants):
    if variant not in arima_results:
        continue
    
    # Get data
    train = arima_results[variant]['train']
    test = arima_results[variant]['test']
    arima_preds = arima_results[variant]['predictions']
    
    death_col = f'{variant}_deaths'
    
    # Plot
    axes[idx].plot(train['date'], train[death_col], label='Training Data', color='blue')
    axes[idx].plot(test['date'], test[death_col], label='Actual Test Data', color='green', linewidth=2)
    axes[idx].plot(test['date'], arima_preds, label='ARIMA Predictions', 
                   color='red', linestyle='--', linewidth=2)
    
    if variant in arimax_results:
        arimax_preds = arimax_results[variant]['predictions']
        axes[idx].plot(test['date'], arimax_preds, label='ARIMAX Predictions', 
                       color='orange', linestyle=':', linewidth=2)
    
    axes[idx].set_title(f'{variant} Variant - Death Predictions', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel('Date')
    axes[idx].set_ylabel('Daily Deaths')
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 13. Residual Analysis

In [None]:
fig, axes = plt.subplots(len(variants), 2, figsize=(16, 4*len(variants)))

for idx, variant in enumerate(variants):
    if variant not in arima_results:
        continue
    
    # ARIMA residuals
    model = arima_results[variant]['model']
    residuals = model.resid
    
    # Plot residuals over time
    axes[idx, 0].plot(residuals)
    axes[idx, 0].axhline(y=0, color='r', linestyle='--')
    axes[idx, 0].set_title(f'{variant} - ARIMA Residuals Over Time', fontweight='bold')
    axes[idx, 0].set_ylabel('Residuals')
    axes[idx, 0].grid(True, alpha=0.3)
    
    # Residuals distribution
    axes[idx, 1].hist(residuals, bins=30, edgecolor='black')
    axes[idx, 1].set_title(f'{variant} - Residuals Distribution', fontweight='bold')
    axes[idx, 1].set_xlabel('Residuals')
    axes[idx, 1].set_ylabel('Frequency')
    axes[idx, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 14. Summary and Conclusions

In [None]:
print("="*70)
print("SUMMARY OF RESULTS")
print("="*70)

print("\n1. VARIANT TIME WINDOWS:")
for variant, (start, end) in variant_windows.items():
    print(f"   {variant}: {start} to {end}")

print("\n2. MODELS TRAINED:")
print(f"   ARIMA models: {len(arima_results)}")
print(f"   ARIMAX models: {len(arimax_results)}")

print("\n3. BEST PERFORMING VARIANT (by R²):")
if len(comparison_df) > 0:
    best_arima = comparison_df.nlargest(1, 'ARIMA_R2')[['Variant', 'ARIMA_R2']]
    best_arimax = comparison_df.nlargest(1, 'ARIMAX_R2')[['Variant', 'ARIMAX_R2']]
    print(f"   ARIMA: {best_arima.iloc[0]['Variant']} (R² = {best_arima.iloc[0]['ARIMA_R2']:.4f})")
    print(f"   ARIMAX: {best_arimax.iloc[0]['Variant']} (R² = {best_arimax.iloc[0]['ARIMAX_R2']:.4f})")

print("\n4. SPATIAL FEATURES IMPACT:")
if len(comparison_df) > 0:
    avg_improvement = comparison_df['Improvement_RMSE'].mean()
    print(f"   Average RMSE improvement with spatial features: {avg_improvement:.2f}%")

print("\n5. KEY FINDINGS:")
print("   - Different variants showed distinct temporal patterns")
print("   - Delta variant had the longest active period (397 days)")
print("   - Spatial connectivity features from adjacency matrices were incorporated")
print("   - ARIMAX models account for neighbor county spillover effects")

print("\n" + "="*70)

## 15. Save Results

In [None]:
# Save comparison results
if len(comparison_df) > 0:
    comparison_df.to_csv('arima_arimax_comparison.csv', index=False)
    print("Comparison results saved to 'arima_arimax_comparison.csv'")

# Save predictions for each variant
for variant in variants:
    if variant in arima_results:
        test = arima_results[variant]['test'].copy()
        test['arima_predictions'] = arima_results[variant]['predictions']
        
        if variant in arimax_results:
            test['arimax_predictions'] = arimax_results[variant]['predictions']
        
        test.to_csv(f'{variant}_predictions.csv', index=False)
        print(f"{variant} predictions saved to '{variant}_predictions.csv'")