# simple 6D PS reconstruction

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from phase_space_reconstruction.virtual.beamlines import quad_tdc_bend
from phase_space_reconstruction.virtual.scans import run_3d_scan_2screens
from phase_space_reconstruction.diagnostics import ImageDiagnostic
from phase_space_reconstruction.visualization import plot_3d_scan_data_2screens
from phase_space_reconstruction.utils import split_2screen_dset
from phase_space_reconstruction.train import train_3d_scan_2screens

from bmadx.distgen_utils import create_beam
from bmadx.plot import plot_projections
from bmadx.constants import PI

from stats import plot_projections_with_contours

## Load data

In [None]:
# load data
data_dir = '/global/homes/j/jpga/AWA_DATA'
save_dir =  'results'
dset = torch.load(os.path.join(data_dir, 'dset.pt'))
train_dset, test_dset = split_2screen_dset(dset)

In [None]:
plot_3d_scan_data_2screens(dset)
plt.show()

In [None]:
plot_3d_scan_data_2screens(train_dset)
plt.show()

In [None]:
plot_3d_scan_data_2screens(test_dset)
plt.show()

## Define diagnostics lattice parameters

In [None]:
# diagnostic beamline:
p0c = 62.0e6
lattice0 = quad_tdc_bend(p0c=p0c, dipole_on=False)
lattice1 = quad_tdc_bend(p0c=p0c, dipole_on=True)

# Scan over quad strength, tdc on/off and dipole on/off
scan_ids = [0, 2, 4] 

# create 2 diagnostic screens: 
def create_screen(size, pixels):
    bins = torch.linspace(-size/2, size/2, pixels)
    bandwidth = (bins[1]-bins[0]) / 2
    return ImageDiagnostic(bins, bins, bandwidth)

screen0 = create_screen(30.22*1e-3, 700)
screen1 = create_screen(26.96*1e-3, 700)

### 10,000 particles

In [None]:
%%time

# training
pred_beam = train_3d_scan_2screens(
    train_dset, 
    lattice0,
    lattice1, 
    p0c, 
    screen0,
    screen1,
    ids = scan_ids,
    n_epochs = 1_000, 
    n_particles = 10_000, 
    device = 'cuda'
    )
pred_beam = pred_beam.numpy_particles()
torch.cuda.empty_cache()

In [None]:
# reconstructed beam projections:
fig, ax = plot_projections(pred_beam,
                           custom_lims = lims,
                           bins = 50
                           )
plt.show()