In [1]:
import matplotlib.pyplot as plt
from distgen import Generator

# create beam
gen = Generator("beams/gaussian.yaml")
gen.run()
gen

ModuleNotFoundError: No module named 'distgen'

In [None]:
import numpy as np
particles = gen.particles

particles.plot("x","px")

In [None]:
# define tracking function
from track import Particle
import torch

keys = ["x","xp","y","xp","z","pz"]
defaults = {
        "s": torch.tensor(0.0),
        "p0c": torch.mean(torch.tensor(particles.gamma) * particles.mass),
        "mc2": torch.tensor(particles.mass),
    }
ground_truth_in = Particle(
    *[torch.tensor(getattr(particles, ele)) for ele in keys], **defaults
)


In [None]:
from normalizing_flow import get_images, image_difference_loss

# get ground truth images
sizes = []
k_in = torch.linspace(-1,1,5) * 1e7
bins = torch.linspace(-20,20,100)*1e-3
bandwidth = torch.tensor(1e-3)
images = get_images(ground_truth_in, k_in, bins, bandwidth)

for image in images[:1]:
    fig,ax = plt.subplots()
    xx = torch.meshgrid(bins, bins)
    ax.pcolor(xx[0], xx[1], image)

# test loss
image_difference_loss(ground_truth_in, ground_truth_in, k_in, bins, bandwidth)

In [None]:
# define a unit multivariate normal
from torch.distributions import MultivariateNormal
normal_samples = MultivariateNormal(torch.zeros(6), torch.eye(6)).sample([1000])*1e-3

# define nonparametric transform
from normalizing_flow import NonparametricTransform
tnf = NonparametricTransform()

# plot initial beam after transform
p_in_guess = Particle(*tnf(normal_samples).T.double(), **defaults)
plt.hist2d(p_in_guess.x, p_in_guess.px,bins=100)



In [None]:
# preform optimization with Adam
optim = torch.optim.Adam(tnf.parameters(), lr=0.001)
n_iter = 1
losses = []
for i in range(n_iter):
    optim.zero_grad()

    loss = image_difference_loss(
        ground_truth_in,
        Particle(*tnf(normal_samples).T.double(), **defaults),
        k_in,
        bins=bins,
        bandwidth=bandwidth,
        plot_images=True
    )
    losses.append(loss.detach())
    loss.backward()

    optim.step()

    if i % 10 == 0:
        print(i)

fig, ax = plt.subplots()
ax.plot(losses)