<div class="alert alert-block alert-info">
<p style="font-size:24px;text-align:center"><b>Simulates a multi-crystal amplifier with BELLA Center's parameters</b>
<br>A 3-crystal configuration, with 4, 3, 3 passes per crystal
</div>

In [None]:
# Laser pulse
num_laser_slices = 10

# Crystal
num_crystal_slices = 10

gain = 1 
radial_n2 = 0
prop_type = 'n0n2_srw'

In [None]:
def plot_all(laser_pulse):
    wfr_temp = laser_pulse.slice[0].wfr
    wfr_x = np.linspace(wfr_temp.mesh.xStart,wfr_temp.mesh.xFin,wfr_temp.mesh.nx)
    wfr_y = np.linspace(wfr_temp.mesh.yStart,wfr_temp.mesh.yFin,wfr_temp.mesh.ny)

    e_total_init = laser_pulse.extract_total_2d_elec_fields()
    intens_2d_init = 0.5 *const.c *const.epsilon_0 *(e_total_init.re**2.0 + e_total_init.im**2.0)
    intens0 = intens_2d_init.flatten()

    uti_plot.uti_plot2d1d(
        intens0,
        [wfr_temp.mesh.xStart, wfr_temp.mesh.xFin, wfr_temp.mesh.nx],
        [wfr_temp.mesh.yStart, wfr_temp.mesh.yFin, wfr_temp.mesh.ny],
        0,
        0,
        ['Horizontal Position', 'Vertical Position', 'Intensity'],
        ['m', 'm', 'ph/s/.1%bw/mm^2'],
        True)

    with plt.style.context(('seaborn-poster')):
        fig = plt.figure(figsize=(4.6*1.5,3.6*1.5))
        ax = fig.gca()
        plt.pcolormesh(wfr_x*(1e3), wfr_y*(1e3), intens_2d_init, cmap=plt.cm.viridis, shading='auto')
        plt.colorbar()
        ax.set_ylabel(r'Vertical Position [mm]')
        ax.set_xlabel(r'Horizontal Position [mm]')
        ax.set_title('Intensity')
        plt.show()
    
    phase_2d_init = laser_pulse.extract_total_2d_phase()
    phase0 = phase_2d_init.flatten()

    uti_plot.uti_plot2d1d(
        phase0,
        [wfr_temp.mesh.xStart, wfr_temp.mesh.xFin, wfr_temp.mesh.nx],
        [wfr_temp.mesh.yStart, wfr_temp.mesh.yFin, wfr_temp.mesh.ny],
        0,
        0,
        ['Horizontal Position', 'Vertical Position', 'Phase'],
        ['m', 'm', ''],
        True)

    with plt.style.context(('seaborn-poster')):
        fig = plt.figure(figsize=(4.6*1.5,3.6*1.5))
        ax = fig.gca()
        plt.pcolormesh(wfr_x*(1e3), wfr_y*(1e3), phase_2d_init, cmap=plt.cm.viridis, shading='auto')
        plt.colorbar()
        ax.set_ylabel(r'Vertical Position [mm]')
        ax.set_xlabel(r'Horizontal Position [mm]')
        ax.set_title('Phase')
        plt.show()

#### Imports

In [None]:
import sys, copy
import numpy as np
from pykern.pkcollections import PKDict

# The rslaser library may not be installed, so a check is required.
try:
    import rslaser
except:
    # Developers should use 'pip install -e .' from the command line.
    # Users can install directly from GitHub --
    !{sys.executable} -m pip install git+https://github.com/radiasoft/rslaser.git
    import rslaser

from rsmath import lct as rslct
from rslaser.pulse import pulse
from rslaser.optics import crystal
from rslaser.optics import drift
from rslaser.optics import lens
import rslaser.utils.srwl_uti_data as srwutil
import uti_plot

import srwlib

import scipy.constants as const
from scipy import interpolate
from scipy.interpolate import RectBivariateSpline
from scipy import special

# 2D plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm

# reset the notebook style
mpl.rcParams.update(mpl.rcParamsDefault)
plt.rcParams['pcolor.shading'] ='auto'
%matplotlib inline

### Laser Pulse

In [None]:
w0 = 3.056e-3
l_scale = np.sqrt(np.pi)*w0

params = PKDict(
    photon_e_ev = 1.5498, # Photon energy [eV], calculated from 800nm wavelength
    nslice      = num_laser_slices,
    pulseE      = 0.002,
    tau_fwhm    = 300.0e-12 / np.sqrt(2.),
    tau_0       = 35.0e-15 / np.sqrt(2.),
    sigx_waist  = w0,
    sigy_waist  = w0,
    num_sig_trans=10,
    nx_slice=512,
    phase_flatten_cutoff=0.7,
)

### Crystal 1 ('Amp2')

In [None]:
crystal_params_1 = PKDict(
    length      = 0.025, # [m]
    nslice      = num_crystal_slices,
    n0          = [1.76 for _ in range(num_crystal_slices)],
    n2          = [0.00 for _ in range(num_crystal_slices)],
    l_scale     = np.sqrt(np.pi)*w0,
    pop_inversion_n_cells=64,
    pop_inversion_mesh_extent=0.01/2.0,  # [m]
    pop_inversion_crystal_alpha=106.4,  # [1/m], 1.064 1/cm
    pop_inversion_pump_waist=3.82e-3,  # [m]
    pop_inversion_pump_wavelength=532.0e-9,  # [m]
    pop_inversion_pump_energy=0.82,  # [J], pump laser energy onto the crystal
    pop_inversion_pump_type="dual",
    pop_inversion_pump_gaussian_order=8,
    pop_inversion_pump_rep_rate=1.0,
)

crystal_1 = crystal.Crystal(crystal_params_1)
n0_slice_array, n2_slice_array1, full_crystal_abcd_mat_1 = crystal_1.calc_n0n2(method="fenics", set_n=True)

if prop_type == 'abcd_lct':
    crystal_params_1.A = full_crystal_abcd_mat_1[0][0]
    crystal_params_1.B = full_crystal_abcd_mat_1[0][1]
    crystal_params_1.C = full_crystal_abcd_mat_1[1][0]
    crystal_params_1.D = full_crystal_abcd_mat_1[1][1]
    crystal_params_1.nslice = 1
    crystal_params_1.n0 = [1.76]
    crystal_params_1.n2 = [0.00]
    crystal_1 = crystal.Crystal(crystal_params_1)

### Crystal 2 ('Amp3')

In [None]:
crystal_params_2 = PKDict(
    length      = 0.025, # [m]
    nslice      = num_crystal_slices,
    n0          = [1.76 for _ in range(num_crystal_slices)],
    n2          = [0.00 for _ in range(num_crystal_slices)],
    l_scale     = np.sqrt(np.pi)*w0,
    pop_inversion_n_cells=64,
    pop_inversion_mesh_extent=0.015/2.0,  # [m]
    pop_inversion_crystal_alpha=106.4,  # [1/m], 1.064 1/cm
    pop_inversion_pump_waist=7.3e-3,  # [m]
    pop_inversion_pump_wavelength=532.0e-9,  # [m]
    pop_inversion_pump_energy=2.75,  # [J], pump laser energy onto the crystal
    pop_inversion_pump_type="dual",
    pop_inversion_pump_gaussian_order=8,
    pop_inversion_pump_rep_rate=1.0,
)

crystal_2 = crystal.Crystal(crystal_params_2)
n0_slice_array, n2_slice_array2, full_crystal_abcd_mat_2 = crystal_2.calc_n0n2(method="fenics", set_n=True)

if prop_type == 'abcd_lct':
    crystal_params_2.A = full_crystal_abcd_mat_2[0][0]
    crystal_params_2.B = full_crystal_abcd_mat_2[0][1]
    crystal_params_2.C = full_crystal_abcd_mat_2[1][0]
    crystal_params_2.D = full_crystal_abcd_mat_2[1][1]
    crystal_params_2.nslice = 1
    crystal_params_2.n0 = [1.76]
    crystal_params_2.n2 = [0.00]
    crystal_2 = crystal.Crystal(crystal_params_2)

### Crystal 3 ('Amp4')

In [None]:
crystal_params_3 = PKDict(
    length      = 0.025, # [m]
    nslice      = num_crystal_slices,
    n0          = [1.76 for _ in range(num_crystal_slices)],
    n2          = [0.00 for _ in range(num_crystal_slices)],
    l_scale     = np.sqrt(np.pi)*w0,
    pop_inversion_n_cells=64,
    pop_inversion_mesh_extent=0.03/2.0,  # [m]
    pop_inversion_crystal_alpha=106.4,  # [1/m], 1.064 1/cm
    pop_inversion_pump_waist=15.4e-3,  # [m]
    pop_inversion_pump_wavelength=532.0e-9,  # [m]
    pop_inversion_pump_energy=7.0,  # [J], pump laser energy onto the crystal
    pop_inversion_pump_type="dual",
    pop_inversion_pump_gaussian_order=8,
    pop_inversion_pump_rep_rate=1.0,
)

crystal_3 = crystal.Crystal(crystal_params_3)
n0_slice_array, n2_slice_array3, full_crystal_abcd_mat_3 = crystal_3.calc_n0n2(method="fenics", set_n=True)

if prop_type == 'abcd_lct':
    crystal_params_3.A = full_crystal_abcd_mat_3[0][0]
    crystal_params_3.B = full_crystal_abcd_mat_3[0][1]
    crystal_params_3.C = full_crystal_abcd_mat_3[1][0]
    crystal_params_3.D = full_crystal_abcd_mat_3[1][1]
    crystal_params_3.nslice = 1
    crystal_params_3.n0 = [1.76]
    crystal_params_3.n2 = [0.00]
    crystal_3 = crystal.Crystal(crystal_params_3)

In [None]:
z = np.linspace(0,2.5,len(n2_slice_array3))

if prop_type == 'abcd_lct':
    print('full_crystal_abcd_mat_1 ',full_crystal_abcd_mat_1)
    print('full_crystal_abcd_mat_2 ',full_crystal_abcd_mat_2)
    print('full_crystal_abcd_mat_3 ',full_crystal_abcd_mat_3)
else:
    plt.figure()
    plt.plot(z,n2_slice_array1, '-.k', label='Crystal 1')
    plt.plot(z,n2_slice_array2, '--k', label='Crystal 2')
    plt.plot(z,n2_slice_array3, ':k', label='Crystal 3')
    plt.xlabel('longitudinal distance [cm]')
    plt.ylabel('n2')
    plt.legend()
    plt.title('8th Order Gaussian Pump Pulse')
    plt.show()

## Drift Matrix

In [None]:
# Drift length [m]
L_Drift = 2.75/2.0 # [m]

# instantiate an srw drift
e_drift_srw = lens.Drift_srw(L_Drift)

# instantiate an lct drift
e_drift_lct = lens.Drift_lct(L_Drift, l_scale)

***
## Initial Intensity and Phase

In [None]:
thisPulse_initial = pulse.LaserPulse(params)
thisPulse_initial.zero_phase()
plot_all(thisPulse_initial)

In [None]:
rayleigh_length = np.pi * (thisPulse_initial.sigx_waist)**2.0 / (thisPulse_initial._lambda0)
print('Rayleigh Length:', round(rayleigh_length,3), ' m')
print('RMS bunch length:', round(thisPulse_initial.sig_s,3), ' m')

***
## 4 Passes Through First Crystal ('Amp2')

In [None]:
thisPulse_firstCrystal = copy.deepcopy(thisPulse_initial)
num_passes = 4

for prop_index in np.arange(num_passes):
    print('Pass ',prop_index+1)
    
    thisPulse_firstCrystal = e_drift_srw.propagate(thisPulse_firstCrystal)
    
    # Crystal
    thisPulse_firstCrystal = crystal_1.propagate(thisPulse_firstCrystal, prop_type, gain, radial_n2)

    thisPulse_firstCrystal = e_drift_srw.propagate(thisPulse_firstCrystal)
    
    if prop_index < (num_passes -1):
        thisPulse_firstCrystal.ideal_mirror_180()

plot_all(thisPulse_firstCrystal)

***
## Telescope 1

In [None]:
f1 = -0.1
f2 = 0.2
d1 = 1.9
d2 = 0.1
d3 = 0.54

telescope_1 = lens.Telescope_lct(f1, f2, d1, d2, d3, l_scale)

thisPulse_telescope1 = copy.deepcopy(thisPulse_firstCrystal)
thisPulse_telescope1 = telescope_1.propagate(thisPulse_telescope1)
plot_all(thisPulse_telescope1)

***
## 3 Passes Through Second Crystal ('Amp3')

In [None]:
thisPulse_secondCrystal = copy.deepcopy(thisPulse_telescope1)
num_passes = 3

for prop_index in np.arange(num_passes):
    print('Pass ', prop_index+1)

    thisPulse_secondCrystal = e_drift_srw.propagate(thisPulse_secondCrystal)
    
    # Crystal
    thisPulse_secondCrystal = crystal_2.propagate(thisPulse_secondCrystal, prop_type, gain, radial_n2)

    thisPulse_secondCrystal = e_drift_srw.propagate(thisPulse_secondCrystal)
    
    if prop_index < (num_passes -1):
        thisPulse_secondCrystal.ideal_mirror_180()
        
plot_all(thisPulse_secondCrystal)

***
## Telescope 2

In [None]:
f1 = -0.3
f2 = 0.6
d1 = 1.0
d2 = 0.3
d3 = 0.66

telescope_2 = lens.Telescope_lct(f1, f2, d1, d2, d3, l_scale)

thisPulse_telescope2 = copy.deepcopy(thisPulse_secondCrystal)
thisPulse_telescope2 = telescope_2.propagate(thisPulse_telescope2)
plot_all(thisPulse_telescope2)

***
## 3 Passes Through Third Crystal ('Amp4')

In [None]:
thisPulse_thirdCrystal = copy.deepcopy(thisPulse_telescope2)
num_passes = 3

for prop_index in np.arange(num_passes):
    print('Pass ', prop_index+1)
    
    thisPulse_thirdCrystal = e_drift_srw.propagate(thisPulse_thirdCrystal)
    
    # Crystal
    thisPulse_thirdCrystal = crystal_3.propagate(thisPulse_thirdCrystal, prop_type, gain, radial_n2)

    thisPulse_thirdCrystal = e_drift_srw.propagate(thisPulse_thirdCrystal)
    
    if prop_index < (num_passes -1):
        thisPulse_thirdCrystal.ideal_mirror_180()
        
plot_all(thisPulse_thirdCrystal)