In [None]:
# Autoload when refreshing notebook
%load_ext autoreload
%autoreload 2

import numpy as np
import h5py
from scipy.io import loadmat
import pandas as pd
import re
import matplotlib.pyplot as plt
from types import SimpleNamespace
import scipy
import warnings
from scipy.ndimage import median_filter, gaussian_filter
from scipy.optimize import curve_fit

# import Python functions 
import sys
sys.path.append('../../')

from Python_Functions.functions import cropProfmonImg, matstruct_to_dict, extractDAQBSAScalars, segment_centroids_and_com, plot2DbunchseparationVsCollimatorAndBLEN, extract_processed_images, apply_centroid_correction, commonIndexFromSteps

In [None]:
# Define XTCAV calibration
krf = 239.26
cal = 1167 # um/deg  http://physics-elog.slac.stanford.edu/facetelog/show.jsp?dir=/2025/11/13.03&pos=2025-$
streakFromGUI = cal*krf*180/np.pi*1e-6#um/um

# Sets the main beam energy
mainbeamE_eV = 10e9
# Sets the dnom value for CHER
dnom = 59.8e-3

# Sets data location
experiment = 'E338' #'E300' 'E338'
runname = '12710' #'12431' '12710'

In [None]:
# Loads dataset
dataloc = '../../data/raw/' + experiment + '/' + experiment + '_' + runname + '/' + experiment + '_'  +runname + '.mat'
mat = loadmat(dataloc,struct_as_record=False, squeeze_me=True)
data_struct = mat['data_struct']

# Extracts number of steps
stepsAll = data_struct.params.stepsAll
if stepsAll is None or len(np.atleast_1d(stepsAll)) == 0:
    stepsAll = [1]

# calculate xt calibration factor
xtcalibrationfactor = data_struct.metadata.DTOTR2.RESOLUTION*1e-6/streakFromGUI/3e8

# cropping aspect ratio 
xrange = 100 
yrange = xrange


# gaussian filter parameter
hotPixThreshold = 1e3
sigma = 1
threshold = 5

In [None]:
# Take at most five steps
step_list = stepsAll[:3]
print("Processing steps:", step_list)
bsaScalarData, bsaVars = extractDAQBSAScalars(data_struct, step_list)

ampl_idx = next(i for i, var in enumerate(bsaVars) if 'TCAV_LI20_2400_A' in var)
xtcavAmpl = bsaScalarData[ampl_idx, :]

phase_idx = next(i for i, var in enumerate(bsaVars) if 'TCAV_LI20_2400_P' in var)
xtcavPhase = bsaScalarData[phase_idx, :]

xtcavOffShots = xtcavAmpl<0.1
xtcavPhase[xtcavOffShots] = 0 #Set this for ease of plotting

isChargePV = [bool(re.search(r'TORO_LI20_2452_TMIT', pv)) for pv in bsaVars]
pvidx = [i for i, val in enumerate(isChargePV) if val]
charge = bsaScalarData[pvidx, :] * 1.6e-19  # in C 

minus_90_idx = np.where((xtcavPhase >= -91) & (xtcavPhase <= -89))[0]
plus_90_idx = np.where((xtcavPhase >= 89) & (xtcavPhase <= 91))[0]
off_idx = np.where(xtcavPhase == 0)[0]
all_idx = np.append(minus_90_idx,plus_90_idx)



In [None]:
# Extract current profiles and 2D LPS images 
xtcavImages_list = []
xtcavImages_list_raw = []
horz_proj_list = []
LPSImage = [] 

## Below value MUST be specified for DAQs with unwanted refraction patterns, etc.

roi_xrange = (400, 700)
roi_yrange = (400, 600)
xtcavImages_centroid_uncorrected, xtcavImages_raw, horz_proj, LPSImage = extract_processed_images(data_struct, experiment, xrange, yrange, hotPixThreshold, sigma, threshold, step_list, roi_xrange, roi_yrange)
print(LPSImage.shape)


### Common Index Debugging (Could be very confusing)

In [None]:
testInd = commonIndexFromSteps(data_struct, [0,1,2,3])
print(testInd)
print(len(testInd))
print(data_struct.images.DTOTR2.common_index - 1)
print(data_struct.images.DTOTR2.common_index.shape)
print(data_struct.scalars.steps.shape)
#print(np.where(data_struct.images.DTOTR2.common_index == 87)[0])
DTOTR2commonind = commonIndexFromSteps(data_struct, [0,1,2,3])
print("Number of shots after applying common index and step range:", len(DTOTR2commonind))
print(np.array(DTOTR2commonind))
# Example: If DTOTR2commonind = [0,1,4,6],  new index is [0,1,0,0,2,0,3,...]
new_index_list = np.full(np.max(DTOTR2commonind) + 1, -1, dtype=int)
new_index_list[DTOTR2commonind] = np.arange(len(DTOTR2commonind))
new_index_list[new_index_list == -1] = 0
print(new_index_list)
print(len(new_index_list))

### Beam Energy Cross-Check

In [None]:
print(bsaVars)
# BPMS_LI20_2445_X is supposed to measure the beam energy right before the TCAV
# BPMS_LI14_801_X is supposed to measure the beam energy at LI14
# 'BLEN_LI14_888_BRAW' is the length of the bunch at LI14
# 'BLEN_LI11_359_BRAW' is the length of the bunch at LI11
# 'BPMS_LI11_333_X' is supposed to measure the beam energy at LI11
energy_idx = next(i for i, var in enumerate(bsaVars) if 'BPMS_LI11_333_X' in var)
beamEnergyM = bsaScalarData[energy_idx, minus_90_idx]
beamEnergyP = bsaScalarData[energy_idx, plus_90_idx]
beamEnergyO = bsaScalarData[energy_idx, off_idx]
# print(beamEnergy)
# Create LPS image center of mass y coordinate vs beam energy plot
plt.figure(figsize=(6,4))
plt.scatter(beamEnergyM, np.array([np.sum(xtcavImages_centroid_uncorrected[:,:,i]*np.arange(xtcavImages_centroid_uncorrected.shape[1])[np.newaxis,:])/np.sum(xtcavImages_centroid_uncorrected[:,:,i]) for i in minus_90_idx]), c='blue', s=5)
plt.scatter(beamEnergyP, np.array([np.sum(xtcavImages_centroid_uncorrected[:,:,i]*np.arange(xtcavImages_centroid_uncorrected.shape[1])[np.newaxis,:])/np.sum(xtcavImages_centroid_uncorrected[:,:,i]) for i in plus_90_idx]), c='red', s=5)
plt.scatter(beamEnergyO, np.array([np.sum(xtcavImages_centroid_uncorrected[:,:,i]*np.arange(xtcavImages_centroid_uncorrected.shape[1])[np.newaxis,:])/np.sum(xtcavImages_centroid_uncorrected[:,:,i]) for i in off_idx]), c='green', s=5)

plt.xlabel('BPMS_LI11_333_X')
plt.ylabel('LPS Image Center of Mass [pix]')
plt.title('LPS Image Center of Mass vs Beam Energy PV, Cropped Image')
plt.legend(['-90 deg','+90 deg','0 deg'])
plt.grid()
plt.show()

plt.figure(figsize=(6,4))
plt.scatter(beamEnergyM, np.array([np.sum(xtcavImages_raw[:,:,i]*np.arange(xtcavImages_raw.shape[1])[np.newaxis,:])/np.sum(xtcavImages_raw[:,:,i]) for i in minus_90_idx]), c='blue', s=5)
plt.scatter(beamEnergyP, np.array([np.sum(xtcavImages_raw[:,:,i]*np.arange(xtcavImages_raw.shape[1])[np.newaxis,:])/np.sum(xtcavImages_raw[:,:,i]) for i in plus_90_idx]), c='red', s=5)
plt.scatter(beamEnergyO, np.array([np.sum(xtcavImages_raw[:,:,i]*np.arange(xtcavImages_raw.shape[1])[np.newaxis,:])/np.sum(xtcavImages_raw[:,:,i]) for i in off_idx]), c='green', s=5)

plt.xlabel('BPMS_LI11_333_X')
plt.ylabel('LPS Image Center of Mass [pix]')
plt.title('LPS Image Center of Mass vs Beam Energy PV, RAW Image')
plt.legend(['-90 deg','+90 deg','0 deg'])
plt.grid()
plt.show()

###  Centroid Correction (Optional)

In [None]:
xtcavImages, horz_proj, LPSImage, centroid_corrections = apply_centroid_correction(xtcavImages_centroid_uncorrected, off_idx)
print(LPSImage.shape)

### Current Profile Generation

In [None]:
currentProfile_all = [] 

# Process all degree shots
for ij in range(len(all_idx)):
    idx = all_idx[ij]
    streakedProfile = horz_proj[:,idx]

    tvar = np.arange(1, len(streakedProfile) + 1) * xtcalibrationfactor
    tvar = tvar - np.median(tvar)  # Center around zero

    prefactor = charge[0, idx] / np.trapz(streakedProfile, tvar)

    currentProfile = 1e-3 * streakedProfile * prefactor  # Convert to kA
    currentProfile_all.append(currentProfile)
    
currentProfile_all = np.array(currentProfile_all)

### Check Images

In [None]:
from ipywidgets import interact, IntSlider
# Find the first shot where tcav is at -90, 0 and +90 deg
def plot_sample_images(idx):
    near_minus_90_idx = np.where((xtcavPhase >= -90.55) & (xtcavPhase <= -89.55))[0][idx]
    near_plus_90_idx = np.where((xtcavPhase >= 89.55) & (xtcavPhase <= 90.55))[0][idx]
    zero_idx = np.where(xtcavPhase == 0)[0][idx]

    sample_image_indices = [near_minus_90_idx, zero_idx, near_plus_90_idx]
    plot_titles = ['Tcav phase -90 deg', '0 deg', '+90 deg']


    # Define the x and yrange for cropping the image; Need to automate this
    # figure;imagesc(sampleImage)

    xrange = 100
    yrange = xrange

    fig, axs = plt.subplots(1, 4, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1, 1, 0.1]})
    fig.suptitle(f'TCAV images RAW DAQ {experiment} - {runname}', fontsize=14)

    for i, idx in enumerate(sample_image_indices):
        if idx is None:
            continue

        sample_image = xtcavImages_raw[:, :, idx]

        axs[i].imshow(sample_image, cmap='jet', aspect='auto')
        axs[i].set_title(plot_titles[i])

    # Colorbar, top right corner, horizontal
    cbar = fig.colorbar(axs[2].images[0], cax = axs[3], orientation='vertical', fraction=0.05, pad=0.2)
    cbar.set_label('Intensity [a.u.]')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    fig, axs = plt.subplots(2, 3, figsize=(12, 6))
    fig.suptitle(f'TCAV images before centroid correction DAQ {experiment} - {runname}', fontsize=14)

    for i, idx in enumerate(sample_image_indices):
        if idx is None:
            continue

        sample_image = xtcavImages_centroid_uncorrected[:, :, idx]
        horz_proj = np.sum(sample_image, axis=0)

        axs[0, i].imshow(sample_image, cmap='jet', aspect='auto')
        axs[0, i].set_title(plot_titles[i])

        axs[1, i].plot(horz_proj)
        axs[1, i].set_title("Horizontal Projection")
        #If i==1, the center plot, also plot centroid_corrections on the 2d image
        if i==1:
            for row in range(sample_image.shape[0]):
                shift = centroid_corrections[row]
                # Plot a dot at (shift, row)
                axs[0, i].plot(xrange - shift, row, 'wo', markersize=1)

            # Draw a vertical line at the center of mass x
            center_of_mass_x = np.sum(horz_proj * np.arange(horz_proj.shape[0])) / np.sum(horz_proj)
            axs[0, i].axvline(center_of_mass_x, color='w', linestyle='--')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    fig, axs = plt.subplots(2, 3, figsize=(12, 6))
    fig.suptitle(f'TCAV images after centroid correction DAQ {experiment} - {runname}', fontsize=14)

    for i, idx in enumerate(sample_image_indices):
        if idx is None:
            continue

        sample_image = xtcavImages[:, :, idx]
        horz_proj = np.sum(sample_image, axis=0)

        axs[0, i].imshow(sample_image, cmap='jet', aspect='auto')
        axs[0, i].set_title(plot_titles[i])

        axs[1, i].plot(horz_proj)
        axs[1, i].set_title("Horizontal Projection")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
interact(plot_sample_images, idx=IntSlider(min=0, max=40, step=1, value=0))

### Filter Good Shots

In [None]:
# Filter out "bad" shots with Bi-Gaussian fit 
def bi_gaussian(x, A1, mu1, sigma1, A2, mu2, sigma2):
    return (A1 * np.exp(-(x - mu1)**2 / (2 * sigma1**2)) +
            A2 * np.exp(-(x - mu2)**2 / (2 * sigma2**2)))

amp1 = []
amp2 = []
mu1 = []
mu2 = []
R_squared = []

for ij in range(len(all_idx)):
    y = currentProfile_all[ij, :]
    x = np.arange(len(y))

    # Initial guess: [A1, mu1, sigma1, A2, mu2, sigma2]
    if xtcavPhase[all_idx][ij] < 0:
        initial_guess = [np.max(y), 100, 4, np.max(y)*0.1, 60 + ij*0.15, 4]
    elif xtcavPhase[all_idx][ij] > 0:
        initial_guess = [np.max(y), 100, 4, np.max(y)*0.1, 60, 4]
    
    try:
        popt, pcov = curve_fit(bi_gaussian, x, y, p0=initial_guess, maxfev=5000)
    except RuntimeError:
        amp1.append(np.nan)
        R_squared.append(np.nan)
        continue

    # Extract parameters
    A1, mu1_val, sig1, A2, mu2_val, sig2 = popt
    amp1.append(A1)
    amp2.append(A2)
    mu1.append(mu1_val)
    mu2.append(mu2_val)

    # Evaluate fit
    y_fit = bi_gaussian(x, *popt)
    SST = np.sum((y - np.mean(y))**2)
    SSR = np.sum((y - y_fit)**2)
    R_squared.append(1 - SSR / SST)

# Convert results to arrays
amp1 = np.array(amp1)
R_squared = np.array(R_squared)
# set requirements for "good" shots. For xtcavPhase>0, we want larger (A1) peak at larger x (mu1).
# For xtcavPhase<0, we want larger (A2) peak at smaller x (mu2).
goodShots = np.where((R_squared > 0.97) & (amp1 < 50))[0]
#goodShots_twobunch_tcav = np.where((R_squared > 0.97) & (amp1 < 50) & ((mu1 > mu2) & (amp1 < amp2)))[0]

In [None]:
# Plot some good shots xtcavOffShots
idx = 5
fig, (ax1) = plt.subplots(1,1,figsize=(9, 6))
im1 = ax1.imshow(xtcavImages[:,:,minus_90_idx[idx]], cmap = "jet",aspect='auto')
# ax1.suptitle(f"Current Profile Index: {idx}")
cbar1 = plt.colorbar(im1, ax=ax1)
cbar1.set_label("Charge(a.u.)")

### MLP 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

index = np.sort(all_idx[goodShots])
images = LPSImage[all_idx,:][goodShots,:]
steps = data_struct.scalars.steps[DTOTR2commonind]
predictor = np.vstack(bsaScalarData[:,goodShots]).T

x_scaler = MinMaxScaler()
iz_scaler = MinMaxScaler()
x_scaled = x_scaler.fit_transform(predictor)
Iz_scaled = iz_scaler.fit_transform(images)

# 80/20 train-test split
x_train_full, x_test_scaled, Iz_train_full, Iz_test_scaled, ntrain, ntest = train_test_split(
    x_scaled, Iz_scaled, index, test_size=0.2, random_state = 42)

# 20% validation split 
x_train_scaled, x_validation, Iz_train_scaled, y_validation = train_test_split(
    x_train_full, Iz_train_full, test_size=0.2, random_state = 42)

# compress pixels 
pca = PCA(n_components=100)
compressed_targets = pca.fit_transform(Iz_train_scaled) 
print(Iz_train_scaled.shape, compressed_targets.shape)
y_validation = pca.transform(y_validation)

# Convert to PyTorch tensors
X_train = torch.tensor(x_train_scaled, dtype=torch.float32)
x_validation = torch.tensor(x_validation, dtype=torch.float32)
X_test = torch.tensor(x_test_scaled, dtype=torch.float32)
Y_train = torch.tensor(compressed_targets, dtype=torch.float32)
y_validation = torch.tensor(y_validation, dtype=torch.float32)
Y_test = torch.tensor(Iz_test_scaled, dtype=torch.float32)

train_ds = TensorDataset(X_train, Y_train)
train_dl = DataLoader(train_ds, batch_size=24, shuffle=True)


In [None]:
import time

# Define MLP structure
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_dim, 1000),
            nn.ReLU(),
            nn.Linear(1000,500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, out_dim)
        )
    def forward(self, x):
        return self.model(x)

model = MLP(X_train.shape[1], Y_train.shape[1])
optimizer = optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999))
loss_fn = nn.L1Loss()

# Define custom weighted MSE loss function 
def custom_loss( y_pred,y_true): 
    mse = (y_true - y_pred)**2
    weights = 1 + 0.7*((y_true < 0.2)|(y_true > 0.8)).float()
    return torch.mean(weights*mse)

# Training loop 
n_epochs = 200
patience = 25
best_val_loss = float('inf')
early_stop_counter = 0

t0 = time.time()

# Fit the nn model on the training set
train_losses = []
val_losses = []

for epoch in range(n_epochs):
    model.train()
    train_loss = 0
    for xb, yb in train_dl:
        pred = model(xb)
        loss = custom_loss(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_dl)
    train_losses.append(avg_train_loss)

    # Validation loss
    model.eval()
    with torch.no_grad():
        val_pred = model(x_validation)
        val_loss = custom_loss(val_pred, y_validation).item()
        val_losses.append(val_loss)

    # Early stopping logic
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict()
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            break
    
model.load_state_dict(best_model_state)
    
# Evaluate model
model.eval()
with torch.no_grad():
    pred_train_scaled = model(X_train).numpy()
    pred_test_scaled = model(X_test).numpy()

# Inverse transform predictions
pred_train_full = iz_scaler.inverse_transform(pca.inverse_transform(pred_train_scaled))
pred_test_full = iz_scaler.inverse_transform(pca.inverse_transform(pred_test_scaled))
Iz_train_true = iz_scaler.inverse_transform(Iz_train_scaled)
Iz_test_true = iz_scaler.inverse_transform(Iz_test_scaled)
elapsed = time.time() - t0
print("Elapsed time [mins] = {:.1f} ".format(elapsed/60))

# Compute R²
def r2_score(true, pred):
    RSS = np.sum((true - pred)**2)
    TSS = np.sum((true - np.mean(true))**2)
    return 1 - RSS / TSS if TSS != 0 else s0

print("Train R²: {:.2f} %".format(r2_score(Iz_train_true.ravel(), pred_train_full.ravel()) * 100))
print("Test R²: {:.2f} %".format(r2_score(Iz_test_true.ravel(), pred_test_full.ravel()) * 100))

# Plot histogram of R² values for each test sample
r2_values = [r2_score(Iz_test_true.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,i], pred_test_full.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,i]) for i in range(Iz_test_true.shape[0])]
# Throw away values outside 0 to 1, and count the number of throws
r2_values_new = [r2 for r2 in r2_values if 0 <= r2 <= 1]
num_throws = len(r2_values) - len(r2_values_new)
plt.figure(figsize=(8, 5))
plt.hist(r2_values_new, bins=20, color='skyblue', edgecolor='black')
plt.title('Histogram of R² Values for Test Samples')
plt.xlabel('R² Value')
plt.ylabel(f'Plotted Samples: {len(r2_values) - num_throws} / Total Samples: {len(r2_values)}')
plt.show()


### Effect of PCA (before MLP)

In [None]:
idx = 15
fig, (ax2, ax3, cx2) = plt.subplots(1,3,figsize=(15, 3), gridspec_kw={'width_ratios': [1, 1, 0.02]})
before_pca_image = Iz_test_true.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx]
#Flip in y direction for proper visualization
im2 = ax2.imshow(np.flip(before_pca_image, axis=0), cmap = "jet",aspect='auto', vmin = 0, vmax = 400)
# ax2.suptitle(f"Current Profile Index: {idx}")
ax2.set(ylabel="y [pix]")
ax2.set(xlabel = "Time [fs]")
ax2.set_title('Before PCA', fontsize = 12)
# ax2.set(title = "True", fontsize = 2)
ax2.set(xlim = (0,2*xrange))
ax2.set(ylim= (0,2*yrange))

after_pca_image = pca.inverse_transform(pca.transform(before_pca_image.flatten()[np.newaxis,:])).reshape(2*yrange,2*xrange)

im3 = ax3.imshow(np.flip(after_pca_image, axis=0), cmap = "jet",aspect='auto', vmin = 0, vmax = 400)
ax3.set(xlabel = "Time [fs]")
ax3.set(ylabel = "y [pix]")
ax3.set_title('After PCA', fontsize = 12)
ax3.set(xlim = (0,2*xrange))
ax3.set(ylim= (0,2*yrange))
fig.colorbar((im2), cax=cx2, format='%.3g')
fig.subplots_adjust(wspace=0.8)
#fig.tight_layout()


In [None]:
from ipywidgets import interact, IntSlider
def plot_xtcav_image_pred(idx):
    fig, (ax1, ax2, cx1) = plt.subplots(1,3,figsize=(10, 3), gridspec_kw={'width_ratios': [1, 1, 0.02]})
    im1 = ax1.imshow(np.flip(Iz_test_true.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], axis=0), cmap = "jet",aspect='auto', vmin = 0, vmax = 400)
   
    # ax1.suptitle(f"Current Profile Index: {idx}")
    ax1.set(ylabel="y [pix]")
    ax1.set(xlabel = "Time [fs]")
    ax1.set(title = f"True(Shot Number: {ntest[idx]})")
    ax1.set(xlim = (0,2*xrange))
    ax1.set(ylim= (0,2*yrange))

    im2 = ax2.imshow(np.flip(pred_test_full.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], axis=0), cmap = "jet",aspect='auto',vmin = 0, vmax = 400)
    ax2.set(xlabel = "Time [fs]")
    ax2.set(ylabel = "y [pix]")
    ax2.set(title = "Prediction")
    ax2.set(xlim = (0,2*xrange))
    ax2.set(ylim= (0,2*yrange))
    cbar = fig.colorbar(im1, cax=cx1, fraction=0.16, pad=0.04)
    # cbar.set_label("Current [arb. units]")
    plt.subplots_adjust(wspace=0.4)
    # plt.tight_layout()
    # fig.show()
    # Also plot R² value for this index
    r2_val = r2_score(Iz_test_true.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx], pred_test_full.T.reshape(2*yrange,2*xrange,Iz_test_true.shape[0])[:,:,idx])
    plt.suptitle(f'R² Value: {r2_val:.4f}', fontsize=7)

# Create slider
interact(plot_xtcav_image_pred, idx=IntSlider(min=0, max=pred_test_full.shape[0]-1, step=1, value=0))

In [None]:
raise ImportError("Stop here after training and evaluating the model.")

### Load From A Different Run For Prediction

In [None]:
#Load from a different run for prediction
experiment_new='E300'
runname_new = '12431'
dataloc_new = '../../data/raw/' + experiment_new + '/' + experiment_new + '_' + runname_new + '/' + experiment_new + '_'  +runname_new + '.mat'
mat_new = loadmat(dataloc_new,struct_as_record=False, squeeze_me=True)
data_struct_new = mat_new['data_struct']
# Extract current profiles and 2D LPS images
xtcavImages_list_newRun = []
xtcavImages_list_raw_newRun = []
horz_proj_list_newRun = []
LPSImage_newRun = []
step_list = [1,2,3,4,5]
xtcavImages_centroid_uncorrected_newRun, xtcavImages_raw_newRun, horz_proj_newRun, LPSImage_newRun = extract_processed_images(data_struct_new, experiment_new, xrange, yrange, hotPixThreshold, sigma, threshold, step_list)
print(LPSImage_newRun.shape)
DTOTR2commonind_newRun = commonIndexFromSteps(data_struct_new, step_list)
bsaScalarData_newRun, bsaVars_newRun = extractDAQBSAScalars(data_struct_new, DTOTR2commonind_newRun)
# Process all degree shots

# Take only if steps in step_range
xtcavAmpl_newRun_idx = next(i for i, var in enumerate(bsaVars_newRun) if 'TCAV_LI20_2400_A' in var)
xtcavAmpl_newRun = bsaScalarData_newRun[xtcavAmpl_newRun_idx, :]
xtcavPhase_newRun_idx = next(i for i, var in enumerate(bsaVars_newRun) if 'TCAV_LI20_2400_P' in var)
xtcavPhase_newRun = bsaScalarData_newRun[xtcavPhase_newRun_idx, :]
xtcavPhase_newRun[xtcavAmpl_newRun<0.1] = 0 #Set this for ease of plotting
isChargePV_newRun = [bool(re.search(r'TORO_LI20_2452_TMIT', pv)) for pv in bsaVars_newRun]
pvidx_newRun = [i for i, val in enumerate(isChargePV_newRun) if val]
charge_newRun = bsaScalarData_newRun[pvidx_newRun, :] * 1.6e-19  # in C 
minus_90_idx_newRun = np.where((xtcavPhase_newRun >= -91) & (xtcavPhase_newRun <= -89))[0]
plus_90_idx_newRun = np.where((xtcavPhase_newRun >= 89) & (xtcavPhase_newRun <= 91))[0]
all_idx_newRun = np.append(minus_90_idx_newRun,plus_90_idx_newRun)
# Include zero phase shots as well
off_idx_newRun = np.where(xtcavPhase_newRun == 0)[0]
#Apply centroid correction
xtcavImages_newRun, horz_proj_newRun, LPSImage_newRun, centroid_corrections_newRun = apply_centroid_correction(xtcavImages_centroid_uncorrected_newRun, off_idx_newRun)
# Extract current profiles
currentProfile_all_newRun = [] 
for ij in range(len(all_idx_newRun)):
    idx = all_idx_newRun[ij]
    streakedProfile = horz_proj_newRun[:,idx]

    tvar = np.arange(1, len(streakedProfile) + 1) * xtcalibrationfactor
    tvar = tvar - np.median(tvar)  # Center around zero

    prefactor = charge_newRun[0, idx] / np.trapz(streakedProfile, tvar)

    currentProfile = 1e-3 * streakedProfile * prefactor  # Convert to kA
    currentProfile_all_newRun.append(currentProfile)
currentProfile_all_newRun = np.array(currentProfile_all_newRun)



In [None]:
# Prepare input for prediction
index_newRun = np.sort(all_idx_newRun)
images_newRun = LPSImage_newRun[all_idx_newRun,:]
predictor_newRun = np.vstack((bsaScalarData_newRun[:,all_idx_newRun], data_struct_new.scalars.steps[all_idx_newRun])).T
#x_scaled_newRun = x_scaler.transform(predictor_newRun)
# Check if scaler worked properly. Values should be between 0 and 1
# Clamp values outside 0-1 range
# x_scaled_newRun = np.clip(x_scaled_newRun, 0, 1)
# Create new MiniMaxScaler for targets to avoid issues with inverse transform
x_scaler_newRun = MinMaxScaler()
x_scaled_newRun = x_scaler_newRun.fit_transform(predictor_newRun)
X_newRun = torch.tensor(x_scaled_newRun, dtype=torch.float32)


print("Min and Max of scaled inputs for new run:", np.min(x_scaled_newRun), np.max(x_scaled_newRun))
# Make prediction
model.eval()
with torch.no_grad():
    pred_newRun_scaled = model(X_newRun).numpy()
# Inverse transform predictions
pred_newRun_full = iz_scaler.inverse_transform(pca.inverse_transform(pred_newRun_scaled))

In [None]:
# Find the first shot where tcav is at -90, 0 and +90 deg
def plot_sample_images(idxz):
    near_minus_90_idx = np.where((xtcavPhase_newRun >= -90.55) & (xtcavPhase_newRun <= -89.55))[0][idxz]
    near_plus_90_idx = np.where((xtcavPhase_newRun >= 89.55) & (xtcavPhase_newRun <= 90.55))[0][idxz]
    zero_idx = np.where(xtcavPhase_newRun == 0)[0][idxz]

    sample_image_indices = [near_minus_90_idx, zero_idx, near_plus_90_idx]
    plot_titles = ['Tcav phase -90 deg', '0 deg', '+90 deg']


    # Define the x and yrange for cropping the image; Need to automate this
    # figure;imagesc(sampleImage)

    xrange = 100
    yrange = xrange

    fig, axs = plt.subplots(1, 4, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1, 1, 0.1]})
    fig.suptitle(f'TCAV images RAW DAQ {experiment_new} - {runname_new}', fontsize=14)

    for i, idx in enumerate(sample_image_indices):
        if idx is None:
            continue

        sample_image = xtcavImages_raw_newRun[:, :, idx]

        axs[i].imshow(sample_image, cmap='jet', aspect='auto')
        axs[i].set_title(plot_titles[i])

    # Colorbar, top right corner, horizontal
    cbar = fig.colorbar(axs[2].images[0], cax = axs[3], orientation='vertical', fraction=0.05, pad=0.2)
    cbar.set_label('Intensity [a.u.]')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    fig, axs = plt.subplots(2, 3, figsize=(12, 6))
    fig.suptitle(f'TCAV images before centroid correction DAQ {experiment_new} - {runname_new}', fontsize=14)

    for i, idx in enumerate(sample_image_indices):
        if idx is None:
            continue

        sample_image = xtcavImages_centroid_uncorrected_newRun[:, :, idx]
        horz_proj = np.sum(sample_image, axis=0)

        axs[0, i].imshow(sample_image, cmap='jet', aspect='auto')
        axs[0, i].set_title(plot_titles[i])

        axs[1, i].plot(horz_proj)
        axs[1, i].set_title("Horizontal Projection")
        #If i==1, the center plot, also plot centroid_corrections on the 2d image
        if i==1:
            for row in range(sample_image.shape[0]):
                shift = centroid_corrections_newRun[row]
                # Plot a dot at (shift, row)
                axs[0, i].plot(xrange - shift, row, 'wo', markersize=1)

            # Draw a vertical line at the center of mass x
            center_of_mass_x = np.sum(horz_proj * np.arange(horz_proj.shape[0])) / np.sum(horz_proj)
            axs[0, i].axvline(center_of_mass_x, color='w', linestyle='--')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    fig, axs = plt.subplots(2, 3, figsize=(12, 6))
    fig.suptitle(f'TCAV images after centroid correction DAQ {experiment_new} - {runname_new}', fontsize=14)

    for i, idx in enumerate(sample_image_indices):
        if idx is None:
            continue

        sample_image = xtcavImages_newRun[:, :, idx]
        horz_proj = np.sum(sample_image, axis=0)

        axs[0, i].imshow(sample_image, cmap='jet', aspect='auto')
        axs[0, i].set_title(plot_titles[i])

        axs[1, i].plot(horz_proj)
        axs[1, i].set_title("Horizontal Projection")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
interact(plot_sample_images, idxz=IntSlider(min=0, max=40, step=1, value=0))

In [None]:
# Plot R² histogram for new run
r2_values_newRun = [r2_score(images_newRun.T.reshape(2*yrange,2*xrange,images_newRun.shape[0])[:,:,i], pred_newRun_full.T.reshape(2*yrange,2*xrange,images_newRun.shape[0])[:,:,i]) for i in range(images_newRun.shape[0])]
# Throw away values outside 0 to 1, and count the number of throws
r2_values_newRun_filtered = [r2 for r2 in r2_values_newRun if 0 <= r2 <= 1]
num_throws_newRun = len(r2_values_newRun) - len(r2_values_newRun_filtered)
plt.figure(figsize=(8, 5))
plt.hist(r2_values_newRun_filtered, bins=20, color='lightgreen', edgecolor='black')
plt.title('Histogram of R² Values for New Run Samples')
plt.xlabel('R² Value')
plt.ylabel(f'Plotted Samples: {len(r2_values_newRun) - num_throws_newRun} / Total Samples: {len(r2_values_newRun)}')
plt.show()

In [None]:
from ipywidgets import interact, IntSlider
# Plot some prediction results from new run
# Interactive plot
def plot_xtcav_image_newRun(idx):
    fig, (ax1, ax2, cx1) = plt.subplots(1,3,figsize=(10, 3), gridspec_kw={'width_ratios': [1, 1, 0.02]})
    im1 = ax1.imshow(np.flip(images_newRun.T.reshape(2*yrange,2*xrange,images_newRun.shape[0])[:,:,idx], axis=0), cmap = "jet",aspect='auto', vmin = 0, vmax = 400)
    # ax1.suptitle(f"Current Profile Index: {idx}")
    ax1.set(ylabel="y [pix]")
    ax1.set(xlabel = "Time [fs]")
    ax1.set(title = f"True(Shot Number: {index_newRun[idx]})")
    ax1.set(xlim = (0,2*xrange))
    ax1.set(ylim= (0,2*yrange))
    im2 = ax2.imshow(np.flip(pred_newRun_full.T.reshape(2*yrange,2*xrange,images_newRun.shape[0])[:,:,idx], axis=0), cmap = "jet",aspect='auto',vmin = 0, vmax = 400)
    ax2.set(xlabel = "Time [fs]")
    ax2.set(ylabel = "y [pix]")
    ax2.set(title = "Prediction")
    ax2.set(xlim = (0,2*xrange))
    ax2.set(ylim= (0,2*yrange))
    cbar = fig.colorbar(im1, cax=cx1, fraction=0.16, pad=0.04)
    # cbar.set_label("Current [arb. units]")
    #Also print corresponding xtcavPhase_newRun_idx and xtcavAmpl_newRun_idx
    xtcavPhase_corr = xtcavPhase_newRun[idx]
    xtcavAmpl_corr = xtcavAmpl_newRun[idx]
    fig.suptitle(f'TCAV Phase: {xtcavPhase_corr:.2f} deg, TCAV Amplitude: {xtcavAmpl_corr:.2f} MV', fontsize=7)

    plt.subplots_adjust(wspace=0.4)
    plt.show()
# Create slider
interact(plot_xtcav_image_newRun, idx=IntSlider(min=0, max=pred_newRun_full.shape[0]-1, step=1, value=0))

### BSA Variable Debugging

In [None]:
import numpy as np
import pandas as pd # For easy table printing

# 1. Calculate statistics for the new run
avg_new = np.mean(bsaScalarData_newRun, axis=1)
std_new = np.std(bsaScalarData_newRun, axis=1)

# 2. Calculate statistics for the old run
avg_old = np.mean(bsaScalarData, axis=1)
std_old = np.std(bsaScalarData, axis=1)

sigma = (avg_new-avg_old) / np.sqrt(std_new**2 + std_old**2)

# 3. Create a pandas DataFrame for structured output
data = {
    'New Run Avg': avg_new,
    'New Run StdDev': std_new,
    'Old Run Avg': avg_old,
    'Old Run StdDev': std_old,
    'Sigma Difference': sigma
}

comparison_df = pd.DataFrame(data, index=bsaVars_newRun)

# Format the output for readability (e.g., 3 decimal places)
comparison_df = comparison_df.round(34)

# 4. Print the comparison table
print("Comparison of BSA Variable Statistics (New Run vs. Old Run)\n")
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 100000) # Use a large number like 100000
print(comparison_df)