## 14. Inference helper — Ray Data batch predictor on GPU  

Define a Ray Data based batch predictor class that loads the trained `TimeSeriesTransformer` once per GPU actor and keeps it resident in memory for efficient inference.  
Each actor processes batches of input windows (e.g., recent time series segments) in parallel, producing forecasts for the next horizon.  

Ray Data inference enables scalable, fault-tolerant prediction pipelines that reuse loaded models across many requests, making it ideal for large-scale batch or near-real-time forecasting workloads.

In [None]:
# 14. Ray Data inference helper — stateful per-actor predictor

class TimeSeriesBatchPredictor:
    """
    Keeps the TimeSeriesTransformer in memory per actor (GPU if available).
    Expects a Pandas batch with a 'past' column containing np.ndarray of shape (INPUT_WINDOW,).
    Returns a batch with a 'pred' column (np.ndarray of shape (HORIZON,)).
    """
    def __init__(self, checkpoint_path: str, model_kwargs: dict):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Recreate model with the *same* hyperparams used during training
        self.model = TimeSeriesTransformer(
            input_window=model_kwargs["input_window"],
            horizon=model_kwargs["horizon"],
            d_model=model_kwargs["d_model"],
            nhead=model_kwargs["nhead"],
            num_layers=model_kwargs["num_layers"],
        ).to(self.device).eval()

        # Load checkpoint weights once per actor
        ckpt = Checkpoint.from_directory(checkpoint_path)
        with ckpt.as_directory() as ckpt_dir:
            state_dict = torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu")
            # Strip DDP prefix if present
            state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
            self.model.load_state_dict(state_dict)

        torch.set_grad_enabled(False)

    def __call__(self, batch):
        import pandas as pd

        past_list = batch["past"]  # each entry: np.ndarray shape (INPUT_WINDOW,)
        # Stack into (B, T, 1)
        x = np.stack([p.astype(np.float32) for p in past_list], axis=0)
        x = torch.from_numpy(x).unsqueeze(-1).to(self.device)  # (B, INPUT_WINDOW, 1)

        # Inference path uses the model's "zeros as decoder input" forward
        preds = self.model(x).detach().cpu().numpy()  # (B, HORIZON)

        out = batch.copy()
        out["pred"] = list(preds)  # each row: np.ndarray (HORIZON,)
        return out[["pred"]]


### 15. Run distributed inference and visualize results  

Use **Ray Data** to perform GPU-based batch inference with the trained model.  
The model runs on a Ray worker, generates a forecast for the latest input window, and returns predictions to the driver.  
De-normalize and plot the forecast against the ground truth to visually assess model performance.
This tutorial uses a very small amount of data. As a result, you can see that the model learns a near-constant solution.

In [None]:
# 15. Run inference on the latest window with Ray Data and plot

# 1) Prepare the latest window on the driver
past_norm = hourly["norm"].iloc[-INPUT_WINDOW:].to_numpy().astype(np.float32)
future_true = hourly["passengers"].iloc[-HORIZON:].to_numpy()  # for visualization only

# 2) Get the best checkpoint directory selected by Ray
with result.checkpoint.as_directory() as ckpt_dir:
    best_ckpt_path = ckpt_dir  # path visible to workers

# 3) Build a tiny Ray Dataset and run inference on a GPU actor
model_kwargs = {
    "input_window": INPUT_WINDOW,
    "horizon": HORIZON,
    "d_model": 128,
    "nhead": 4,
    "num_layers": 3,
}

ds = rdata.from_items([{"past": past_norm}])
pred_ds = ds.map_batches(
    TimeSeriesBatchPredictor,
    fn_constructor_args=(best_ckpt_path, model_kwargs),
    batch_size=1,
    batch_format="pandas",
    concurrency=1,
    num_gpus=1,  # force placement on a GPU worker if available
)

pred_row = pred_ds.take(1)[0]
pred_norm = pred_row["pred"]  # np.ndarray (HORIZON,)

# 4) De-normalize on the driver
mean, std = hourly["passengers"].mean(), hourly["passengers"].std()
pred = pred_norm * std + mean
past = past_norm * std + mean

# 5) Plot

t_past   = np.arange(-INPUT_WINDOW, 0)
STEP_SIZE_HOURS = 0.5  # you mentioned 30-min data
t_future = np.arange(0, HORIZON) * STEP_SIZE_HOURS

plt.figure(figsize=(10, 4))
plt.plot(t_past, past, label="History", marker="o")
plt.plot(t_future, future_true, "--", label="Ground Truth")
plt.plot(t_future, pred, "-.", label="Forecast")
plt.axvline(0)
plt.xlabel("Hours relative")
plt.ylabel("# trips")
plt.title("NYC-Taxi Forecast (Ray Data Inference)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


### 16. Cleanup: remove all training artifacts  
Finally, tidy up by deleting temporary checkpoint folders and any intermediate result directories. Clearing out old artifacts frees disk space and leaves your workspace clean for whatever comes next.

In [None]:
# 16. Cleanup – optionally remove all artifacts to free space
if os.path.exists(DATA_DIR):
    shutil.rmtree(DATA_DIR)
    print(f"Deleted {DATA_DIR}")

### Wrap up and next steps

You built a robust, distributed forecasting workflow using **Ray Train on Anyscale** that:

* Trains a Transformer model across **multiple GPUs** using **Ray Train with Distributed Data Parallel (DDP)**, abstracting away low-level orchestration.
* Recovers automatically from failures with **built-in checkpointing and resume**, even across re-launches or node churn.
* Logs and reports per-epoch metrics using **Ray Train’s reporting APIs**, enabling real-time monitoring and seamless plotting.
* Performs inference using **Ray Data**. This allows you to scale forecasting across GPUs or nodes without changing model code.

---

### Where can you take this next?

The following are a few directions you can explore to extend or adapt this workload:

1. **Hyperparameter sweeps**  
   * Wrap the `TorchTrainer` with **Ray Tune** to search over `d_model`, `nhead`, learning rate, and window sizes.  

2. **Probabilistic forecasting**  
   * Output percentiles or fit a distribution head (for example, Gaussian) to capture prediction uncertainty.  

3. **Multivariate and exogenous features**  
   * Add weather, holidays, or ride-sharing surge multipliers as extra input channels.  

4. **Early-stopping and LR scheduling**  
   * Monitor val-loss and reduce LR on plateau, or stop when improvement < 1 %.  

5. **Model compression**  
   * Distill the large Transformer into a lightweight LSTM or Tiny-Transformer for edge deployment.  

6. **Streaming and online learning**  
   * Use **Ray Serve** to deploy the model and update weights periodically with the latest data.  

7. **Interpretability**  
   * Visualize attention maps to see which time lags the model focuses on—effective for stakeholder trust.  

8. **End-to-end MLOps**  
   * Schedule nightly retraining with **Ray jobs**, log artifacts to MLflow or Weights & Biases, and automate model promotion.  
