# 4D GPSR from Matlab

In [None]:
import matlab_parser
import numpy as np
from lcls_tools.common.data.model_general_calcs import bdes_to_kmod
from gpsr.data_processing import process_images
import torch
from ml_tto.automatic_emittance.image_projection_fit import RecursiveImageProjectionFit

from gpsr.modeling import GPSR
from gpsr.train import LitGPSR
from gpsr.beams import NNParticleBeamGenerator, NNTransform

import pandas as pd
import matplotlib.pyplot as plt

import torch
import lightning as L
import time

matlab_fname = "Emittance-scan-OTRS_HTR_330-2025-03-03-231948.mat"
data = matlab_parser.loadmat(matlab_fname)

### get beamline info needed to perform the reconstruction

In [None]:
quad_strengths = data["data"]["quadVal"]
energy = data["data"]["energy"] * 1e9
rmat = torch.tensor(np.array(data["data"]["rMatrix"]))
resolution = data["data"]["dataList"][0]["res"][0] * 1e-6

### Process measured images

In [None]:
images = []
for ele in data["data"]["dataList"]:
    images += [np.array(ele["img"])]

images = np.stack(images).transpose(0, 1, -1, -2)


def fit(image):
    res = RecursiveImageProjectionFit().fit_image(np.array(image))
    return np.array(res.rms_size), np.array(res.centroid)


final_images, meshgrid = process_images(
    images, resolution * 1e6, fit, median_filter_size=None, pool_size=None
)
final_images = np.mean(final_images, axis=1)

print("Final images shape: ", final_images.shape)
print("Quad strengths shape: ", quad_strengths.shape)

### Set up reconstruction data set and cheetah model

In [None]:
from gpsr.datasets import QuadScanDataset
from cheetah.accelerator import Screen

screen = Screen(
    resolution=final_images.shape[1:],
    pixel_size=torch.ones(2) * resolution,
    method="kde",
    kde_bandwidth=torch.tensor(resolution, dtype=torch.float32),
    is_active=True,
)

xbins = torch.tensor(meshgrid[0][0])
ybins = torch.tensor(meshgrid[1].T[0])

train_dset = QuadScanDataset(
    torch.tensor(quad_strengths, dtype=torch.float32).unsqueeze(-1),
    torch.tensor(final_images, dtype=torch.float32),
    screen,
)

In [None]:
train_dset.plot_data()

In [None]:
from gpsr.modeling import GPSRLattice
from cheetah.accelerator import CustomTransferMap, Segment
from typing import Tuple


class RMatLattice(GPSRLattice):
    def __init__(self, rmat):
        super().__init__()
        self.register_parameter(
            "threshold", torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
        )
        self.lattice = Segment(
            [
                CustomTransferMap(rmat),
                screen,
            ]
        )

    def set_lattice_parameters(self, x: torch.Tensor) -> None:
        pass

    def track_and_observe(self, beam) -> Tuple[torch.Tensor, ...]:
        """
        tracks beam through the lattice and returns observations

        Returns
        -------
        results: Tuple[Tensor]
            Tuple of results from each measurement path
        """
        self.lattice.elements[-1].pixel_size = self.lattice.elements[-1].pixel_size.to(
            beam.x
        )
        beam.particle_charges = torch.ones_like(beam.x).to(device=beam.x.device)
        self.lattice.track(beam)

        observations = self.lattice.elements[-1].reading.transpose(-1, -2)

        # clip observations
        observations = torch.clip(observations - self.threshold * 1e-3, 0, None)

        return tuple(observations.unsqueeze(0))

## Set up training

In [None]:
## Reconstruction hyperparameters
learning_rate = 1e-2  # learning rate of the optimizer
num_epochs = 1000  # number of epochs
n_hidden = 4  # number of hidden layers in the neural network (more-->more complex)

# scale of the output beam distribution
# (should be smaller than the scale size of the beam,
# for example reconstructing a beam of ~ 100 um size requires a scale of 1e-4)
output_scale = 1e-4

In [None]:
R = torch.eye(7).repeat(len(rmat), 1, 1)
R[:, :6, :6] = rmat
R[:, :2, :2]

gpsr_lattice = RMatLattice(R.to(dtype=torch.float32))

p0c = torch.tensor(energy).to(dtype=torch.float32)
gpsr_model = GPSR(
    NNParticleBeamGenerator(
        10000, p0c, transformer=NNTransform(n_hidden, 20, output_scale=output_scale)
    ),
    gpsr_lattice,
)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=20)

litgpsr = LitGPSR(gpsr_model, lr=learning_rate)

### Run training

In [None]:
# create a pytorch lightning trainer
logger = L.pytorch.loggers.CSVLogger(".")
trainer = L.Trainer(limit_train_batches=100, max_epochs=num_epochs, logger=logger)

In [None]:
# run the training
start = time.time()
trainer.fit(model=litgpsr, train_dataloaders=train_loader)
print(time.time() - start)

In [None]:
# visualize the loss function as a function of the epoch
trial_indicies = [0]  # indicies of the pytorch lightning trials to plot

fig, ax = plt.subplots()
min_losses = []
for ele in trial_indicies:
    metrics = pd.read_csv(f"./lightning_logs/version_{ele}/metrics.csv")
    ax.plot(metrics.epoch, metrics.loss_epoch, label=ele)
    min_losses.append(metrics.loss_epoch.min())

ax.legend()
ax.set_yscale("log")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")

### Analyze reconstruction results

In [None]:
# get the reconstructed beam distribution
reconstructed_beam = litgpsr.gpsr_model.beam_generator()

# predict the measurements to compare with training data
pred = gpsr_model(train_dset.parameters)[0].detach()
pred_dset = QuadScanDataset(train_dset.parameters, pred, screen)

In [None]:
# compare the predicted measurements with the training data
fig, ax = train_dset.plot_data(
    overlay_data=pred_dset,
    overlay_kwargs={"levels": [0.01, 0.25, 0.75], "cmap": "Greys"},
    filter_size=0,
)
fig.set_size_inches(20, 4)
fig.tight_layout()

for ele in [train_dset, pred_dset]:
    fig, ax = ele.plot_data()
    fig.set_size_inches(20, 4)
    fig.tight_layout()

In [None]:
# plot the 4D phase space of the reconstructed beam
reconstructed_beam.plot_distribution(dimensions=["x", "px", "y", "py"]);

In [None]:
# get the reconstructed beam emittances and twiss parameters
print(reconstructed_beam.emittance_x * energy / 0.511e6)
print(reconstructed_beam.emittance_y * energy / 0.511e6)
print(reconstructed_beam.beta_x)
print(reconstructed_beam.beta_y)
print(reconstructed_beam.alpha_x)
print(reconstructed_beam.alpha_y)