# Synthetic Hawkes Calibration

Benchmark classical and neural Hawkes models on synthetic data.

## 1. Generate synthetic sequences

In [None]:

import numpy as np
import torch
from torch.utils.data import DataLoader

import neural_hawkes as nh

np.random.seed(42)
torch.manual_seed(42)

synthetic_cfg = dict(num_sequences=200, num_events=400, num_types=4, seed=42)
sequences = nh.generate_synthetic_sequences(**synthetic_cfg)
dataset = nh.EventSequenceDataset(sequences, window_size=64, stride=32)
train_set, val_set, test_set = nh.split_dataset(dataset, (0.7, 0.15, 0.15))


## 2. Classical Hawkes baseline (tick)

In [None]:

try:
    from tick.hawkes import HawkesExpKern
    times = [seq.times for seq in sequences]
    learner = HawkesExpKern(decays=2.0)
    learner.fit(times)
    print('Classical Hawkes fitted:')
    print('mu =', learner.baseline)
    print('adjacency =', learner.adjacency)
except Exception as exc:
    print('tick not available or fit failed:', exc)


## 3. Neural Hawkes training

In [None]:

collate = nh.collate_windows
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, collate_fn=collate)
val_loader = DataLoader(val_set, batch_size=256, shuffle=False, collate_fn=collate)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, collate_fn=collate)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nh.NeuralHawkesModel(num_types=4, embed_dim=32, hidden_dim=64, backbone='gru').to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

history = []
for epoch in range(1, 6):
    train_metrics = nh.train_one_epoch(model, train_loader, optim, device, delta_weight=1.0)
    val_metrics = nh.evaluate(model, val_loader, device, delta_weight=1.0)
    history.append((train_metrics, val_metrics))
    print(f"Epoch {epoch}: train {train_metrics['loss']:.4f}, val {val_metrics['loss']:.4f}")


## 4. Diagnostics

In [None]:

pred_deltas, true_deltas, mask = nh.collect_predictions(model, test_loader, device)
diag = nh.time_rescaling_diagnostics(pred_deltas, true_deltas, mask)
diag
