In [15]:
import os, sys
sys.path.insert(0, os.path.abspath('/Users/nneveu/github/lume-astra'))
sys.path.insert(0, os.path.abspath('/Users/nneveu/github/distgen'))

In [16]:
import numpy as np
import matplotlib.pyplot as plt
#plt.style.use('petrstyle.txt')
from matplotlib.ticker import NullFormatter
import glob, sys, h5py
import scipy.io

from astra import Astra#, template_dir
from astra.plot import plot_fieldmaps, plot_stats, plot_stats_with_layout
import distgen
from distgen import Generator
from distgen.writers import *
from pmd_beamphysics import ParticleGroup
from pmd_beamphysics.plot import marginal_plot

#slice plots
# from h5py import File
# from pmd_beamphysics.interfaces import opal
# from pmd_beamphysics.plot import slice_plot
# from pmd_beamphysics.plot import marginal_plot, density_plot

In [17]:
def parse_opal_emitted_dist(filename, names=['x', 'px','y','py','t','pz']):
    '''Read in particle distribution used 
    in OPAL-T simulation. Used to describe the 
    beam distribution as it leaves the cathode.'''
    dist = {}
    data = np.loadtxt(filename, skiprows=1)
    for i,name in enumerate(names):
        dist[name] = data[:,i]
    return dist
def parse_astra_dist(filename, header=['x', 'y', 'z', 'px','py', 'pz', 't', 'Q', 'ptype', 'flag']):
    '''
    Read in initial particle distribution used 
    in ASTRA simulation. Used to describe the 
    beam distribution as it leaves the cathode.
   
    t     = time in ns
    Q     = macro charge
    ptype = particle type (electron: 1)
    flag  = particle location (cathode: -1)
    '''
    #print(filename)
    data = pd.read_csv(filename, delim_whitespace=True, names=header)
    #print(data)
    #only return non traj probe particles at cathodeprint(astradist['z'])
    dist = data[data.flag == -1]
    return dist #dist

In [18]:
def make_tri(n, xmin, xmax, x):

    half = int(n/2)
    yr  = np.zeros(half*2)
    #step 1
    nr = int(n)
    xr = np.random.rand(nr)

    for j in range(0,nr):
        #step 3
        if j < half:
            yr[j] = (1- np.sqrt(1-xr[j]))*(xmax-xmin)  
        elif j >= half:
        #step 4 
            yr[j] = (-1 + np.sqrt(1-xr[j]))*(xmax-xmin) 

    #showplot(xr, yr)
    for k in range(0,len(x)):
        #print(y[k])
        if (-yr[k]/8 <= x[k] <= yr[k]/8):
            pass
            #x[k]=0
        else:
            #print(x[k])
            x[k]=0
            #pass

    #showplot(xr,x)
    return x, xr

In [19]:
# From Chris: https://github.com/slaclab/lcls-lattice/tree/master/distgen/models/cu_inj/vcc_image
def write_distgen_xy_dist(filename, image, resolution, resolution_units='m'):
    """
    Writes image data in distgen's xy_dist format
    
    Returns the absolute path to the file written
    
    """
    
    # Get width of each dimension
    widths = resolution * np.array(image.shape)
    
    # Form header
    header = f"""x {widths[1]} {widths[1]/2} [{resolution_units}]
y {widths[0]} {widths[0]/2}  [{resolution_units}]"""
    
    # Save with the correct orientation
    np.savetxt(filename, np.flip(image, axis=0), header=header, comments='')
    
    return os.path.abspath(filename)

In [20]:
#Edit this line to match your path
%env ASTRA_BIN=/Users/nneveu/Code/astra/Astra

env: ASTRA_BIN=/Users/nneveu/Code/astra/Astra


# Laser 

In [21]:
vcc = glob.glob('../laser_images/*.mat')

In [22]:
mat = scipy.io.loadmat(vcc[0])

IndexError: list index out of range

In [None]:
# mat
# figure;imagesc(reshape(X(1,:),[45 45]))

In [None]:
# mat['X'].shape

In [None]:
# mat['Y'].shape

In [None]:
arr = mat['X'][1,:]
dim = arr.shape[0]
print(dim)
xy = int(np.sqrt(dim))
nrow = xy
ncol = xy
IMAGE = arr.reshape(nrow, ncol)


In [None]:
plt.imshow(IMAGE)#, extent=[0,1,0,1])
plt.ylabel('Pixel', size=20)
plt.xlabel('Pixel', size=20)
plt.savefig('DMD_image.png', dpi=300, bbox_inches='tight')

In [None]:
FOUT = write_distgen_xy_dist('laser_image.txt', IMAGE, xy, resolution_units='um')

In [None]:
dist_file = 'distgen_laser.yaml'
dist = Generator(input=dist_file, verbose=False)
dist.input['n_particle'] = int(1e5)
# dist.input['start']['MTE']['value'] = 150 #330 

fwhm = 0.06/2.355 # 60 fs = 0.06 ps
dist.input['t_dist']['sigma_t']['value'] = fwhm
dist.input['total_charge']['value'] = 10

# dist.input['r_dist']['max_r']['value'] = 0.5
dist.run()
particles = dist.particles 
particles.write_astra('astra_particles.txt')
print(dist)


In [None]:
particles.plot('x','y')

In [None]:
num_bins = 100
plt.figure(figsize=(5, 5))
astradist = parse_opal_emitted_dist('astra_particles.txt', names=['x', 'y', 'z', 'px','py', 'pz', 't', 'Q', 'ptype', 'flag'])
xy = plt.hist2d(astradist['x']*10**3, astradist['y']*10**3, num_bins, facecolor='blue', cmin=2)#, alpha=0.5)
plt.ylabel('[mm]', size=20)
plt.xlabel('[mm]', size=20)
plt.savefig('DMD_electrons.png', dpi=300)


# Run ASTRA

In [None]:
astra_file = 'xta.in'
xta = Astra(initial_particles=particles, input_file=astra_file, verbose=True)

In [None]:


xta.input['solenoid']['maxb(1)'] = 0.45
xta.input['solenoid']

In [None]:
xta.input['charge']

In [None]:
xta.input['newrun']['zstart'] = 0.0
xta.input['newrun']['zstop'] = 5.5
xta.input['charge']['lspch'] = False
xta.input['newrun']['zphase'] = 50

In [None]:
xta.input['charge']

In [None]:
xta.run()

In [None]:
# xta.input_file

In [None]:
#xta.archive('xta_60fs_1mm_10pC_noSC_laser_weak_solenoid_0.45_zstop_5.5.h5')

In [None]:
xta.load_archive('archives/xta_60fs_1mm_10pC_SC_laser_weak_solenoid_0.45_zstop_5.5.h5')

In [None]:

test = xta.particles[-1].plot('energy')
plt.savefig('test.png', dpi=250)

In [None]:
#xta.load_archive('archives/xta_60fs_1mm_1pC_SC_laser_test.h5')
plt.figure(figsize=(5, 5))
# xta.particles[-1].plot('x','y')
xtapart = xta.particles[-1]
num_bins = 100
xy = plt.hist2d(xtapart.x*10**3, xtapart.pz*10**3, num_bins, facecolor='blue', cmin=1)#, alpha=0.5)
print(np.mean(xtapart.z))
plt.ylabel('[mm]', size=20)
plt.xlabel('X [mm]', size=20)
plt.savefig('DMD_xta_SC_sol0.45_5.5meters_LDRDpxpy.png', dpi=300, bbox_inches='tight')

In [None]:
len(xta.particles)

In [None]:
plt.plot(xta.stat('mean_z'), xta.stat('sigma_z')*10**3, '-', label="$\sigma_z$")
plt.plot(xta.stat('mean_z'), xta.stat('sigma_x')*10**3, '-', label="$\sigma_x$")
plt.xlabel('z (m)')
plt.ylabel('Beam size (mm)')
plt.legend(loc='upper right')
plt.ylim(0,0.6)
plt.xlim(0,6)
plt.grid()

In [None]:
plt.plot(xta.stat('mean_z'), xta.stat('mean_kinetic_energy')*10**-6, '-', label="Energy")
# plt.plot(xta.stat('mean_z'), xta.stat('sigma_x')*10**3, '-', label="$\sigma_x$")
plt.xlabel('z (m)')
plt.ylabel('Beam size (mm)')
plt.legend(loc='upper right')
# plt.ylim(0,0.6)
# plt.xlim(0,0.7)
plt.grid()

# Movie