# PS reconstruction, Gaussian beam

In [None]:
import torch
import numpy as np
from copy import deepcopy

import lightning as L
from cheetah.particles import ParticleBeam
from cheetah.utils.bmadx import bmad_to_cheetah_coords
from cheetah.accelerator import Quadrupole, Drift

from phase_space_reconstruction.diagnostics import ImageDiagnostic
from phase_space_reconstruction.datasets import SixDReconstructionDataset
from phase_space_reconstruction.modeling import GPSR6DLattice, GPSR
from phase_space_reconstruction.train import LitGPSR
from phase_space_reconstruction.beams import NNParticleBeamGenerator


## Create synthetic ground truth beam

In [None]:
# create synthetic ground truth beam from base and transform yaml files
p0c = 43.36e6 # reference momentum in eV/c

particle_data = torch.load("data/gaussian_beam.pt").data[:10000]
bmad_gt_beam = torch.load("data/gaussian_beam.pt")
bmad_gt_beam.data = bmad_gt_beam.data[:10000]

particle_data = bmad_to_cheetah_coords(
    particle_data,
    torch.tensor(p0c),
    torch.tensor(0.511e6)
)
gt_beam = ParticleBeam(*particle_data)


## Define diagnostics lattice parameters

In [None]:
lims = np.array([[-15e-3, 15e-3],
                 [-15e-3, 15e-3],
                 [-15e-3, 15e-3],
                 [-15e-3, 15e-3],
                 [-5e-3, 5e-3],
                 [-5e-2, 5e-2]]) *0.2
gt_beam.plot_distribution(
    custom_lims = lims,
)

In [None]:
# create diagnostic screens:
bins = torch.linspace(-5, 5, 200) * 1e-3
bandwidth = (bins[1]-bins[0]) / 2
screen = ImageDiagnostic(
    bins, bins, bandwidth
)

Cheetah based version

In [None]:
# define upstream beamline
k1 =            -24.868402
k2 =             26.179029
k3 =            -26.782126

quad_length = 0.11
lq12 = 1.209548
lq23 = 0.19685
lq34 = 0.18415
ld1 = lq12 - quad_length
ld2 = lq23 - quad_length
ld3 = lq34 - quad_length
q1 = Quadrupole(
    length=quad_length,
    k1=k1,
    num_steps=5,
    tracking_method="bmadx"
)
d1 = Drift(ld1)
q2 = Quadrupole(
    length=quad_length,
    k1=k2,
    num_steps=5,
    tracking_method="bmadx"
)
d2 = Drift(ld2)
q3 = Quadrupole(
    length=quad_length,
    k1=k3,
    num_steps=5,
    tracking_method="bmadx"
)
d3 = Drift(ld3)
upstream_components = [q1, d1, q2, d2, q3, d3]

l_quad=0.11
l_tdc=0.01
f_tdc=1.3e9
phi_tdc=0.0
l_bend=0.3018
theta_on=- 20.0 * 3.14 / 180.0
l1=0.790702
l2=0.631698
l3=0.889

gpsr_lattice = GPSR6DLattice(
    l_quad,
    l_tdc,
    f_tdc,
    phi_tdc,
    l_bend,
    theta_on,
    l1,
    l2,
    l3,
    deepcopy(screen),
    deepcopy(screen),
    upstream_components
)



In [None]:
# Scan over quad strength, tdc on/off and dipole on/off
n_ks = 5
PI = 3.14
ks = torch.linspace(-3, 3, n_ks) # quad ks
vs = torch.tensor([0, 3e6]) # TDC off/on
gs = torch.tensor([-2.22e-16, -20.0*PI/180.0/0.365]) # dipole off/on

scan_ids = [6, 8, 10]
n_ks = 5

train_params = torch.meshgrid(gs, vs, ks, indexing='ij')
train_params = torch.stack(train_params, dim=-1)

# create training data
gpsr_lattice.set_lattice_parameters(train_params)
final_beam = gpsr_lattice.lattice(gt_beam)
obs = gpsr_lattice.track_and_observe(gt_beam)

obs_dataset = SixDReconstructionDataset(train_params, obs, (bins, bins))

In [None]:
fig, ax = obs_dataset.plot_data(publication_size=True)

## Training using pytorch lightning

In [None]:
gpsr_model = GPSR(NNParticleBeamGenerator(10000, p0c), gpsr_lattice)
train_loader = torch.utils.data.DataLoader(obs_dataset, batch_size=20)

litgpsr = LitGPSR(gpsr_model)
logger = L.pytorch.loggers.TensorBoardLogger('.', )
trainer = L.Trainer(limit_train_batches=100,max_epochs=500, logger=logger)
trainer.fit(model=litgpsr, train_dataloaders=train_loader)

In [None]:
reconstructed_beam = litgpsr.gpsr_model.beam_generator()

In [None]:
reconstructed_beam.plot_distribution(
    custom_lims = lims,
)