# PS reconstruction, Gaussian beam

In [None]:
import numpy as np
import torch
import lightning as L

from gpsr.datasets import SixDReconstructionDataset, split_dataset
from gpsr.modeling import GPSR6DLattice, GPSR
from gpsr.train import LitGPSR
from gpsr.beams import NNParticleBeamGenerator

## Import data

In [None]:
dset = torch.load(
    "example_data/example_datasets/reconstruction_6D.dset", weights_only=False
)

print(
    dset.parameters.shape,
    dset.six_d_parameters.shape,
    dset.observations[0].shape,
    dset.six_d_observations[0].shape,
    type(dset.screens[0]),
)

dset.plot_data(publication_size=True);

In [None]:
train_k_ids = np.arange(0, len(dset.six_d_parameters), 2)
train_dset, test_dset = split_dataset(dset, train_k_ids)

In [None]:
train_dset.plot_data(publication_size=True);

In [None]:
test_dset.plot_data(publication_size=True);

## Set up diagnostic lattice in cheetah
These parameters match those at the Argonne Wakefield Accelerator (AWA) and were used
 to generate the synthetic example dataset.

In [None]:
p0c = 43.36e6  # reference momentum in eV/c

screens = train_dset.screens

l_quad = 0.11
l_tdc = 0.01
f_tdc = 1.3e9
phi_tdc = 0.0
l_bend = 0.3018
theta_on = -20.0 * 3.14 / 180.0
l1 = 0.790702
l2 = 0.631698
l3 = 0.889

gpsr_lattice = GPSR6DLattice(
    l_quad, l_tdc, f_tdc, phi_tdc, l_bend, theta_on, l1, l2, l3, *screens
)

## Training using pytorch lightning

In [None]:
gpsr_model = GPSR(NNParticleBeamGenerator(10000, p0c), gpsr_lattice)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=10)

litgpsr = LitGPSR(gpsr_model)
logger = L.pytorch.loggers.TensorBoardLogger(
    ".",
)
trainer = L.Trainer(max_epochs=500, logger=logger)
trainer.fit(model=litgpsr, train_dataloaders=train_loader)

In [None]:
reconstructed_beam = litgpsr.gpsr_model.beam_generator()

In [None]:
test_params = test_dset.six_d_parameters
pred = tuple([ele.detach() for ele in gpsr_model(test_params)])
pred_dset = SixDReconstructionDataset(test_params, pred, test_dset.screens)
test_dset.plot_data(
    publication_size=True,
    overlay_data=pred_dset,
    overlay_kwargs={"cmap": "Greys", "levels": [0.1, 0.5, 0.9]},
);

In [None]:
reconstructed_beam.plot_distribution();