In [1]:
# compare transport matricies for quad and sextupole elements
import torch
from phase_space_reconstruction.virtual.beamlines import sextupole_drift, quad_drift
from bmadx.bmad_torch.taylor_map import get_transport_matrix

lattice = sextupole_drift()
lattice.elements[0].K2.data = torch.tensor(1000.0)


In [2]:
from torch.autograd.functional import jacobian, hessian
from bmadx.structures import Particle

def get_taylor_coeff(lattice, s, p0c, mc2):
    def f(x):
        return lattice(Particle(*x, s, p0c, mc2))[:6]

    def df(x):
        return jacobian(f, x, create_graph=True)

    def ddf(x):
        return jacobian(df, x, create_graph=True)

    def dddf(x):
        return jacobian(ddf, x, create_graph=True)

    J = df(torch.zeros(6))
    H = ddf(torch.zeros(6))
    T = dddf(torch.zeros(6))

    return J, H,T
    #matrix_elements = []
    #for ele in J[:6]:
    #     if lattice.batch_shape == torch.Size():
    #         matrix_elements += [ele.unsqueeze(0)]
    #     else:
    #         if ele.shape == torch.Size([*lattice.batch_shape, 6]):
    #             matrix_elements += [ele.unsqueeze(-2)]
    #         elif ele.shape == torch.Size([6]):
    #             matrix_elements += [ele.repeat(*lattice.batch_shape, 1, 1)]
    #         else:
    #             raise RuntimeError("unhandled shape for jacobian")
    #
    # return torch.cat(matrix_elements, dim=-2)

In [3]:
J, H, T = get_taylor_coeff(
    lattice, torch.tensor(0.0), torch.tensor(10e6), torch.tensor(0.511e6))

In [4]:
J

(tensor([1.0000, 1.1000, 0.0000, 0.0000, 0.0000, -0.0000],
        grad_fn=<ViewBackward0>),
 tensor([0., 1., 0., 0., 0., -0.], grad_fn=<ViewBackward0>),
 tensor([0.0000, 0.0000, 1.0000, 1.1000, 0.0000, -0.0000],
        grad_fn=<ViewBackward0>),
 tensor([0., 0., 0., 1., 0., -0.], grad_fn=<ViewBackward0>),
 tensor([0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0029],
        grad_fn=<ViewBackward0>),
 tensor([0., 0., 0., 0., 0., 1.]))

In [5]:
H

(tensor([[-104.0000,   -4.0800,    0.0000,    0.0000,    0.0000,    0.0000],
         [  -4.0800,   -0.2432,    0.0000,    0.0000,    0.0000,   -1.0000],
         [   0.0000,    0.0000,  104.0000,    4.0800,    0.0000,    0.0000],
         [   0.0000,    0.0000,    4.0800,    0.2432,    0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000],
         [   0.0000,   -1.0000,    0.0000,    0.0000,    0.0000,    0.0000]],
        grad_fn=<ViewBackward0>),
 tensor([[-100.0000,   -4.0000,    0.0000,    0.0000,    0.0000,    0.0000],
         [  -4.0000,   -0.2400,    0.0000,    0.0000,    0.0000,    0.0000],
         [   0.0000,    0.0000,  100.0000,    4.0000,    0.0000,    0.0000],
         [   0.0000,    0.0000,    4.0000,    0.2400,    0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000],
         [   0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000]],
        grad_fn=<ViewBackward0>),
 tenso

In [6]:
T

(tensor([[[ 2.4240e+02,  7.2480e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            2.0400e+02],
          [ 7.2480e+00,  1.4445e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            8.0800e+00],
          [ 0.0000e+00,  0.0000e+00,  8.0800e+01,  5.6416e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  5.6416e+00,  9.9520e-02,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 2.0400e+02,  8.0800e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00]],
 
         [[ 7.2480e+00,  1.4445e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            8.0800e+00],
          [ 1.4445e-01,  3.0052e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            4.8320e-01],
          [ 0.0000e+00,  0.0000e+00, -4.0352e+00,  2.2464e-02,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  2.2464e-02,  1.0017e+00,  0.0000e+00,
            

In [7]:
qlattice = quad_drift()
qlattice.elements[0].K1.data = torch.tensor(10.0)

In [8]:
J, H, T = get_taylor_coeff(
    qlattice, torch.tensor(0.0), torch.tensor(10e6), torch.tensor(0.511e6))

In [9]:
J

(tensor([-0.0330,  1.0488,  0.0000,  0.0000,  0.0000,  0.0000],
        grad_fn=<ViewBackward0>),
 tensor([-0.9834,  0.9504, -0.0000,  0.0000,  0.0000,  0.0000],
        grad_fn=<ViewBackward0>),
 tensor([0.0000, 0.0000, 2.0672, 1.1521, 0.0000, 0.0000],
        grad_fn=<ViewBackward0>),
 tensor([-0.0000, 0.0000, 1.0168, 1.0504, 0.0000, 0.0000],
        grad_fn=<ViewBackward0>),
 tensor([0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0029],
        grad_fn=<ViewBackward0>),
 tensor([0., 0., 0., 0., 0., 1.]))

In [10]:
H

(tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0161],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.9979],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 1.0161, -0.9979,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<ViewBackward0>),
 tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0165],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0492],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [-0.0165,  0.0492,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<ViewBackward0>),
 tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  

In [11]:
T

(tensor([[[-2.8532,  2.7575,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 2.7575, -2.6649,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.0166, -1.0503,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.0503, -1.0851,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.9985]],
 
         [[ 2.7575, -2.6649,  0.0000,  0.0000,  0.0000,  0.0000],
          [-2.6649,  2.5755,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.9825,  1.0151,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0151,  1.0487,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.8951]],
 
         [[ 0.0000,  0.0000, -1.0166, -1.0503,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.9825,  1.0151,  0.0000,  0.0000],
          [-1.0166,  0.9825,  0.0000,  0.0000,  0.0000,  0.0000],
    