In [None]:
# Autoload when refreshing notebook
%load_ext autoreload
%autoreload 2

import numpy as np
import h5py
from scipy.io import loadmat
import pandas as pd
import re
import matplotlib.pyplot as plt
from types import SimpleNamespace
import scipy
import warnings
from scipy.ndimage import median_filter, gaussian_filter
from scipy.optimize import curve_fit

# import Python functions 
import sys
sys.path.append('../../')

from Python_Functions.functions import cropProfmonImg, matstruct_to_dict, extractDAQBSAScalars, segment_centroids_and_com, plot2DbunchseparationVsCollimatorAndBLEN

In [None]:
# Define XTCAV calibration
krf = 239.26
cal = 1167 # um/deg  http://physics-elog.slac.stanford.edu/facetelog/show.jsp?dir=/2025/11/13.03&pos=2025-$
streakFromGUI = cal*krf*180/np.pi*1e-6#um/um

# Sets the main beam energy
mainbeamE_eV = 10e9
# Sets the dnom value for CHER
dnom = 59.8e-3

# Sets data location
experiment = 'E300'
runname = '12431'

In [None]:
# Loads dataset
dataloc = '../../data/raw/' + experiment + '/' + experiment + '_' + runname + '/' + experiment + '_'  +runname + '.mat'
mat = loadmat(dataloc,struct_as_record=False, squeeze_me=True)
data_struct = mat['data_struct']

# Extracts number of steps
stepsAll = data_struct.params.stepsAll
if stepsAll is None or len(np.atleast_1d(stepsAll)) == 0:
    stepsAll = [1]

# calculate xt calibration factor
xtcalibrationfactor = data_struct.metadata.DTOTR2.RESOLUTION*1e-6/streakFromGUI/3e8

# cropping aspect ratio 
xrange = 100 
yrange = xrange


# gaussian filter parameter
hotPixThreshold = 1e3
sigma = 1
threshold = 5

In [None]:
# Extract current profiles and 2D LPS images 
xtcavImages_list = []
xtcavImages_list_raw = []
horz_proj_list = []
LPSImage = [] 

for a in range(len(stepsAll)):
    if len(stepsAll) == 1:
        raw_path = data_struct.images.DTOTR2.loc
    else: 
        raw_path = data_struct.images.DTOTR2.loc[a]
    match = re.search(rf'({experiment}_\d+/images/DTOTR2/DTOTR2_data_step\d+\.h5)', raw_path)
    if not match:
        raise ValueError(f"Path format invalid or not matched: {raw_path}")

    DTOTR2datalocation = '../../data/raw/'+ experiment + '/' + match.group(0)

    with h5py.File(DTOTR2datalocation, 'r') as f:
        data_raw = f['entry']['data']['data'][:].astype(np.float64)  # shape: (N, H, W)
    
    # Transpose to shape: (H, W, N)
    DTOTR2data_step = np.transpose(data_raw, (2, 1, 0))
    xtcavImages_step = DTOTR2data_step - data_struct.backgrounds.DTOTR2[:,:,np.newaxis].astype(np.float64)
    
    for idx in range(DTOTR2data_step.shape[2]):
        if idx is None:
            continue
        image = xtcavImages_step[:,:,idx]
        xtcavImages_list_raw.append(image[:,:,np.newaxis])
        
        # crop images 
        image_cropped, _ = cropProfmonImg(image, xrange, yrange, plot_flag=False)
        img_filtered = median_filter(image_cropped, size=3)
        hotPixels = img_filtered > hotPixThreshold
        img_filtered = np.ma.masked_array(img_filtered, hotPixels)
        processed_image = gaussian_filter(img_filtered, sigma=sigma, radius = 6*sigma + 1)
        processed_image[processed_image < threshold] = 0.0
        Nrows = np.array(processed_image).shape[0]
        
        # calcualte current profiles 
        horz_proj_idx = np.sum(processed_image, axis=0)
        horz_proj_idx = horz_proj_idx[:,np.newaxis]
        processed_image = processed_image[:,:,np.newaxis]
        image_ravel = processed_image.ravel()
        # combine current profiles into one array 
        horz_proj_list.append(horz_proj_idx)

        # combine images into one array 
        xtcavImages_list.append(processed_image)
        LPSImage.append([image_ravel])

xtcavImages = np.concatenate(xtcavImages_list, axis=2)
xtcavImages_raw = np.concatenate(xtcavImages_list_raw, axis=2)
horz_proj = np.concatenate(horz_proj_list, axis=1)
LPSImage = np.concatenate(LPSImage, axis = 0)

# Keeps only the data with a common index
DTOTR2commonind = data_struct.images.DTOTR2.common_index -1 
horz_proj = horz_proj[:,DTOTR2commonind]
xtcavImages = xtcavImages[:,:,DTOTR2commonind]
xtcavImages_raw = xtcavImages_raw[:,:,DTOTR2commonind]

#Make a copy of xtcavImages for comparison of before and after centroid correction
xtcavImages_centroid_uncorrected = xtcavImages.copy()
LPSImage = LPSImage[DTOTR2commonind,:]

In [None]:
bsaScalarData, bsaVars = extractDAQBSAScalars(data_struct)

ampl_idx = next(i for i, var in enumerate(bsaVars) if 'TCAV_LI20_2400_A' in var)
xtcavAmpl = bsaScalarData[ampl_idx, :]

phase_idx = next(i for i, var in enumerate(bsaVars) if 'TCAV_LI20_2400_P' in var)
xtcavPhase = bsaScalarData[phase_idx, :]

xtcavOffShots = xtcavAmpl<0.1
xtcavPhase[xtcavOffShots] = 0 #Set this for ease of plotting

isChargePV = [bool(re.search(r'TORO_LI20_2452_TMIT', pv)) for pv in bsaVars]
pvidx = [i for i, val in enumerate(isChargePV) if val]
charge = bsaScalarData[pvidx, :] * 1.6e-19  # in C 

minus_90_idx = np.where((xtcavPhase >= -91) & (xtcavPhase <= -89))[0]
plus_90_idx = np.where((xtcavPhase >= 89) & (xtcavPhase <= 91))[0]
off_idx = np.where(xtcavPhase == 0)[0]
all_idx = np.append(minus_90_idx,plus_90_idx)

currentProfile_all = [] 

# Process all degree shots
for ij in range(len(all_idx)):
    idx = all_idx[ij]
    streakedProfile = horz_proj[:,idx]

    tvar = np.arange(1, len(streakedProfile) + 1) * xtcalibrationfactor
    tvar = tvar - np.median(tvar)  # Center around zero

    prefactor = charge[0, idx] / np.trapz(streakedProfile, tvar)

    currentProfile = 1e-3 * streakedProfile * prefactor  # Convert to kA
    currentProfile_all.append(currentProfile)
    
currentProfile_all = np.array(currentProfile_all)

### Energy Correction (Optional)

In [None]:
print(bsaVars)
# BPMS_LI20_2445_X is supposed to measure the beam energy right before the TCAV
# BPMS_LI14_801_X is supposed to measure the beam energy at LI14
# 'BLEN_LI14_888_BRAW' is the length of the bunch at LI14
# 'BLEN_LI11_359_BRAW' is the length of the bunch at LI11
# 'BPMS_LI11_333_X' is supposed to measure the beam energy at LI11
energy_idx = next(i for i, var in enumerate(bsaVars) if 'BPMS_LI11_333_X' in var)
beamEnergyM = bsaScalarData[energy_idx, minus_90_idx]
beamEnergyP = bsaScalarData[energy_idx, plus_90_idx]
beamEnergyO = bsaScalarData[energy_idx, off_idx]
# print(beamEnergy)
# Create LPS image center of mass y coordinate vs beam energy plot
plt.figure(figsize=(6,4))
plt.scatter(beamEnergyM, np.array([np.sum(xtcavImages[:,:,i]*np.arange(xtcavImages.shape[1])[np.newaxis,:])/np.sum(xtcavImages[:,:,i]) for i in minus_90_idx]), c='blue', s=5)
plt.scatter(beamEnergyP, np.array([np.sum(xtcavImages[:,:,i]*np.arange(xtcavImages.shape[1])[np.newaxis,:])/np.sum(xtcavImages[:,:,i]) for i in plus_90_idx]), c='red', s=5)
plt.scatter(beamEnergyO, np.array([np.sum(xtcavImages[:,:,i]*np.arange(xtcavImages.shape[1])[np.newaxis,:])/np.sum(xtcavImages[:,:,i]) for i in off_idx]), c='green', s=5)

plt.xlabel('BPMS_LI11_333_X')
plt.ylabel('LPS Image Center of Mass [pix]')
plt.title('LPS Image Center of Mass vs Beam Energy PV, Cropped Image')
plt.legend(['-90 deg','+90 deg','0 deg'])
plt.grid()
plt.show()

plt.figure(figsize=(6,4))
plt.scatter(beamEnergyM, np.array([np.sum(xtcavImages_raw[:,:,i]*np.arange(xtcavImages_raw.shape[1])[np.newaxis,:])/np.sum(xtcavImages_raw[:,:,i]) for i in minus_90_idx]), c='blue', s=5)
plt.scatter(beamEnergyP, np.array([np.sum(xtcavImages_raw[:,:,i]*np.arange(xtcavImages_raw.shape[1])[np.newaxis,:])/np.sum(xtcavImages_raw[:,:,i]) for i in plus_90_idx]), c='red', s=5)
plt.scatter(beamEnergyO, np.array([np.sum(xtcavImages_raw[:,:,i]*np.arange(xtcavImages_raw.shape[1])[np.newaxis,:])/np.sum(xtcavImages_raw[:,:,i]) for i in off_idx]), c='green', s=5)

plt.xlabel('BPMS_LI11_333_X')
plt.ylabel('LPS Image Center of Mass [pix]')
plt.title('LPS Image Center of Mass vs Beam Energy PV, RAW Image')
plt.legend(['-90 deg','+90 deg','0 deg'])
plt.grid()
plt.show()

###  Centroid Correction (Optional)

In [None]:
import numpy as np

def construct_centroid_function(images, off_idx, smoothing_window_size=5, max_degree=1):
    """
    Constructs a smoothed centroid function with localized quadratic extrapolation.

    This function calculates the mean horizontal center of mass (COM) for each
    row across a selection of images. It handles unreliable data with advanced logic:
    1.  **Interpolation**: Fills gaps for rows with high COM variance using linear
        interpolation between stable rows.
    2.  **Smoothing**: Applies a moving average filter to the entire COM profile.
    3.  **Local Extrapolation**: For rows far from any stable data, it performs
        two separate polynomial fits:
        -   **Top Extrapolation**: Fits a degree-2 polynomial to the top-most
            stable rows (up to 7 points) to project the trend upwards.
        -   **Bottom Extrapolation**: Fits a separate degree-2 polynomial to the
            bottom-most stable rows to project the trend downwards.
        This local approach correctly handles non-uniform trends (e.g., S-curves).

    Args:
        images (list or np.ndarray): A list or 3D NumPy array of 2D image arrays.
                                     All images must have the same dimensions.
        off_idx (list or np.ndarray): Indices of images to use for the calculation.
        smoothing_window_size (int, optional): The size of the moving average window
                                               for smoothing. Must be an odd number.
                                               Defaults to 5.

    Returns:
        np.ndarray: A 1D array where each value is the integer horizontal shift
                    required to center the content of that row.
    """
    if not isinstance(off_idx, (list, np.ndarray)):
        raise ValueError("off_idx must be a non-empty list or array of indices.")
    if smoothing_window_size % 2 != 1:
        raise ValueError("smoothing_window_size must be a positive odd number.")

    # 1. Select images and get dimensions
    selected_images = np.array([images[:,:,i] for i in off_idx])
    num_images, num_rows, num_cols = selected_images.shape
    image_center = num_cols / 2.0

    # 2. Calculate Center of Mass (COM) for each row
    col_indices = np.arange(num_cols)
    epsilon = 1e-9
    row_sums = selected_images.sum(axis=2)
    all_row_coms = np.sum(selected_images * col_indices, axis=2) / (row_sums + epsilon)
    all_row_coms[row_sums == 0] = np.nan

    # 3. Identify stable ("good") rows
    mean_coms = np.nanmean(all_row_coms, axis=0)
    std_dev_coms = np.nanstd(all_row_coms, axis=0)
    variance_threshold = 0.15 * num_cols
    good_rows_indices = np.where(std_dev_coms <= variance_threshold)[0]
    all_row_indices = np.arange(num_rows)
    
    # Handle cases with insufficient good data
    if len(good_rows_indices) < 2:
        print("Warning: Fewer than 2 stable rows found. Cannot perform reliable analysis.")
        return np.zeros(num_rows, dtype=int)
        
    # 4. Interpolate and Smooth
    good_com_values = mean_coms[good_rows_indices]
    interpolated_coms = np.interp(all_row_indices, good_rows_indices, good_com_values)
    
    if smoothing_window_size > 1:
        kernel = np.ones(smoothing_window_size) / smoothing_window_size
        smoothed_coms = np.convolve(interpolated_coms, kernel, mode='same')
    else:
        smoothed_coms = interpolated_coms
        
    # 5. Perform LOCAL EXTRAPOLATION
    final_coms = np.copy(smoothed_coms) # Start with interpolated/smoothed data
    min_good_idx, max_good_idx = good_rows_indices[0], good_rows_indices[-1]

    # --- Top Extrapolation ---
    top_extrap_indices = np.arange(0, min_good_idx)
    if top_extrap_indices.size > 0:
        # Select up to degree+5 (7) points from the top of the stable region
        fit_indices = good_rows_indices[:max_degree + 5]
        fit_values = good_com_values[:max_degree + 5]
        
        # Need at least degree+1 points to fit. We require 2 for linear, 3 for quadratic.
        if len(fit_indices) >= 2:
            degree = min(max_degree, len(fit_indices) - 1)
            coeffs = np.polyfit(fit_indices, fit_values, degree)
            poly_func = np.poly1d(coeffs)
            final_coms[top_extrap_indices] = poly_func(top_extrap_indices)
            
    # --- Bottom Extrapolation ---
    bottom_extrap_indices = np.arange(max_good_idx + 1, num_rows)
    if bottom_extrap_indices.size > 0:
        # Select up to degree+5 (7) points from the bottom of the stable region
        fit_indices = good_rows_indices[-(max_degree + 5):]
        fit_values = good_com_values[-(max_degree + 5):]

        if len(fit_indices) >= 2:
            degree = min(max_degree, len(fit_indices) - 1)
            coeffs = np.polyfit(fit_indices, fit_values, degree)
            poly_func = np.poly1d(coeffs)
            final_coms[bottom_extrap_indices] = poly_func(bottom_extrap_indices)

    # 6. Calculate the final correction shift
    horizontal_correction = image_center - final_coms
    return np.round(horizontal_correction).astype(int)

In [None]:
centroid_corrections = construct_centroid_function(xtcavImages, off_idx)
# Extract current profiles and 2D LPS images 
xtcavImages_list = []
horz_proj_list = []
#LPS is overwritten in this optional step. If not needed, this block can be skipped.
LPSImage = [] 

for a in range(len(stepsAll)):
    if len(stepsAll) == 1:
        raw_path = data_struct.images.DTOTR2.loc
    else: 
        raw_path = data_struct.images.DTOTR2.loc[a]
    match = re.search(rf'({experiment}_\d+/images/DTOTR2/DTOTR2_data_step\d+\.h5)', raw_path)
    if not match:
        raise ValueError(f"Path format invalid or not matched: {raw_path}")

    DTOTR2datalocation = '../../data/raw/'+ experiment + '/' + match.group(0)

    with h5py.File(DTOTR2datalocation, 'r') as f:
        data_raw = f['entry']['data']['data'][:].astype(np.float64)  # shape: (N, H, W)
    
    # Transpose to shape: (H, W, N)
    DTOTR2data_step = np.transpose(data_raw, (2, 1, 0))
    xtcavImages_step = DTOTR2data_step - data_struct.backgrounds.DTOTR2[:,:,np.newaxis].astype(np.float64)
    
    for idx in range(DTOTR2data_step.shape[2]):
        if idx is None:
            continue
        image = xtcavImages_step[:,:,idx]
        
        # crop images 
        image_cropped, _ = cropProfmonImg(image, xrange, yrange, plot_flag=False)
        img_filtered = median_filter(image_cropped, size=3)
        hotPixels = img_filtered > hotPixThreshold
        img_filtered = np.ma.masked_array(img_filtered, hotPixels)
        processed_image = gaussian_filter(img_filtered, sigma=sigma, radius = 6*sigma + 1)
        processed_image[processed_image < threshold] = 0.0
        Nrows = np.array(processed_image).shape[0]

        # Apply centroid correction
        corrected_image = np.zeros_like(processed_image)
        for row in range(Nrows):
            shift = centroid_corrections[row]
            corrected_image[row, :] = np.roll(processed_image[row, :], shift)
        
        # calcualte current profiles 
        horz_proj_idx = np.sum(corrected_image, axis=0)
        horz_proj_idx = horz_proj_idx[:,np.newaxis]
        corrected_image = corrected_image[:,:,np.newaxis]
        image_ravel = corrected_image.ravel()
        # combine current profiles into one array 
        horz_proj_list.append(horz_proj_idx)

        # combine images into one array 
        xtcavImages_list.append(corrected_image)
        LPSImage.append([image_ravel])

xtcavImages = np.concatenate(xtcavImages_list, axis=2)
horz_proj = np.concatenate(horz_proj_list, axis=1)
LPSImage = np.concatenate(LPSImage, axis = 0)

# Keeps only the data with a common index
DTOTR2commonind = data_struct.images.DTOTR2.common_index -1 
horz_proj = horz_proj[:,DTOTR2commonind]
xtcavImages = xtcavImages[:,:,DTOTR2commonind]
LPSImage = LPSImage[DTOTR2commonind,:]

In [None]:
bsaScalarData, bsaVars = extractDAQBSAScalars(data_struct)

ampl_idx = next(i for i, var in enumerate(bsaVars) if 'TCAV_LI20_2400_A' in var)
xtcavAmpl = bsaScalarData[ampl_idx, :]

phase_idx = next(i for i, var in enumerate(bsaVars) if 'TCAV_LI20_2400_P' in var)
xtcavPhase = bsaScalarData[phase_idx, :]

xtcavOffShots = xtcavAmpl<0.1
xtcavPhase[xtcavOffShots] = 0 #Set this for ease of plotting

isChargePV = [bool(re.search(r'TORO_LI20_2452_TMIT', pv)) for pv in bsaVars]
pvidx = [i for i, val in enumerate(isChargePV) if val]
charge = bsaScalarData[pvidx, :] * 1.6e-19  # in C 

minus_90_idx = np.where((xtcavPhase >= -91) & (xtcavPhase <= -89))[0]
plus_90_idx = np.where((xtcavPhase >= 89) & (xtcavPhase <= 91))[0]
off_idx = np.where(xtcavPhase == 0)[0]
all_idx = np.append(minus_90_idx,plus_90_idx)

currentProfile_all = [] 

# Process all degree shots
for ij in range(len(all_idx)):
    idx = all_idx[ij]
    streakedProfile = horz_proj[:,idx]

    tvar = np.arange(1, len(streakedProfile) + 1) * xtcalibrationfactor
    tvar = tvar - np.median(tvar)  # Center around zero

    prefactor = charge[0, idx] / np.trapz(streakedProfile, tvar)

    currentProfile = 1e-3 * streakedProfile * prefactor  # Convert to kA
    currentProfile_all.append(currentProfile)
    
currentProfile_all = np.array(currentProfile_all)

### Check Images

In [None]:
# Find the first shot where tcav is at -90, 0 and +90 deg
idx = 1
near_minus_90_idx = np.where((xtcavPhase >= -90.55) & (xtcavPhase <= -89.55))[0][idx]
near_plus_90_idx = np.where((xtcavPhase >= 89.55) & (xtcavPhase <= 90.55))[0][idx]
zero_idx = np.where(xtcavPhase == 0)[0][idx]

sample_image_indices = [near_minus_90_idx, zero_idx, near_plus_90_idx]
plot_titles = ['Tcav phase -90 deg', '0 deg', '+90 deg']


# Define the x and yrange for cropping the image; Need to automate this
# figure;imagesc(sampleImage)

xrange = 100
yrange = xrange

fig, axs = plt.subplots(1, 4, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1, 1, 0.1]})
fig.suptitle(f'TCAV images RAW DAQ {experiment} - {runname}', fontsize=14)

for i, idx in enumerate(sample_image_indices):
    if idx is None:
        continue

    sample_image = xtcavImages_raw[:, :, idx]

    axs[i].imshow(sample_image, cmap='jet', aspect='auto')
    axs[i].set_title(plot_titles[i])

# Colorbar, top right corner, horizontal
cbar = fig.colorbar(axs[2].images[0], cax = axs[3], orientation='vertical', fraction=0.05, pad=0.2)
cbar.set_label('Intensity [a.u.]')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

fig, axs = plt.subplots(2, 3, figsize=(12, 6))
fig.suptitle(f'TCAV images before centroid correction DAQ {experiment} - {runname}', fontsize=14)

for i, idx in enumerate(sample_image_indices):
    if idx is None:
        continue

    sample_image = xtcavImages_centroid_uncorrected[:, :, idx]
    horz_proj = np.sum(sample_image, axis=0)

    axs[0, i].imshow(sample_image, cmap='jet', aspect='auto')
    axs[0, i].set_title(plot_titles[i])

    axs[1, i].plot(horz_proj)
    axs[1, i].set_title("Horizontal Projection")
    #If i==1, the center plot, also plot centroid_corrections on the 2d image
    if i==1:
        for row in range(sample_image.shape[0]):
            shift = centroid_corrections[row]
            # Plot a dot at (shift, row)
            axs[0, i].plot(xrange - shift, row, 'wo', markersize=1)

        # Draw a vertical line at the center of mass x
        center_of_mass_x = np.sum(horz_proj * np.arange(horz_proj.shape[0])) / np.sum(horz_proj)
        axs[0, i].axvline(center_of_mass_x, color='w', linestyle='--')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

fig, axs = plt.subplots(2, 3, figsize=(12, 6))
fig.suptitle(f'TCAV images after centroid correction DAQ {experiment} - {runname}', fontsize=14)

for i, idx in enumerate(sample_image_indices):
    if idx is None:
        continue

    sample_image = xtcavImages[:, :, idx]
    horz_proj = np.sum(sample_image, axis=0)

    axs[0, i].imshow(sample_image, cmap='jet', aspect='auto')
    axs[0, i].set_title(plot_titles[i])

    axs[1, i].plot(horz_proj)
    axs[1, i].set_title("Horizontal Projection")

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

### Filter Good Shots

In [None]:
# Filter out "bad" shots with Bi-Gaussian fit 
def bi_gaussian(x, A1, mu1, sigma1, A2, mu2, sigma2):
    return (A1 * np.exp(-(x - mu1)**2 / (2 * sigma1**2)) +
            A2 * np.exp(-(x - mu2)**2 / (2 * sigma2**2)))

amp1 = []
amp2 = []
mu1 = []
mu2 = []
R_squared = []

for ij in range(len(all_idx)):
    y = currentProfile_all[ij, :]
    x = np.arange(len(y))

    # Initial guess: [A1, mu1, sigma1, A2, mu2, sigma2]
    if xtcavPhase[all_idx][ij] < 0:
        initial_guess = [np.max(y), 100, 4, np.max(y)*0.1, 60 + ij*0.15, 4]
    elif xtcavPhase[all_idx][ij] > 0:
        initial_guess = [np.max(y), 100, 4, np.max(y)*0.1, 60, 4]
    
    try:
        popt, pcov = curve_fit(bi_gaussian, x, y, p0=initial_guess, maxfev=5000)
    except RuntimeError:
        amp1.append(np.nan)
        R_squared.append(np.nan)
        continue

    # Extract parameters
    A1, mu1_val, sig1, A2, mu2_val, sig2 = popt
    amp1.append(A1)
    amp2.append(A2)
    mu1.append(mu1_val)
    mu2.append(mu2_val)

    # Evaluate fit
    y_fit = bi_gaussian(x, *popt)
    SST = np.sum((y - np.mean(y))**2)
    SSR = np.sum((y - y_fit)**2)
    R_squared.append(1 - SSR / SST)

# Convert results to arrays
amp1 = np.array(amp1)
R_squared = np.array(R_squared)
# set requirements for "good" shots. For xtcavPhase>0, we want larger (A1) peak at larger x (mu1).
# For xtcavPhase<0, we want larger (A2) peak at smaller x (mu2).
goodShots = np.where((R_squared > 0.97) & (amp1 < 50))[0]
#goodShots_twobunch_tcav = np.where((R_squared > 0.97) & (amp1 < 50) & ((mu1 > mu2) & (amp1 < amp2)))[0]

In [None]:
# Plot some good shots xtcavOffShots
idx = 5
fig, (ax1) = plt.subplots(1,1,figsize=(9, 6))
im1 = ax1.imshow(xtcavImages[:,:,minus_90_idx[idx]], cmap = "jet",aspect='auto')
# ax1.suptitle(f"Current Profile Index: {idx}")
cbar1 = plt.colorbar(im1, ax=ax1)
cbar1.set_label("Charge(a.u.)")

### MLP 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

index = np.sort(all_idx[goodShots])
images = LPSImage[all_idx,:][goodShots,:]
steps = data_struct.scalars.steps[DTOTR2commonind]
predictor = np.vstack((bsaScalarData[:,goodShots], steps[goodShots])).T

x_scaler = MinMaxScaler()
iz_scaler = MinMaxScaler()
x_scaled = x_scaler.fit_transform(predictor)
Iz_scaled = iz_scaler.fit_transform(images)

# 80/20 train-test split
x_train_full, x_test_scaled, Iz_train_full, Iz_test_scaled, ntrain, ntest = train_test_split(
    x_scaled, Iz_scaled, index, test_size=0.2, random_state = 42)

# 20% validation split 
x_train_scaled, X_val, Iz_train_scaled, Y_val = train_test_split(
    x_train_full, Iz_train_full, test_size=0.2, random_state = 42)

# compress pixels 
pca = PCA(n_components=100)
compressed_targets = pca.fit_transform(Iz_train_scaled) 
print(Iz_train_scaled.shape, compressed_targets.shape)
Y_val = pca.transform(Y_val)

# Convert to PyTorch tensors
X_train = torch.tensor(x_train_scaled, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
X_test = torch.tensor(x_test_scaled, dtype=torch.float32)
Y_train = torch.tensor(compressed_targets, dtype=torch.float32)
Y_val = torch.tensor(Y_val, dtype=torch.float32)
Y_test = torch.tensor(Iz_test_scaled, dtype=torch.float32)

train_ds = TensorDataset(X_train, Y_train)
train_dl = DataLoader(train_ds, batch_size=24, shuffle=True)


In [None]:
import time

# Define MLP structure
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_dim, 1000),
            nn.ReLU(),
            nn.Linear(1000,500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, out_dim)
        )
    def forward(self, x):
        return self.model(x)

model = MLP(X_train.shape[1], Y_train.shape[1])
optimizer = optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999))
loss_fn = nn.L1Loss()

# Define custom weighted MSE loss function 
def custom_loss( y_pred,y_true): 
    mse = (y_true - y_pred)**2
    weights = 1 + 0.7*((y_true < 0.2)|(y_true > 0.8)).float()
    return torch.mean(weights*mse)

# Training loop 
n_epochs = 200
patience = 25
best_val_loss = float('inf')
early_stop_counter = 0

t0 = time.time()

# Fit the nn model on the training set
train_losses = []
val_losses = []

for epoch in range(n_epochs):
    model.train()
    train_loss = 0
    for xb, yb in train_dl:
        pred = model(xb)
        loss = custom_loss(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_dl)
    train_losses.append(avg_train_loss)

    # Validation loss
    model.eval()
    with torch.no_grad():
        val_pred = model(X_val)
        val_loss = custom_loss(val_pred, Y_val).item()
        val_losses.append(val_loss)

    # Early stopping logic
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict()
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            break
    
model.load_state_dict(best_model_state)
    
# Evaluate model
model.eval()
with torch.no_grad():
    pred_train_scaled = model(X_train).numpy()
    pred_test_scaled = model(X_test).numpy()

# Inverse transform predictions
pred_train_full = iz_scaler.inverse_transform(pca.inverse_transform(pred_train_scaled))
pred_test_full = iz_scaler.inverse_transform(pca.inverse_transform(pred_test_scaled))
Iz_train_true = iz_scaler.inverse_transform(Iz_train_scaled)
Iz_test_true = iz_scaler.inverse_transform(Iz_test_scaled)
elapsed = time.time() - t0
print("Elapsed time [mins] = {:.1f} ".format(elapsed/60))

# Compute R²
def r2_score(true, pred):
    RSS = np.sum((true - pred)**2)
    TSS = np.sum((true - np.mean(true))**2)
    return 1 - RSS / TSS if TSS != 0 else s0

print("Train R²: {:.2f} %".format(r2_score(Iz_train_true.ravel(), pred_train_full.ravel()) * 100))
print("Test R²: {:.2f} %".format(r2_score(Iz_test_true.ravel(), pred_test_full.ravel()) * 100))


In [None]:
PCA_pred = pred_test_full 

In [None]:
idx = 10
fig, (ax1, ax2) = plt.subplots(1,2,figsize=(9, 3))
im1 = ax1.imshow(pred_test_full.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], cmap = "jet",aspect='auto', vmin = 0, vmax = 600)
# ax1.suptitle(f"Current Profile Index: {idx}")
ax1.set(ylabel="y [pix]")
ax1.set(xlabel = "Time [fs]")
# ax1.set_title('True', fontsize = 12)
# ax1.set(title = "True", fontsize = 2)
ax1.set(xlim = (0,2*xrange))
ax1.set(ylim= (0,2*yrange))

im2 = ax2.imshow(PCA_pred.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], cmap = "jet",aspect='auto',vmin = 0, vmax = 600)
ax2.set(xlabel = "Time [fs]")
ax2.set(ylabel = "y [pix]")
# ax2.set_title('Prediction', fontsize = 12)
ax2.set(xlim = (0,2*xrange))
ax2.set(ylim= (0,2*yrange))
cbar = fig.colorbar(im1, ax=[ax1, ax2], fraction=0.16, pad=0.04)
plt.suptitle(f'Test Set LPS Image: Shot {ntest[idx]}', fontsize = 12 )


### Effect of PCA (before MLP)

In [None]:
idx = 15
fig, (ax2, ax3, cx2) = plt.subplots(1,3,figsize=(15, 3), gridspec_kw={'width_ratios': [1, 1, 0.02]})
before_pca_image = Iz_test_true.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx]
#Flip in y direction for proper visualization
im2 = ax2.imshow(np.flip(before_pca_image, axis=0), cmap = "jet",aspect='auto', vmin = 0, vmax = 400)
# ax2.suptitle(f"Current Profile Index: {idx}")
ax2.set(ylabel="y [pix]")
ax2.set(xlabel = "Time [fs]")
ax2.set_title('Before PCA', fontsize = 12)
# ax2.set(title = "True", fontsize = 2)
ax2.set(xlim = (0,2*xrange))
ax2.set(ylim= (0,2*yrange))

after_pca_image = pca.inverse_transform(pca.transform(before_pca_image.flatten()[np.newaxis,:])).reshape(2*yrange,2*xrange)

im3 = ax3.imshow(np.flip(after_pca_image, axis=0), cmap = "jet",aspect='auto', vmin = 0, vmax = 400)
ax3.set(xlabel = "Time [fs]")
ax3.set(ylabel = "y [pix]")
ax3.set_title('After PCA', fontsize = 12)
ax3.set(xlim = (0,2*xrange))
ax3.set(ylim= (0,2*yrange))
fig.colorbar((im2), cax=cx2, format='%.3g')
fig.subplots_adjust(wspace=0.8)
#fig.tight_layout()


In [None]:
from ipywidgets import interact, IntSlider
def plot_xtcav_image(idx):
    fig, (ax1, ax2) = plt.subplots(1,2,figsize=(9, 3))
    im1 = ax1.imshow(np.flip(Iz_test_true.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], axis=0), cmap = "jet",aspect='auto', vmin = 0, vmax = 600)
   
    # ax1.suptitle(f"Current Profile Index: {idx}")
    ax1.set(ylabel="y [pix]")
    ax1.set(xlabel = "Time [fs]")
    ax1.set(title = "True")
    ax1.set(xlim = (0,2*xrange))
    ax1.set(ylim= (0,2*yrange))

    im2 = ax2.imshow(np.flip(pred_test_full.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], axis=0), cmap = "jet",aspect='auto',vmin = 0, vmax = 600)
    ax2.set(xlabel = "Time [fs]")
    ax2.set(ylabel = "y [pix]")
    ax2.set(title = "Prediction")
    ax2.set(xlim = (0,2*xrange))
    ax2.set(ylim= (0,2*yrange))
    cbar = fig.colorbar(im1, ax=[ax1, ax2], fraction=0.16, pad=0.04)
    # cbar.set_label("Current [arb. units]")
    # plt.tight_layout()
    # fig.show()

# Create slider
interact(plot_xtcav_image, idx=IntSlider(min=0, max=pred_test_full.shape[0]-1, step=1, value=0))

In [None]:
streakedProfile = horz_proj[:,idx]

tvar = np.arange(1, len(streakedProfile) + 1) * xtcalibrationfactor
tvar = tvar - np.median(tvar)  # Center around zero

prefactor = charge[0, idx] / np.trapz(streakedProfile, tvar)

currentProfile = 1e-3 * streakedProfile * prefactor  # Convert to kA
currentProfile_all.append(currentProfile)

In [None]:
from ipywidgets import interact, IntSlider
def plot_xtcav_image(idx):

    horz_proj_true = np.sum(Iz_test_true.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], axis = 0)
    tvar = np.arange(1, len(horz_proj_true) + 1) * xtcalibrationfactor
    tvar = tvar - np.median(tvar) 
    prefactor = charge[0, ntest[idx]] / np.trapezoid(horz_proj_true, tvar) # needs to be idx from entire data set ==> manual train/test split with test_idx
    currentProfile = 1e-3 * horz_proj_true * prefactor 
    plt.plot(currentProfile, label = "True", alpha = 0.5)
    # ax1.suptitle(f"Current Profile Index: {idx}")
    plt.ylabel("y [pix]")
    plt.xlabel("Time [fs]")
    # ax1.set(xlim = (0,200))
    # ax1.set(ylim= (0,200))

    horz_proj_pred = np.sum(pred_test_full.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], axis = 0)
    prefactor = charge[0, ntest[idx]] / np.trapezoid(horz_proj_pred, tvar)
    currentProfile = 1e-3 * horz_proj_pred * prefactor 
    plt.plot(currentProfile, label = "Prediction", alpha = 0.5)
    # ax2.set(xlim = (0,200))
    # ax2.set(ylim= (0,200))
    # cbar.set_label("Current [arb. units]")
    # plt.tight_layout()
    plt.legend()
    plt.show()

# Create slider
interact(plot_xtcav_image, idx=IntSlider(min=0, max=pred_test_full.shape[0]-1, step=1, value=0))