# JACTUS: GPU/TPU Portfolio Benchmark

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/JACTUS/blob/main/examples/notebooks/05_gpu_tpu_portfolio_benchmark.ipynb)
[![GitHub](https://img.shields.io/badge/GitHub-JACTUS-blue?logo=github)](https://github.com/pedronahum/JACTUS)
[![PyPI](https://img.shields.io/pypi/v/jactus)](https://pypi.org/project/jactus/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

This notebook benchmarks JACTUS's **array-mode PAM simulation** across CPU and GPU/TPU backends.

You will:
1. Generate large portfolios (1K to 50K contracts)
2. Compare **Python sequential** vs **array-mode batch** simulation
3. Measure the impact of JIT compilation and hardware acceleration
4. Run scenario sweeps and gradient-based sensitivities at portfolio scale

> **Colab tip:** To run on GPU, go to **Runtime > Change runtime type > T4 GPU**.

## 1. Setup

In [None]:
!pip install -q jactus

In [None]:
import random
import time

import jax
import jax.numpy as jnp

import jactus

from jactus.contracts import create_contract
from jactus.contracts.pam_array import (
    batch_simulate_pam,
    batch_simulate_pam_auto,
    batch_simulate_pam_vmap,
    prepare_pam_batch,
    simulate_pam_portfolio,
)
from jactus.core import (
    ActusDateTime,
    ContractAttributes,
    ContractRole,
    ContractType,
    DayCountConvention,
)
from jactus.observers import ConstantRiskFactorObserver

print(f"JACTUS version: {jactus.__version__}")
print(f"JAX version:    {jax.__version__}")
print(f"Backend:        {jax.default_backend()}")
print(f"Devices:        {jax.devices()}")

## 2. Generate Test Portfolios

We generate random PAM (Principal at Maturity) loan portfolios with varying:
- Notional: $50K – $500K
- Rate: 3% – 8%
- Term: 3 – 10 years
- Interest payment cycle: monthly, quarterly, or semi-annual
- Day count convention: A360, A365, 30E/360, or 30/360 US

In [None]:
VALUATION_DATE = ActusDateTime(2025, 6, 1)
DISCOUNT_RATE = 0.045


def generate_portfolio(
    n_loans: int, seed: int = 42
) -> list[tuple[ContractAttributes, ConstantRiskFactorObserver]]:
    """Generate a random portfolio of PAM contracts."""
    rng = random.Random(seed)
    dccs = [
        DayCountConvention.A360,
        DayCountConvention.A365,
        DayCountConvention.E30360,
        DayCountConvention.B30360,
    ]
    obs = ConstantRiskFactorObserver(0.0)
    contracts = []

    for i in range(n_loans):
        orig_year = rng.randint(2020, 2024)
        orig_month = rng.randint(1, 12)
        origination = ActusDateTime(orig_year, orig_month, 15)
        term_years = rng.randint(3, 10)
        maturity = ActusDateTime(orig_year + term_years, orig_month, 15)
        if maturity <= VALUATION_DATE:
            continue

        attrs = ContractAttributes(
            contract_id=f"LOAN-{i:06d}",
            contract_type=ContractType.PAM,
            contract_role=ContractRole.RPA,
            status_date=VALUATION_DATE,
            initial_exchange_date=origination,
            maturity_date=maturity,
            notional_principal=round(rng.uniform(50_000, 500_000), -3),
            nominal_interest_rate=round(rng.uniform(0.03, 0.08), 4),
            day_count_convention=rng.choice(dccs),
            interest_payment_cycle=rng.choice(["1M", "3M", "6M"]),
        )
        contracts.append((attrs, obs))

    return contracts


# Pre-generate portfolios of increasing size
SIZES = [100, 500, 1_000, 5_000, 10_000, 50_000]
portfolios = {}
for n in SIZES:
    portfolios[n] = generate_portfolio(n)
    print(f"  {n:>6,} requested → {len(portfolios[n]):>6,} active contracts")

## 3. Baseline: Python Sequential Simulation

The standard ACTUS path uses `contract.simulate()` which runs a Python `for` loop over events. This is the **baseline** we want to beat.

In [None]:
def python_sequential_pv(contracts, discount_rate):
    """Simulate each contract individually and compute PV."""
    total_pv = 0.0
    for attrs, obs in contracts:
        contract = create_contract(attrs, obs)
        result = contract.simulate()
        pv = sum(float(e.payoff) for e in result.events)
        total_pv += pv
    return total_pv


# Benchmark on small portfolio (Python path is slow on large ones)
python_sizes = [s for s in SIZES if s <= 1_000]
python_results = {}

print("Python Sequential Path")
print("=" * 55)
print(f"{'Contracts':>10}  {'Time (s)':>10}  {'Throughput':>15}  {'Total CF':>14}")
print("-" * 55)

for n in python_sizes:
    contracts = portfolios[n]
    t0 = time.perf_counter()
    total = python_sequential_pv(contracts, DISCOUNT_RATE)
    elapsed = time.perf_counter() - t0
    throughput = len(contracts) / elapsed
    python_results[n] = {"time": elapsed, "throughput": throughput, "total": total}
    print(f"{len(contracts):>10,}  {elapsed:>10.3f}  {throughput:>12,.0f}/sec  ${total:>12,.0f}")

print("=" * 55)

## 4. Array-Mode Batch Simulation

The array-mode path has two phases:

1. **Pre-computation** (`prepare_pam_batch`): Generates event schedules and year fractions as JAX arrays. When all contracts are batch-eligible, this uses a **JAX-native path** (GPU/TPU-ready).

2. **JIT kernel** (`batch_simulate_pam_auto`): Runs the actual simulation as a single `jax.lax.scan` over all contracts simultaneously. Automatically selects:
   - `vmap` on GPU/TPU (parallel across accelerator cores)
   - Manual batching on CPU (lower dispatch overhead)

### 4.1 End-to-End Benchmark

In [None]:
array_results = {}

print(f"Array-Mode Batch Path (backend: {jax.default_backend()})")
print("=" * 75)
print(f"{'Contracts':>10}  {'Prep (s)':>10}  {'Kernel (s)':>12}  {'Total (s)':>10}  {'Throughput':>15}")
print("-" * 75)

for n in SIZES:
    contracts = portfolios[n]
    nc = len(contracts)

    # Phase 1: Pre-computation
    t0 = time.perf_counter()
    batched = prepare_pam_batch(contracts)
    batched_states, batched_et, batched_yf, batched_rf, batched_params, batched_masks = batched
    t_prep = time.perf_counter() - t0

    # Phase 2: JIT kernel (warm-up + measure)
    # First call includes JIT compilation
    final_states, payoffs = batch_simulate_pam_auto(
        batched_states, batched_et, batched_yf, batched_rf, batched_params
    )
    payoffs.block_until_ready()  # Force materialization

    # Steady-state: average of 10 runs
    times_kernel = []
    for _ in range(10):
        t1 = time.perf_counter()
        final_states, payoffs = batch_simulate_pam_auto(
            batched_states, batched_et, batched_yf, batched_rf, batched_params
        )
        payoffs.block_until_ready()
        times_kernel.append(time.perf_counter() - t1)

    t_kernel = sum(times_kernel) / len(times_kernel)
    t_total = t_prep + t_kernel
    throughput = nc / t_total

    array_results[n] = {
        "prep": t_prep,
        "kernel": t_kernel,
        "total": t_total,
        "throughput": throughput,
        "payoffs": payoffs,
        "masks": batched_masks,
    }
    print(
        f"{nc:>10,}  {t_prep:>10.4f}  {t_kernel:>12.6f}  {t_total:>10.4f}  {throughput:>12,.0f}/sec"
    )

print("=" * 75)

### 4.2 Speedup Summary

In [None]:
print("Speedup: Array-Mode vs Python Sequential")
print("=" * 65)
print(f"{'Contracts':>10}  {'Python (s)':>12}  {'Array (s)':>12}  {'Speedup':>10}")
print("-" * 65)

for n in python_sizes:
    if n in array_results:
        py_t = python_results[n]["time"]
        ar_t = array_results[n]["total"]
        speedup = py_t / ar_t
        print(f"{len(portfolios[n]):>10,}  {py_t:>12.3f}  {ar_t:>12.4f}  {speedup:>9.1f}x")

print("=" * 65)

# Extrapolate for large sizes
if python_sizes:
    largest_py = max(python_sizes)
    py_per_contract = python_results[largest_py]["time"] / len(portfolios[largest_py])
    print(f"\nProjected Python time for large portfolios (at {py_per_contract*1000:.1f} ms/contract):")
    for n in SIZES:
        if n > largest_py and n in array_results:
            nc = len(portfolios[n])
            py_est = nc * py_per_contract
            ar_t = array_results[n]["total"]
            print(f"  {nc:>6,} contracts:  Python ~{py_est:.1f}s  vs  Array {ar_t:.3f}s  ({py_est/ar_t:.0f}x)")

## 5. Kernel Performance Deep-Dive

The pre-computation runs once per portfolio. The JIT kernel can be called **many times** with different parameters — this is where GPU/TPU acceleration really shines.

### 5.1 Kernel-Only Throughput

In [None]:
print(f"Kernel-Only Throughput (backend: {jax.default_backend()})")
print("=" * 60)
print(f"{'Contracts':>10}  {'Kernel (ms)':>12}  {'Throughput':>18}")
print("-" * 60)

for n in SIZES:
    r = array_results[n]
    nc = len(portfolios[n])
    kernel_ms = r["kernel"] * 1000
    kernel_throughput = nc / r["kernel"]
    print(f"{nc:>10,}  {kernel_ms:>12.3f}  {kernel_throughput:>15,.0f}/sec")

print("=" * 60)

### 5.2 Dispatch Strategy Comparison

JACTUS provides three batch strategies:

| Strategy | Best For | How It Works |
|----------|----------|-------------|
| `batch_simulate_pam` | CPU | Manual batching, single scan over `[B]` arrays |
| `batch_simulate_pam_vmap` | GPU/TPU | `jax.vmap` over per-contract scan |
| `batch_simulate_pam_auto` | Any | Auto-selects based on `jax.default_backend()` |

Let's compare them directly:

In [None]:
# Use 5K portfolio for dispatch comparison
test_n = 5_000
contracts = portfolios[test_n]
batched = prepare_pam_batch(contracts)
batched_states, batched_et, batched_yf, batched_rf, batched_params, batched_masks = batched

strategies = {
    "manual (CPU-optimized)": batch_simulate_pam,
    "vmap (GPU/TPU-optimized)": batch_simulate_pam_vmap,
    "auto": batch_simulate_pam_auto,
}

print(f"Dispatch Strategy Comparison — {len(contracts):,} contracts (backend: {jax.default_backend()})")
print("=" * 65)

for name, fn in strategies.items():
    # Warm-up
    _, p = fn(batched_states, batched_et, batched_yf, batched_rf, batched_params)
    p.block_until_ready()

    # Measure
    times = []
    for _ in range(20):
        t0 = time.perf_counter()
        _, p = fn(batched_states, batched_et, batched_yf, batched_rf, batched_params)
        p.block_until_ready()
        times.append(time.perf_counter() - t0)

    median = sorted(times)[len(times) // 2]
    throughput = len(contracts) / median
    print(f"  {name:<28}  median: {median*1000:>8.3f} ms  ({throughput:>10,.0f} contracts/sec)")

print("=" * 65)

## 6. Scenario Analysis at Scale

Once the portfolio is pre-computed, running **hundreds of scenarios** is nearly free — we just recompute present values with different discount rates. No re-simulation needed.

In [None]:
# Use largest portfolio
largest_n = max(SIZES)
r = array_results[largest_n]
nc = len(portfolios[largest_n])
payoffs = r["payoffs"]
masks = r["masks"]
masked_payoffs = payoffs * masks

# Retrieve year fractions for discounting
batched = prepare_pam_batch(portfolios[largest_n])
_, _, batched_yf, _, _, _ = batched
cum_yfs = jnp.cumsum(batched_yf, axis=1)

# Run 200 discount rate scenarios
rates = jnp.linspace(0.01, 0.10, 200)

t0 = time.perf_counter()
pvs = []
for rate in rates:
    disc = 1.0 / (1.0 + float(rate) * cum_yfs)
    pv = float(jnp.sum(masked_payoffs * disc))
    pvs.append(pv)

t_scenarios = time.perf_counter() - t0

print(f"Scenario Analysis — {nc:,} contracts x 200 discount rates")
print("=" * 55)
print(f"Total time:         {t_scenarios:.4f}s")
print(f"Per scenario:       {t_scenarios/200*1000:.3f} ms")
print(f"Scenarios/sec:      {200/t_scenarios:,.0f}")
print(f"Contract-scenarios: {nc*200:,} in {t_scenarios:.3f}s")
print(f"")
print(f"PV at 1.0%:    ${pvs[0]:>14,.0f}")
print(f"PV at 4.5%:    ${pvs[89]:>14,.0f}")
print(f"PV at 10.0%:   ${pvs[-1]:>14,.0f}")
print("=" * 55)

## 7. Gradient-Based Sensitivities

Because the simulation kernel is pure JAX, we can use `jax.grad` to compute exact derivatives — no finite-difference bumps needed.

Here we compute **dPV/dRate** (DV01-like sensitivity) for the portfolio.

In [None]:
from jactus.contracts.pam_array import simulate_pam_array, precompute_pam_arrays

# Use a single representative contract for gradient demo
sample_attrs, sample_obs = portfolios[100][0]
init_state, et, yf, rf, params = precompute_pam_arrays(sample_attrs, sample_obs)


def pv_of_rate(rate):
    """Portfolio PV as a function of the contract rate."""
    new_params = params._replace(nominal_interest_rate=rate)
    new_state = init_state._replace(ipnr=rate)
    _, sim_payoffs = simulate_pam_array(new_state, et, yf, rf, new_params)
    cum_yf = jnp.cumsum(yf)
    disc = 1.0 / (1.0 + DISCOUNT_RATE * cum_yf)
    return jnp.sum(sim_payoffs * disc)


# Compute gradient
grad_fn = jax.jit(jax.grad(pv_of_rate))

base_rate = jnp.float32(sample_attrs.nominal_interest_rate)
# Warm-up
_ = grad_fn(base_rate)

t0 = time.perf_counter()
dPV_dRate = grad_fn(base_rate)
dPV_dRate.block_until_ready()
t_grad = time.perf_counter() - t0

pv_val = pv_of_rate(base_rate)

print("Gradient-Based Sensitivity (Automatic Differentiation)")
print("=" * 55)
print(f"Contract:        {sample_attrs.contract_id}")
print(f"Notional:        ${sample_attrs.notional_principal:,.0f}")
print(f"Rate:            {float(base_rate):.4%}")
print(f"PV:              ${float(pv_val):,.2f}")
print(f"dPV/dRate:       ${float(dPV_dRate):,.2f}")
print(f"Gradient time:   {t_grad*1000:.3f} ms")
print(f"")
print(f"Interpretation: A 1bp rate increase changes PV by ${float(dPV_dRate)*0.0001:,.2f}")
print("=" * 55)

## 8. Scaling Analysis

How does performance scale with portfolio size? Let's plot the relationship.

In [None]:
print(f"Scaling Analysis (backend: {jax.default_backend()})")
print("=" * 80)
print(f"{'Contracts':>10}  {'Prep (ms)':>10}  {'Kernel (ms)':>12}  {'Kernel/Contract':>16}  {'Total (ms)':>12}")
print("-" * 80)

for n in SIZES:
    r = array_results[n]
    nc = len(portfolios[n])
    prep_ms = r["prep"] * 1000
    kernel_ms = r["kernel"] * 1000
    per_contract_us = (r["kernel"] / nc) * 1_000_000
    total_ms = r["total"] * 1000
    print(f"{nc:>10,}  {prep_ms:>10.1f}  {kernel_ms:>12.3f}  {per_contract_us:>13.2f} us  {total_ms:>12.1f}")

print("=" * 80)
print("\nus = microseconds per contract (kernel only)")

## 9. Visualization

In [None]:
try:
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 3, figsize=(16, 5))

    sizes = [len(portfolios[n]) for n in SIZES]
    prep_times = [array_results[n]["prep"] * 1000 for n in SIZES]
    kernel_times = [array_results[n]["kernel"] * 1000 for n in SIZES]
    throughputs = [array_results[n]["throughput"] for n in SIZES]

    # Plot 1: Prep vs Kernel time
    ax = axes[0]
    ax.bar(range(len(sizes)), prep_times, label="Pre-computation", alpha=0.8)
    ax.bar(range(len(sizes)), kernel_times, bottom=prep_times, label="JIT Kernel", alpha=0.8)
    ax.set_xticks(range(len(sizes)))
    ax.set_xticklabels([f"{s:,}" for s in sizes], rotation=45)
    ax.set_ylabel("Time (ms)")
    ax.set_title("Execution Time Breakdown")
    ax.legend()

    # Plot 2: Throughput
    ax = axes[1]
    ax.plot(sizes, throughputs, "o-", linewidth=2, markersize=8)
    ax.set_xscale("log")
    ax.set_xlabel("Portfolio Size")
    ax.set_ylabel("Contracts/sec")
    ax.set_title(f"End-to-End Throughput ({jax.default_backend().upper()})")
    ax.grid(True, alpha=0.3)

    # Plot 3: Scenario PV curve
    ax = axes[2]
    rates_pct = [float(r) * 100 for r in rates]
    pvs_millions = [pv / 1e6 for pv in pvs]
    ax.plot(rates_pct, pvs_millions, linewidth=2)
    ax.set_xlabel("Discount Rate (%)")
    ax.set_ylabel("Portfolio PV ($M)")
    ax.set_title(f"PV Sensitivity — {nc:,} Contracts")
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

except ImportError:
    print("Install matplotlib for plots: pip install matplotlib")

## 10. Summary

### Architecture

```
Python attributes ──→ prepare_pam_batch() ──→ batch_simulate_pam_auto() ──→ payoffs
    (N contracts)       JAX-native schedule      jax.lax.scan + vmap         [N, T]
                         + year fractions         (GPU/TPU parallel)
```

### Key Takeaways

- **Pre-computation** is the dominant cost. The JIT kernel is fast on any backend.
- **GPU/TPU** acceleration benefits grow with portfolio size (>1K contracts).
- **Scenario sweeps** over pre-computed payoffs are essentially free.
- **`jax.grad`** gives exact sensitivities without finite-difference bumps.
- No code changes needed — just install `jax[cuda13]` or `jax[tpu]`.

### Next Steps

| Notebook | Topic |
|----------|-------|
| [`00_getting_started_pam.ipynb`](https://colab.research.google.com/github/pedronahum/JACTUS/blob/main/examples/notebooks/00_getting_started_pam.ipynb) | PAM basics |
| [`01_annuity_mortgage.ipynb`](https://colab.research.google.com/github/pedronahum/JACTUS/blob/main/examples/notebooks/01_annuity_mortgage.ipynb) | Mortgages (ANN) |

### Resources

- **Documentation**: [pedronahum.github.io/JACTUS](https://pedronahum.github.io/JACTUS/)
- **GitHub**: [github.com/pedronahum/JACTUS](https://github.com/pedronahum/JACTUS)
- **JAX GPU Install**: `pip install jax[cuda13]`