In [None]:
import numpy as np
import torch
from gpsr.modeling import GPSRQuadScanLattice
from gpsr.datasets import split_dataset
from gpsr.ensemble import (
    train_ensemble,
    plot_2d_distribution,
    compute_mean_and_confidence_interval,
)

from gpsr.datasets import QuadScanDataset

In [None]:
# load dataset
dset = torch.load(
    "example_data/example_datasets/reconstruction_4D.dset", weights_only=False
)
print(
    dset.parameters.shape,
    dset.observations[0].shape,
    dset.screen,
)
dset.plot_data();

In [None]:
# split dataset into training and test sets
train_k_ids = np.arange(0, len(dset.parameters), 2)
train_dset, test_dset = split_dataset(dset, train_k_ids)

In [None]:
train_dset.plot_data();

In [None]:
# create diagnostic lattice
p0c = 43.36e6  # reference momentum in eV/c
gpsr_lattice = GPSRQuadScanLattice(l_quad=0.1, l_drift=1.0, screen=train_dset.screen)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=20)

result = train_ensemble(
    gpsr_lattice, train_loader, n_models=10, n_epochs=500, checkpoint_period_epochs=100
)

In [None]:
reconstructed_beams = [ele.beam_generator() for ele in result]

In [None]:
plot_2d_distribution(
    reconstructed_beams, "x", "py", smoothing_factor=1, ci_kws={"vmax": 2}
)

## Evaluate model on test samples to compare predictions

In [None]:
test_pred = torch.stack([ele(test_dset.parameters)[0].detach() for ele in result])

mean, nci = compute_mean_and_confidence_interval(test_pred.cpu())

mean_dset = QuadScanDataset(test_dset.parameters, (mean,), train_dset.screen)
nci_dset = QuadScanDataset(test_dset.parameters, (nci,), train_dset.screen)

In [None]:
fig, ax = test_dset.plot_data(overlay_data=mean_dset)
fig.set_size_inches(20, 3)

### visualize normalized confidence interval

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(
    1, len(test_dset.parameters), figsize=(20, 3), sharex=True, sharey=True
)

xbins, ybins = nci_dset.screen.pixel_bin_centers
xx = torch.meshgrid(xbins * 1e3, ybins * 1e3, indexing="ij")

for i in range(len(test_dset.parameters)):
    c = ax[i].pcolormesh(*xx, nci_dset.observations[0][i], vmin=0, vmax=2)
fig.colorbar(c, ax=ax, label="nci")