# Using Causal Inference to Improve Time Series Forecasting with CausalVAR

This notebook demonstrates how to use the `CausalVAR` forecaster in `sktime`. `CausalVAR` is a Vector Autoregressive (VAR) model that uses a known causal graph to prune its features. 

**The Problem with Standard VAR:** A standard VAR model assumes that every variable in the system can be predicted by the past values of *all* other variables. In a system with many variables (e.g., sales, marketing spend, inventory, price, competitor price), this can lead to overfitting and poor interpretability.

**The `CausalVAR` Solution:** By providing a causal graph, we can tell the model to only use the past values of a variable's *direct causes* as predictors. For example, we might specify that `sales` is caused by `marketing` and `price`, but not by `inventory`.

This notebook will walk through:
1. Defining a causal structure with `pgmpy`.
2. Generating synthetic time series data that respects this structure.
3. Training the `CausalVAR` forecaster.
4. Comparing its forecast to a standard `VAR` model.

In [None]:
%pip install sktime pgmpy statsmodels matplotlib scikit-learn

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pgmpy.models import BayesianNetwork

# Import sktime tools
from sktime.forecasting.var import VAR
from sktime.forecasting.model_selection import temporal_train_test_split
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error

# Import our new forecaster! Make sure the .py file is in the right place.
from sktime.forecasting._causal_var import CausalVAR

### Step 1: Define the Causal Graph

We will model a simple retail scenario:
- `marketing` efforts drive `sales`.
- `sales` levels determine the required `inventory`.
- `price` is an independent variable that also affects `sales`.

We use `pgmpy` to define this structure.

In [None]:
causal_graph = BayesianNetwork([
    ("marketing", "sales"),
    ("price", "sales"),
    ("sales", "inventory")
])

### Step 2: Generate Synthetic Data

Now we create a `pd.DataFrame` of time series data that respects these relationships.

In [None]:
n_periods = 200
idx = pd.date_range("2023-01-01", periods=n_periods, freq="W")

# Create base series
marketing = 10 + pd.Series(np.sin(np.linspace(0, 2 * np.pi, n_periods)) * 5)
price = 50 - pd.Series(range(n_periods)) * 0.1
noise = np.random.normal(0, 2, size=(n_periods, 4))

# Create sales based on its causes (marketing and price from the previous period)
sales = 2 * marketing.shift(1) - 1.5 * price.shift(1) + 100

# Create inventory based on its cause (sales from the previous period)
inventory = -0.8 * sales.shift(1) + 200

data = pd.DataFrame(
    {"marketing": marketing, "price": price, "sales": sales, "inventory": inventory},
    index=idx
)
# Add noise and remove initial NaNs from shifting
data = (data + noise).dropna()

data.plot(subplots=True, figsize=(12, 8), title="Synthetic Time Series Data");

### Step 3: Train and Compare Models

We will split our data and train two models:
1. `CausalVAR`: Our new forecaster that uses the causal graph.
2. `VAR`: The standard `sktime` VAR forecaster that will use all variables as predictors for each other.

In [None]:
y_train, y_test = temporal_train_test_split(data, test_size=30)
fh = y_test.index

# Train CausalVAR
causal_forecaster = CausalVAR(causal_graph=causal_graph, maxlags=5)
causal_forecaster.fit(y_train)
y_pred_causal = causal_forecaster.predict(fh)

# Train standard VAR
standard_forecaster = VAR(maxlags=5)
standard_forecaster.fit(y_train)
y_pred_standard = standard_forecaster.predict(fh)

causal_mape = mean_absolute_percentage_error(y_test, y_pred_causal)
standard_mape = mean_absolute_percentage_error(y_test, y_pred_standard)

print(f"MAPE from CausalVAR: {causal_mape:.4f}")
print(f"MAPE from Standard VAR: {standard_mape:.4f}")

### Step 4: Visualize the Results

Let's plot the forecasts for the 'sales' variable to see how each model performed.

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
y_train['sales'].plot(ax=ax, label='Training Data')
y_test['sales'].plot(ax=ax, label='True Values (Test)', style='--')
y_pred_causal['sales'].plot(ax=ax, label='CausalVAR Forecast')
y_pred_standard['sales'].plot(ax=ax, label='Standard VAR Forecast', style=':')
ax.set_title("Comparing Forecasts for 'sales'")
ax.legend();