# 4-D phase space reconstruction
Here we test 4-d phase space reconstruction using a basic quadrupole scan and a screen. The test uses a dummy phase space distribution in the 4d transverse phase space.

### Python package imports

In [None]:
import torch
import lightning as L
from gpsr.custom_cheetah.screen import Screen
from cheetah import ParticleBeam
from gpsr.modeling import GPSR, GPSRQuadScanLattice
from gpsr.train import LitGPSR
from gpsr.beams import NNParticleBeamGenerator
from gpsr.datasets import QuadScanDataset

### Import test beam distribution and visualize 6d phase space

In [None]:
# create synthetic ground truth beam from base and transform yaml files
p0c = 43.36e6  # reference momentum in eV/c

gt_beam = torch.load("../example_data/example_distributions/complex_beam.pt")
gt_beam = ParticleBeam(gt_beam.particles, gt_beam.energy)

In [None]:
# plot synthetic ground truth beam projections

lims = [
    [-15e-3, 15e-3],
    [-10e-3, 10e-3],
    [-20e-3, 20e-3],
    [-10e-3, 10e-3],
    [-5e-3, 5e-3],
    [-5e-2, 5e-2],
]
gt_beam.plot_distribution(bin_ranges=lims);

### Create the screen diagnostic
Here we use the differentiable Cheetah `Screen`. This screen uses kernel desity estimation to approximate the histogram in order to make it differentiable and vectorized.

In [None]:
# create diagnostic screen:
n_pixels = 100
pixel_size = torch.tensor(1e-3)
screen = Screen(
    resolution=(n_pixels, n_pixels),
    pixel_size=torch.tensor((pixel_size, pixel_size)),
    method="kde",
    kde_bandwidth=pixel_size,
    is_active=True,
)

### Create synthetic data
We track the dummy beam distribution through a dummy lattice (containing a quadrupole and drift) with different quadrupole strengths to get the training images for GPSR.

In [None]:
# create synthetic data
# get model
gpsr_lattice = GPSRQuadScanLattice(l_quad=0.1, l_drift=1.0, diagnostic=screen)

# quadrupole strengths for scan (reshape as: n_quad_strengths x n_images_per_quad_strength x 1)
n_ks = 20
ks = torch.linspace(-25, 15, n_ks).unsqueeze(-1)
gpsr_lattice.set_lattice_parameters(ks)

# propagate beam and get images
images = gpsr_lattice.track_and_observe(gt_beam)[0]

# split in training and testing sets
train_ks = ks[::2]
test_ks = ks[1::2]

train_images = images[::2]
test_images = images[1::2]

train_dset = QuadScanDataset(train_ks, train_images, screen=screen)
test_dset = QuadScanDataset(test_ks, test_images, screen=screen)

train_dset.plot_data();

In [None]:
torch.save(train_dset, "../example_data/example_datasets/reconstruction_4D_train.dset")
torch.save(test_dset, "../example_data/example_datasets/reconstruction_4D_test.dset")

### Define the GPSR model for training
The GPSR model contains the ML-based parameterization of the initial beam distribution `NNParticleBeamGenerator` with 10k particles and the differentiable simulation of the diagnostic lattice (same one used above to generate the training data).

In [None]:
gpsr_model = GPSR(NNParticleBeamGenerator(10000, p0c), gpsr_lattice)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=20)

litgpsr = LitGPSR(gpsr_model)
logger = L.pytorch.loggers.TensorBoardLogger(
    ".",
)

### Perform the reconstruction
This cell performs the reconstruction by varying the parameters of `NNParticleBeamGenerator` to minimize the difference between predicted measurements and the training data. This step will take some time on a CPU but can be greatly accelerated (1-2 orders of magnitude) if using a GPU to do the computation. If you are limited to a CPU I would recommend reducing the number of `max_epochs` to be between 500-1000 to reduce computation time.

In [None]:
trainer = L.Trainer(limit_train_batches=100, max_epochs=2000, logger=logger)
trainer.fit(model=litgpsr, train_dataloaders=train_loader)

### Get the reconstructed beam distribution

In [None]:
reconstructed_beam = litgpsr.gpsr_model.beam_generator()

### Evaluate model on samples to compare predictions
Here we use the trained GPSR model to make predictions that should agree with the training data. The plot below shows the training data as the colormap and uses contour lines to show the predicted measurements at the 10th, 50th, 95th percentiles.

In [None]:
test_pred = gpsr_model(train_dset.parameters)[0].detach()
test_pred_dset = QuadScanDataset(train_dset.parameters, test_pred, screen=screen)
test_pred_dset.observations[0].shape

In [None]:
fig, ax = test_dset.plot_data(overlay_data=test_pred_dset)
fig.set_size_inches(20, 3)

### Compare reconstructed distribution to ground truth distribution
Here we compare 2d projections of each distribution to see agreement. Note that because we are only using a quadrupole to rotate the phase space, projections associated with the longitudinal phase space coordinates will be inaccurate and should be ignored. See the 6d reconstruction examples for a case where we can reconstruct those distributions as well.

In [None]:
reconstructed_beam.plot_distribution(bin_ranges=lims)
gt_beam.plot_distribution(bin_ranges=lims)