# EarthquakeNPP quickstart for adding a new model
This notebook walks through the minimum steps to plug a custom model into the EarthquakeNPP benchmark:

- load one of the provided catalogs
- split events into train/val/test using the benchmark time splits
- build sliding-window datasets that expose fixed-length histories
- fit a tiny spatio-temporal conditional intensity function (CIF)
- evaluate the CIF on the test split with a log-likelihood that uses numerical integration for the temporal survival term


## Set up paths and choose a dataset
Pick any catalog listed in `CATALOG_PATHS` (they match the benchmark names in the paper). The split dates come from `Datasets/README.md`.

In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import sys

# Relative paths from this notebook (located in Experiments/)
PROJECT_ROOT = Path('..').resolve()
AUTO_STPP_SRC = PROJECT_ROOT / 'Experiments' / 'AutoSTPP' / 'src'
sys.path.append(str(AUTO_STPP_SRC))
from data.data import SlidingWindowWrapper

# Choose your catalog here
DATASET = 'ComCat_25'  # e.g. ComCat_25 | SCEDC_25 | SCEDC_20 | SCEDC_30 | SanJac_10 | SaltonSea_10 | WHITE_06

CATALOG_PATHS = {
    'ComCat_25': PROJECT_ROOT / 'Datasets' / 'ComCat' / 'ComCat_catalog.csv',
    'SCEDC_20': PROJECT_ROOT / 'Datasets' / 'SCEDC' / 'SCEDC_catalog.csv',
    'SCEDC_25': PROJECT_ROOT / 'Datasets' / 'SCEDC' / 'SCEDC_catalog.csv',
    'SCEDC_30': PROJECT_ROOT / 'Datasets' / 'SCEDC' / 'SCEDC_catalog.csv',
    'SanJac_10': PROJECT_ROOT / 'Datasets' / 'QTM' / 'SanJac_catalog.csv',
    'SaltonSea_10': PROJECT_ROOT / 'Datasets' / 'QTM' / 'SaltonSea_catalog.csv',
    'WHITE_06': PROJECT_ROOT / 'Datasets' / 'WHITE' / 'WHITE_catalog.csv',
    'ETAS_25': PROJECT_ROOT / 'Datasets' / 'ETAS' / 'ETAS_California_catalog.csv',
    'ETAS_incomplete_25': PROJECT_ROOT / 'Datasets' / 'ETAS' / 'ETAS_California_incomplete_catalog.csv',
    'Japan_Deprecated': PROJECT_ROOT / 'Datasets' / 'Japan_Deprecated' / 'Japan_catalog.csv',
}

# Magnitude thresholds
MAG_THRESHOLDS = {
    'ComCat_25': 2.5,
    'SCEDC_20': 2.0,
    'SCEDC_25': 2.5,
    'SCEDC_30': 3.0,
    'SanJac_10': 1.0,
    'SaltonSea_10': 1.0,
    'WHITE_06': 0.6,
    'ETAS_25': 1.0,
    'ETAS_incomplete_25': 1.0,
    'Japan_Deprecated': 2.5,
}

# Time-based splits from Datasets/README.md
SPLIT_BOUNDARIES = {
    'ComCat_25': dict(train_start='1981-01-01', val_start='1998-01-01', test_start='2007-01-01', test_end='2020-01-17'),
    'SCEDC_20': dict(train_start='1985-01-01', val_start='2005-01-01', test_start='2014-01-01', test_end='2020-01-01'),
    'SCEDC_25': dict(train_start='1985-01-01', val_start='2005-01-01', test_start='2014-01-01', test_end='2020-01-01'),
    'SCEDC_30': dict(train_start='1985-01-01', val_start='2005-01-01', test_start='2014-01-01', test_end='2020-01-01'),
    'SanJac_10': dict(train_start='2009-01-01', val_start='2014-01-01', test_start='2016-01-01', test_end='2018-01-01'),
    'SaltonSea_10': dict(train_start='2009-01-01', val_start='2014-01-01', test_start='2016-01-01', test_end='2018-01-01'),
    'WHITE_06': dict(train_start='2009-01-01', val_start='2014-01-01', test_start='2017-01-01', test_end='2021-01-01'),
    'ETAS_25': dict(train_start='1981-01-01', val_start='1998-01-01', test_start='2007-01-01', test_end='2020-01-17'),
    'ETAS_incomplete_25': dict(train_start='1981-01-01', val_start='1998-01-01', test_start='2007-01-01', test_end='2020-01-17'),
    'Japan_Deprecated': dict(train_start='1992-01-01', val_start='2007-01-01', test_start='2011-01-01', test_end='2020-01-01'),
}

catalog_path = CATALOG_PATHS[DATASET]
split_dates = {k: pd.Timestamp(v) for k, v in SPLIT_BOUNDARIES[DATASET].items()}
print(f'Using catalog: {catalog_path}')
print(f"Splits: {split_dates}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


Using catalog: EarthquakeNPP/Datasets/ComCat/ComCat_catalog.csv
Splits: {'train_start': Timestamp('1981-01-01 00:00:00'), 'val_start': Timestamp('1998-01-01 00:00:00'), 'test_start': Timestamp('2007-01-01 00:00:00'), 'test_end': Timestamp('2020-01-17 00:00:00')}


  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

## Load the catalog and create the benchmark splits
We keep only the `time`, `x`, and `y` columns that all baselines consume. Times are converted to days since the first event in each split so that magnitudes of the inputs stay reasonable.

In [2]:
df = pd.read_csv(catalog_path, parse_dates=['time']).sort_values('time').reset_index(drop=True)

mag_filter = MAG_THRESHOLDS.get(DATASET)
if mag_filter is not None and 'magnitude' in df.columns:
    df = df[df['magnitude'] >= mag_filter].reset_index(drop=True)

print(df[['time', 'x', 'y']].head())

train_df = df[(df.time >= split_dates['train_start']) & (df.time < split_dates['val_start'])]
val_df = df[(df.time >= split_dates['val_start']) & (df.time < split_dates['test_start'])]
test_df = df[(df.time >= split_dates['test_start']) & (df.time < split_dates['test_end'])]

print(f"Train events: {len(train_df):,}")
print(f"Val events:   {len(val_df):,}")
print(f"Test events:  {len(test_df):,}")

# Convert each split to a single continuous sequence [time, x, y], time measured in days from the split start
def df_to_sequence(split_df):
    if len(split_df) == 0:
        raise ValueError('Split is empty. Check split boundaries for the chosen dataset.')
    t0 = split_df['time'].iloc[0]
    t_days = (split_df['time'] - t0).dt.total_seconds() / 86400.0
    return np.stack([t_days.to_numpy(), split_df['x'].to_numpy(), split_df['y'].to_numpy()], axis=1)

train_seq = df_to_sequence(train_df)
val_seq = df_to_sequence(val_df)
test_seq = df_to_sequence(test_df)

train_seq.shape, val_seq.shape, test_seq.shape


                     time           x           y
0 1971-01-01 20:36:17.720 -221.828430  -52.446502
1 1971-01-02 02:19:13.010  -18.087790  112.702064
2 1971-01-02 02:37:49.820  -14.801382  114.585400
3 1971-01-02 06:27:39.120   -5.390044 -127.740486
4 1971-01-02 07:59:08.050  -24.172068  113.773295
Train events: 40,701
Val events:   14,741
Test events:  21,885


((40701, 3), (14741, 3), (21885, 3))

## Build sliding-window datasets
`SlidingWindowWrapper` converts the single long sequence into fixed-length histories (`st_X`) and their next-event targets (`st_Y`). Setting `roll=False` keeps time as the first column, which is convenient for the log-likelihood below.

In [3]:
LOOKBACK = 32   # number of past events per window
LOOKAHEAD = 1  # predict only the next event for simplicity

train_ds = SlidingWindowWrapper([train_seq], lookback=LOOKBACK, lookahead=LOOKAHEAD, roll=False, normalized=False, device=device)
val_ds = SlidingWindowWrapper([val_seq], lookback=LOOKBACK, lookahead=LOOKAHEAD, roll=False, normalized=False, device=device)
test_ds = SlidingWindowWrapper([test_seq], lookback=LOOKBACK, lookahead=LOOKAHEAD, roll=False, normalized=False, device=device)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=256)
test_loader = DataLoader(test_ds, batch_size=256)

print(f"Sliding windows: train={len(train_ds):,}, val={len(val_ds):,}, test={len(test_ds):,}")
first_x, first_y, first_x_cum, first_y_cum, idx = train_ds[0]
print('Example history shape:', first_x.shape)
print('Example next event (absolute time, x, y):', first_y_cum)


Sliding windows: train=40,669, val=14,709, test=21,853
Example history shape: torch.Size([32, 3])
Example next event (absolute time, x, y): tensor([[  4.9511, 268.8405,   8.6861]], device='cuda:0')


## Define a spatio-temporal conditional intensity
We factor the CIF into a **temporal rate** and a **spatial density**:

- `lambda_time(t | history)` is non-negative and only affects the survival integral. We integrate it over time; assuming the spatial domain is infinite means the integral over space of the spatial density is 1.
- `p(x, y | history)` is a normalized Gaussian density over the infinite plane, so its integral over space is 1. The log-probability contributes to the event log-likelihood.


In [4]:
class DummySpatioTemporalCIF(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64):
        super().__init__()
        self.encoder = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.time_head = nn.Sequential(
            nn.Linear(hidden_dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )
        self.spatial_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 4),  # mean_x, mean_y, log_std_x, log_std_y
        )

    def encode_history(self, history):
        # history: [B, lookback, input_dim]
        _, h = self.encoder(history)
        return h[-1]  # [B, hidden_dim]

    def lambda_time(self, history, t_eval):
        """
        history: [B, lookback, 3]
        t_eval: [B, K] times (days) measured from the last history event
        returns: [B, K] temporal intensities (non-negative)
        """
        h = self.encode_history(history)
        t_features = torch.stack([t_eval, t_eval ** 2], dim=-1)
        h_expanded = h.unsqueeze(1).expand(-1, t_eval.shape[1], -1)
        logits = self.time_head(torch.cat([t_features, h_expanded], dim=-1)).squeeze(-1)
        return F.softplus(logits) + 1e-6

    def spatial_distribution(self, history):
        """
        Returns a factorized Gaussian over (x, y) with parameters conditioned on history.
        Integral over the infinite plane equals 1, matching the assumption of infinite spatial domain.
        """
        h = self.encode_history(history)
        params = self.spatial_head(h)
        mean = params[:, :2]
        log_std = params[:, 2:].clamp(-5, 5)  # keep std numerically sane
        std = torch.exp(log_std)
        dist_x = Normal(mean[:, 0], std[:, 0])
        dist_y = Normal(mean[:, 1], std[:, 1])
        return dist_x, dist_y

model = DummySpatioTemporalCIF().to(device)
model


DummySpatioTemporalCIF(
  (encoder): GRU(3, 64, batch_first=True)
  (time_head): Sequential(
    (0): Linear(in_features=66, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
  (spatial_head): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=4, bias=True)
  )
)

## Log-likelihood with temporal integration and spatial density
For each window the log-likelihood is split into temporal and spatial parts:

$$ \underbrace{\log \lambda_\text{time}(t_i) - \int_0^T \lambda_\text{time}(t) \, dt}_{\text{temporal}} + \underbrace{\log p(x_i, y_i)}_{\text{spatial}} $$

The spatial density integrates to 1 over the infinite plane, so only the temporal part appears in the survival integral.

In [5]:
def loglik_batch(model, batch, device, integration_steps=128):
    """Returns temporal, spatial, and total log-likelihood means for a batch."""
    st_x, st_y, st_x_cum, st_y_cum, _ = batch
    history = st_x.to(device)
    history_abs = st_x_cum.to(device)
    future_abs = st_y_cum.to(device)

    last_history_time = history_abs[:, -1, 0]
    next_event_time = future_abs[:, :, 0].squeeze(-1)  # [B]
    event_xy = future_abs[:, 0, 1:3]  # [B, 2]

    horizon = torch.clamp(next_event_time - last_history_time, min=1e-5)

    base_grid = torch.linspace(0.0, 1.0, steps=integration_steps, device=device)
    t_grid = horizon.unsqueeze(1) * base_grid  # [B, K]

    lambda_grid = model.lambda_time(history, t_grid)
    integral = torch.trapz(lambda_grid, t_grid, dim=1)

    event_lambda = model.lambda_time(history, horizon.unsqueeze(1)).squeeze(-1)

    dist_x, dist_y = model.spatial_distribution(history)
    spatial_logprob = dist_x.log_prob(event_xy[:, 0]) + dist_y.log_prob(event_xy[:, 1])

    temporal_ll = torch.log(event_lambda + 1e-8) - integral
    total_ll = temporal_ll + spatial_logprob

    temporal_mean = temporal_ll.mean()
    spatial_mean = spatial_logprob.mean()
    total_mean = total_ll.mean()
    return temporal_mean, spatial_mean, total_mean

# quick smoke test
loglik_batch(model, next(iter(train_loader)), device)


(tensor(-0.6562, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-84648.4062, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-84649.0625, device='cuda:0', grad_fn=<MeanBackward0>))

## Train briefly
We minimize the negative total log-likelihood but track temporal and spatial parts separately.

In [6]:
EPOCHS = 10
MAX_TRAIN_BATCHES = 50
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, EPOCHS + 1):
    model.train()
    t_logs, s_logs, tot_logs = [], [], []
    for step, batch in enumerate(train_loader, start=1):
        optimizer.zero_grad()
        temporal_ll, spatial_ll, total_ll = loglik_batch(model, batch, device)
        loss = -total_ll
        loss.backward()
        optimizer.step()

        t_logs.append(float(temporal_ll.item()))
        s_logs.append(float(spatial_ll.item()))
        tot_logs.append(float(total_ll.item()))
        if step >= MAX_TRAIN_BATCHES:
            break
    print(f"Epoch {epoch}: train temporal={np.mean(t_logs):.4f}, spatial={np.mean(s_logs):.4f}, total={np.mean(tot_logs):.4f}")

    model.eval()
    with torch.no_grad():
        vt_logs, vs_logs, vtot_logs = [], [], []
        for batch in val_loader:
            temporal_ll, spatial_ll, total_ll = loglik_batch(model, batch, device)
            vt_logs.append(float(temporal_ll.item()))
            vs_logs.append(float(spatial_ll.item()))
            vtot_logs.append(float(total_ll.item()))
    print(f"          val temporal={np.mean(vt_logs):.4f}, spatial={np.mean(vs_logs):.4f}, total={np.mean(vtot_logs):.4f}")


Epoch 1: train temporal=0.2912, spatial=-10581.2646, total=-10580.9733
          val temporal=0.3943, spatial=-804.5598, total=-804.1655
Epoch 2: train temporal=0.8133, spatial=-377.7007, total=-376.8874
          val temporal=0.3739, spatial=-206.8791, total=-206.5052
Epoch 3: train temporal=0.8614, spatial=-110.8762, total=-110.0149
          val temporal=0.3950, spatial=-95.0211, total=-94.6260
Epoch 4: train temporal=0.8632, spatial=-58.2378, total=-57.3746
          val temporal=0.4065, spatial=-59.6350, total=-59.2284
Epoch 5: train temporal=0.8774, spatial=-37.9387, total=-37.0612
          val temporal=0.4291, spatial=-44.5775, total=-44.1484
Epoch 6: train temporal=0.8968, spatial=-29.1259, total=-28.2290
          val temporal=0.4448, spatial=-36.9842, total=-36.5395
Epoch 7: train temporal=0.9218, spatial=-25.2101, total=-24.2883
          val temporal=0.4653, spatial=-32.2704, total=-31.8051
Epoch 8: train temporal=0.9518, spatial=-22.1353, total=-21.1835
          val temp

## Evaluate on the test split
We report temporal, spatial, and total log-likelihood means.

In [7]:
model.eval()
with torch.no_grad():
    tt_logs, ts_logs, ttot_logs = [], [], []
    for batch in test_loader:
        temporal_ll, spatial_ll, total_ll = loglik_batch(model, batch, device, integration_steps=256)
        tt_logs.append(float(temporal_ll.item()))
        ts_logs.append(float(spatial_ll.item()))
        ttot_logs.append(float(total_ll.item()))

print(f"Test temporal log-likelihood: {np.mean(tt_logs):.4f}")
print(f"Test spatial log-likelihood:  {np.mean(ts_logs):.4f}")
print(f"Test total log-likelihood:    {np.mean(ttot_logs):.4f}")


Test temporal log-likelihood: 0.6687
Test spatial log-likelihood:  -35.4902
Test total log-likelihood:    -34.8215
