# 4-D phase space reconstruction
Here we test 4-d phase space reconstruction using a basic quadrupole scan and a screen. The test uses a dummy phase space distribution in the 4d transverse phase space.

### Python package imports

In [None]:
import numpy as np
import torch
import lightning as L
from gpsr.modeling import GPSR, GPSRQuadScanLattice
from gpsr.train import LitGPSR
from gpsr.beams import NNParticleBeamGenerator
from gpsr.datasets import QuadScanDataset, split_dataset

### Load data

Load measurment dataset and split into train and test datasets

In [None]:
dset = torch.load(
    "example_data/example_datasets/reconstruction_4D.dset", weights_only=False
)
print(
    dset.parameters.shape,
    dset.observations[0].shape,
    dset.screen,
)
dset.plot_data();

In [None]:
train_k_ids = np.arange(0, len(dset.parameters), 2)
train_dset, test_dset = split_dataset(dset, train_k_ids)

In [None]:
train_dset.plot_data();

In [None]:
test_dset.plot_data();

### Create the quadrupole scan lattice
Here we use the differentiable Cheetah `Screen`. This screen uses kernel desity estimation to approximate the histogram in order to make it differentiable and vectorized.

In [None]:
# print screen information
print(train_dset.screen)
# create diagnostic lattice
p0c = 43.36e6  # reference momentum in eV/c
gpsr_lattice = GPSRQuadScanLattice(l_quad=0.1, l_drift=1.0, screen=train_dset.screen)

### Define the GPSR model for training
The GPSR model contains the ML-based parameterization of the initial beam distribution `NNParticleBeamGenerator` with 10k particles and the differentiable simulation of the diagnostic lattice (same one used above to generate the training data).

In [None]:
gpsr_model = GPSR(NNParticleBeamGenerator(10000, p0c), gpsr_lattice)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=10)

litgpsr = LitGPSR(gpsr_model)
logger = L.pytorch.loggers.TensorBoardLogger(
    ".",
)

### Perform the reconstruction
This cell performs the reconstruction by varying the parameters of `NNParticleBeamGenerator` to minimize the difference between predicted measurements and the training data. This step will take some time on a CPU but can be greatly accelerated (1-2 orders of magnitude) if using a GPU to do the computation. If you are limited to a CPU I would recommend reducing the number of `max_epochs` to be between 500-1000 to reduce computation time.

In [None]:
trainer = L.Trainer(limit_train_batches=100, max_epochs=500, logger=logger)
trainer.fit(model=litgpsr, train_dataloaders=train_loader)

### Get the reconstructed beam distribution

In [None]:
reconstructed_beam = litgpsr.gpsr_model.beam_generator()

### Evaluate model on samples to compare predictions
Here we use the trained GPSR model to make predictions that should agree with the training data. The plot below shows the training data as the colormap and uses contour lines to show the predicted measurements at the 10th, 50th, 95th percentiles.

In [None]:
test_pred = gpsr_model(test_dset.parameters)[0].detach()
test_pred_dset = QuadScanDataset(test_dset.parameters, (test_pred,), train_dset.screen)

In [None]:
fig, ax = test_dset.plot_data(overlay_data=test_pred_dset)
fig.set_size_inches(20, 3)

In [None]:
reconstructed_beam.plot_distribution();