In [None]:
import pytrinamic
from pytrinamic.connections import ConnectionManager
from pytrinamic.modules import TMCM6110
import time
import scipy.io
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import trange
from itertools import chain

# Motor Control 

In [None]:
class XYZ():
    def __init__(self):
        self.connectionManager = ConnectionManager()
        self.interface = self.connectionManager.connect()

        # Create an instance of the TMCM_6110 class
        self.module = TMCM6110(self.interface)


        self.motor_0 =  self.module.motors[0]
        self.motor_1 =  self.module.motors[1]
        self.motor_2 =  self.module.motors[2]
        print("Preparing parameters")
        
    def XYZ_setup(self,max_current=500,standby_current=200,boost_current=0,velocity=1500,acceleration=1000,position=0):
        motors_list = []
        for i in [0,1,2]:
            motor_name = f"motor_{i}"
            motor = getattr(self, motor_name)
            # Now you can use 'motor' as a reference to self.motor_0, self.motor_1, etc.
            motor.drive_settings.max_current=max_current
            motor.drive_settings.standby_current=standby_current
            motor.drive_settings.boost_current=boost_current
            motor.drive_settings.microstep_resolution = motor.ENUM.microstep_resolution_256_microsteps
            motor.max_acceleration=acceleration
            motor.max_velocity=velocity
#             motor.actual_position=position
            print(motor)
            motors_list.append(motor)
        return motors_list[2],motors_list[1],motors_list[0],self.interface
        
  

# Capturing functions


In [None]:
def capture_location(scope, pt_exp, exp_name, x, y, num=1000):
    """
    Captures and stores traces for a given (x, y) location.

    Parameters:
    scope: Object responsible for capturing traces.
    pt_exp: Experiment dataset handler for plaintexts and keys.
    exp_name: Dataset handler where captured traces will be stored.
    x, y: Coordinates representing the capture location.
    num (int, optional): Number of traces to capture (default is 1000).
    """

    # Retrieve key, random plaintext, and fixed plaintext datasets
    keys_pt = pt_exp.get_dataset("keys").read_data(0, num)
    random_pt = pt_exp.get_dataset("plaintexts").read_data(0, num)
    fixed_pt = pt_exp.get_dataset("fixed_pt").read_data(0, num)

    # Capture traces using test vector leakage assessment (TVLA) method
    f, r = scope.capture_traces_tvla(num, keys_pt, fixed_pt, keys_pt, random_pt)

    # Store captured traces for the given location
    print("Storing for location: " + str(x) + "_" + str(y))
    exp_name.add_dataset("fixed_" + str(x) + "_" + str(y), f, datatype="float32")
    exp_name.add_dataset("random_" + str(x) + "_" + str(y), r, datatype="float32")

    print("Traces stored")

    return None


def Grid_Tracing_scapegoat(X_range, Y_range, X_number_of_step, Y_number_of_step, X, Y, Z, interface, scope, pt_exp, exp_store, number_of_traces):
    """
    Performs grid-based scanning and captures traces at each step.

    Parameters:
    X_range, Y_range: Step sizes for movement in X and Y directions.
    X_number_of_step, Y_number_of_step: Number of steps to take in X and Y directions.
    X, Y, Z: Actuators controlling movement along respective axes.
    interface: Communication interface for device control.
    scope: Object responsible for capturing traces.
    pt_exp: Experiment dataset handler for plaintexts and keys.
    exp_store: Experiment object to store captured traces.
    number_of_traces: Number of traces to capture at each grid point.
    """

    cordinate_traces = {}  # Dictionary to store traces at different coordinates
    X_moment = 0
    Y_moment = 0

    # Capture initial location traces
    capture_location(scope, pt_exp, exp_store, X_moment, Y_moment, number_of_traces)

    # Store initial positions
    X_start_position = X.get_actual_position()
    Y_start_position = Y.get_actual_position()
    print(f"Starting Position - ({X_start_position}, {Y_start_position})")

    # Perform scanning along the Y-axis
    while Y_moment <= Y_number_of_step:
        X_initial_position = X.get_actual_position()
        Y_initial_position = Y.get_actual_position()

        # Move in the positive X direction
        for _ in range(X_number_of_step):
            X.move_by(X_range)
            print(f'Moving X to {X_initial_position + X_range}')
            while X.get_actual_position() != X_initial_position + X_range:
                time.sleep(0.1)  # Wait until movement is complete
            X_moment += 1
            capture_location(scope, pt_exp, exp_store, X_moment, Y_moment, number_of_traces)
            X_initial_position = X.get_actual_position()

        if Y_moment == Y_number_of_step:
            break  # Stop if the last Y step is reached

        # Move in the positive Y direction
        Y.move_by(Y_range)
        Y_moment += 1
        print(f'Moving Y to {Y_initial_position + Y_range}')
        while Y.get_actual_position() != Y_initial_position + Y_range:
            time.sleep(0.1)  # Wait until movement is complete
        capture_location(scope, pt_exp, exp_store, X_moment, Y_moment, number_of_traces)
        Y_initial_position = Y.get_actual_position()

        # Move in the negative X direction
        for _ in range(X_number_of_step):
            X.move_by(-X_range)
            print(f'Moving X to {X_initial_position - X_range}')
            while X.get_actual_position() != X_initial_position - X_range:
                time.sleep(0.1)  # Wait until movement is complete
            X_moment -= 1
            capture_location(scope, pt_exp, exp_store, X_moment, Y_moment, number_of_traces)
            X_initial_position = X.get_actual_position()

        if Y_moment == Y_number_of_step:
            break  # Stop if the last Y step is reached

        # Move in the positive Y direction again
        Y.move_by(Y_range)
        Y_moment += 1
        print(f'Moving Y to {Y_initial_position + Y_range}')
        while Y.get_actual_position() != Y_initial_position + Y_range:
            time.sleep(0.1)  # Wait until movement is complete
        capture_location(scope, pt_exp, exp_store, X_moment, Y_moment, number_of_traces)
        Y_initial_position = Y.get_actual_position()

    # Return to the starting position
    X.move_to(X_start_position)
    while X.get_actual_position() != X_start_position:
        time.sleep(0.1)  # Wait until movement is complete

    Y.move_to(Y_start_position)
    while Y.get_actual_position() != Y_start_position:
        time.sleep(0.1)  # Wait until movement is complete

    print(f"Final Position - ({X.get_actual_position()}, {Y.get_actual_position()})")

    return None


# Metrics (one-click)

In [None]:
def plot_CEMA_heatmap(test, pt_exp, num, target_byte=0, grid_size=5):
    """
    Compute and visualize Correlation Electromagnetic Analysis (CEMA) results as a heatmap.

    Parameters:
    - test: An object that provides access to trace datasets.
    - pt_exp: Experiment dataset handler for keys and plaintexts.
    - num: Number of traces to process.
    - target_byte: Byte index for correlation analysis (default is 0).
    - grid_size: The size of the grid (default is 5x5).
    
    Returns:
    - CEMA_guesses_rotated: Rotated array of best key guesses from CEMA.
    - CEMA_values_rotated: Rotated array of maximum correlation values from CEMA.
    """

    # Retrieve keys and plaintext datasets
    keys = pt_exp.get_dataset("keys").read_data(0, num)
    plaintexts = pt_exp.get_dataset("plaintexts").read_data(0, num)

    # Initialize arrays to store CEMA results
    CEMA_values = np.zeros((grid_size, grid_size))  # Stores maximum correlation values
    CEMA_guesses = np.zeros((grid_size, grid_size))  # Stores corresponding key guesses

    # Perform CEMA analysis across the grid
    for i in range(grid_size):
        for j in trange(grid_size):
            print(f"Processing grid position ({i}, {j})")

            # Retrieve traces for the current grid position
            traces = test.get_dataset(f"random_{i}_{j}").read_data(0, num)

            # Perform CEMA to obtain correlation values and best key guess
            best_guess, max_correlation = scapegoat_cpa_byte(traces, keys, plaintexts, target_byte)

            # Store results
            CEMA_values[i, j] = max_correlation
            CEMA_guesses[i, j] = best_guess

    # Rotate the heatmap for correct visualization
    CEMA_values_rotated = np.rot90(CEMA_values, k=3)  # Rotate by 90 degrees clockwise
    CEMA_guesses_rotated = np.rot90(CEMA_guesses, k=3)

    # Create a heatmap visualization
    plt.figure(figsize=(8, 6))
    sns.heatmap(CEMA_values_rotated, annot=True, cbar=True, square=True)

    # Add labels and title
    plt.title("CEMA Heatmap")
    plt.xlabel("Grid Column (j)")
    plt.ylabel("Grid Row (i)")

    # Display the plot
    plt.show()

    return CEMA_guesses_rotated, CEMA_values_rotated



def plot_SNR_heatmap(test, pt_exp, num, target_byte=0, grid_size=5, SNR_type="BYTE"):
    """
    Compute and visualize Signal-to-Noise Ratio (SNR) results as a heatmap.

    SNR Types:
    - "BYTE": Uses all possible byte combinations.
    - "FULL": Uses all possible 16-byte key combinations.
    - "HW": Uses the Hamming weight of the byte.

    Parameters:
    - test: An object that provides access to trace datasets.
    - pt_exp: Experiment dataset handler for keys and plaintexts.
    - num: Number of traces to process.
    - target_byte: Byte index for SNR analysis (default is 0).
    - grid_size: The size of the grid (default is 5x5).
    - SNR_type: Type of SNR analysis ("BYTE", "FULL", or "HW").

    Returns:
    - SNR_values_rotated: Rotated array of maximum SNR values.
    - SNR_dB: SNR values in decibels (10 * log10 of SNR_values_rotated).
    """

    # Retrieve keys and plaintext datasets
    keys = pt_exp.get_dataset("keys").read_data(0, num)
    plaintexts = pt_exp.get_dataset("plaintexts").read_data(0, num)

    # Select the labeling method based on SNR type
    if SNR_type == "BYTE":
        labels = sbox_snr(num_traces=num, plaintexts=plaintexts, subkey_guess=keys, target_byte=target_byte)
    elif SNR_type == "HW":
        labels = leakage_model_hamming_weight_snr(num_traces=num, plaintexts=plaintexts, subkey_guess=keys, target_byte=target_byte)
    elif SNR_type == "FULL":
        labels = Sbox[plaintexts ^ keys]
        labels = labels[:, 1]  # Extract the relevant labels for SNR computation
    else:
        print("Incorrect SNR type specified.")
        return -1

    # Initialize array to store SNR values
    SNR_values = np.zeros((grid_size, grid_size))

    # Get unique label values
    labels_unique = np.unique(labels)

    # Perform SNR analysis across the grid
    for i in range(grid_size):
        for j in trange(grid_size):
            sorted_labels = {k: [] for k in labels_unique}  # Dictionary to store traces per label
            print(f"Processing grid position ({i}, {j})")

            # Retrieve traces for the current grid position
            traces = test.get_dataset(f"random_{i}_{j}").read_data(0, num)

            # Organize traces according to their label
            for index, label in enumerate(labels):
                sorted_labels[label].append(np.array(traces[index]))

            # Compute SNR for the given labels and traces
            snr_result = signal_to_noise_ratio(sorted_labels)

            # Store the maximum absolute SNR value
            SNR_values[i, j] = np.nanmax(np.abs(snr_result))

    # Rotate the heatmap for correct visualization
    SNR_values_rotated = np.rot90(SNR_values, k=3)  # Rotate by 90 degrees clockwise

    # Create a heatmap visualization
    plt.figure(figsize=(8, 6))
    sns.heatmap(SNR_values_rotated, annot=True, cbar=True, square=True)

    # Add labels and title based on SNR type
    plt.title(f"SNR Heatmap ({SNR_type} mode)")
    plt.xlabel("Grid Column (j)")
    plt.ylabel("Grid Row (i)")

    # Display the plot
    plt.show()

    # Compute SNR in decibels
    SNR_dB = 10 * np.log10(SNR_values_rotated)

    return SNR_values_rotated, SNR_dB


def plot_t_statistic_heatmap(test, grid_size=5):
    """
    Compute and visualize t-statistics as a heatmap.

    This function calculates t-statistics for each position in a grid 
    and generates a heatmap representation.

    Parameters:
    - test: An object that has a `calculate_t_test` method to compute the t-statistics.
    - grid_size: The size of the grid (default is 5x5).

    Returns:
    - t_values_rotated: Rotated array of maximum absolute t-statistics.
    """

    # Initialize the t-statistics array
    t_values = np.zeros((grid_size, grid_size))

    # Compute t-statistics for each grid position
    for i in range(grid_size):
        for j in range(grid_size):
            t_stat, t_max = test.calculate_t_test(f"fixed_{i}_{j}", f"random_{i}_{j}")
            t_values[i, j] = np.nanmax(np.abs(t_stat))  # Store the maximum absolute t-statistic

    # Rotate the heatmap for correct visualization
    t_values_rotated = np.rot90(t_values, k=3)  # Rotate by 90 degrees clockwise

    # Create a heatmap using seaborn
    plt.figure(figsize=(8, 6))
    sns.heatmap(t_values_rotated, annot=True, cbar=True, square=True)

    # Add labels and title
    plt.title("Heatmap of t-statistics")
    plt.xlabel("Grid Column (j)")
    plt.ylabel("Grid Row (i)")

    # Display the plot
    plt.show()

    return t_values_rotated


def generate_box_plots(test, pt_exp, num_list, target_byte=0, grid_size=5):
    """
    Generate box plots for SNR values at different numbers of traces.

    This function computes SNR values for various trace counts and visualizes them
    using box plots to analyze variations in SNR across different datasets.

    Parameters:
    - test: An object that provides the `plot_SNR_heatmap_byte_bp` function.
    - pt_exp: Experiment data handler.
    - num_list: A list of trace counts to evaluate.
    - target_byte: The target byte used for SNR calculation.
    - grid_size: The grid size for SNR calculations (default is 5x5).

    Returns:
    - all_values: A list containing SNR value lists for each trace count.
    """

    all_values = []  # Stores separate lists of SNR values for each trace count
    labels = []  # Labels corresponding to each dataset

    for num in num_list:
        print(f"Processing num_traces = {num}...")

        # Compute the SNR values for the given number of traces
        SNR_values_rotated = plot_SNR_heatmap_byte_bp(test, pt_exp, num, target_byte, grid_size)

        # Validate and flatten SNR values
        if isinstance(SNR_values_rotated, list):
            all_values.append(list(chain.from_iterable(SNR_values_rotated)))  # Flatten and store values
        else:
            print(f"Invalid data format for num_traces={num}: {SNR_values_rotated}")
            continue

        labels.append(f"{num} traces")  # Create labels for the box plot

    # Generate box plot if valid data is available
    if all_values:
        plt.figure(figsize=(10, 6))
        sns.boxplot(data=all_values)

        # Configure plot labels and title
        plt.xticks(ticks=range(len(num_list)), labels=labels)
        plt.xlabel("Number of Traces")
        plt.ylabel("SNR Values")
        plt.title("Box Plot of SNR Values for Different Trace Counts")

        # Display the plot
        plt.show()
    else:
        print("No valid data to plot.")

    return all_values

    

def plot_CEMA_wr(test, pt_exp, num, target_byte=0, x=0, y=0, div=10):
    """
    Generate a CPA correlation plot comparing the correct key vs. wrong keys over increasing trace counts.

    This function performs Correlation Power Analysis (CPA) across different numbers of traces, 
    visualizing how the correct key and wrong keys' correlation evolve.

    Parameters:
    - test: An object that provides trace datasets.
    - pt_exp: Experiment data handler providing plaintext and key datasets.
    - num: Total number of traces to analyze.
    - target_byte: The target byte index in the key (default is 0).
    - x, y: Grid position for selecting the dataset.
    - div: Step size for processing traces in intervals.

    Returns:
    - maxcpa_matrix: A matrix storing the maximum CPA correlation values for all 256 key guesses.
    """

    # Load key and plaintext data
    keys = pt_exp.get_dataset("keys").read_data(0, num)
    plaintexts = pt_exp.get_dataset("plaintexts").read_data(0, num)

    # Load power traces for the selected grid position
    traces = test.get_dataset(f"random_{x}_{y}").read_data(0, num)

    # Initialize a matrix to store max CPA values across different key guesses
    maxcpa_matrix = np.zeros((int(num / div), 256))

    iterations = 1
    for i in trange(1, num):
        if i % div == 0:
            for k in range(256):
                # Compute leakage model using Hamming weight
                leakage = leakage_model_hamming_weight(num_traces=i, plaintexts=plaintexts, subkey_guess=k, target_byte=target_byte)
                
                # Compute Pearson correlation
                correlation = pearson_correlation(leakage, traces[:i])

                # Store max correlation value for this key guess
                maxcpa_matrix[int(i / div), k] = np.nanmax(np.abs(correlation))
            
            iterations += 1

    # Debugging: Check matrix shape
    print("Shape of maxcpa_matrix:", maxcpa_matrix.shape)

    # Prepare x-axis values (trace count steps)
    xp = np.arange(2, min(iterations + 2, maxcpa_matrix.shape[0]))

    # Plotting
    plt.figure(figsize=(10, 6))

    if maxcpa_matrix.shape[0] > 2:
        # Plot the statistical threshold
        plt.plot(xp, (abs(4) / np.sqrt(xp * div)) * np.ones_like(xp), 
                 color="black", linestyle='dotted', linewidth=1.5, label="Threshold")

        # Plot CPA correlations for all 256 key hypotheses
        for i in range(256):
            if i == 43:  # Assuming 43 is the correct key
                plt.plot(xp, maxcpa_matrix[2:len(xp) + 2, i], color="red",
                         alpha=0.9, linewidth=1.5, label="Correct key")
            else:
                plt.plot(xp, maxcpa_matrix[2:len(xp) + 2, i], color="grey",
                         alpha=0.1, linewidth=0.5, label="Wrong keys" if i == 0 else "")

    # Configure plot labels and title
    plt.xlabel(f"No. of traces × {div}", fontsize=14)
    plt.ylabel("Max CPA Value", fontsize=14)
    plt.title(f"Correlation Power Analysis: ({x}, {y})", fontsize=18)
    plt.yticks(fontsize=12)
    plt.xticks(fontsize=12)
    plt.legend()

    # Display the plot
    plt.show()

    return maxcpa_matrix



## helper functions  

In [None]:
def scapegoat_cpa(experiment):
    """
    Perform Correlation Power Analysis (CPA) on an experiment dataset.

    Parameters:
    - experiment: An object containing datasets for traces, keys, and plaintexts.

    Returns:
    - best_guess: List of best key byte guesses (one per byte position).
    - cpa_refs: List of highest correlation values for each key byte position.
    """
    num_bytes = 16  # AES has 16 key bytes
    max_cpa = np.zeros(256)  # Store CPA values for each subkey guess
    cpa_refs = np.zeros(num_bytes)  # Highest CPA values per key byte
    best_guess = np.zeros(num_bytes, dtype=int)  # Best key guesses

    # Load datasets
    traces = experiment.get_dataset("CW_Capture_Traces").read_all()
    keys = experiment.get_dataset("CW_Capture_Keys").read_all()
    plaintexts = experiment.get_dataset("CW_Capture_Plaintexts").read_all()

    # Perform CPA attack for each byte in the key
    for byte_idx in trange(num_bytes, desc="CPA on key bytes"):
        for k in range(256):
            # Compute leakage model
            leakage = leakage_model_hamming_weight(
                num_traces=len(plaintexts),
                plaintexts=plaintexts,
                subkey_guess=k,
                target_byte=byte_idx
            )
            # Compute Pearson correlation
            correlation = pearson_correlation(leakage, traces)
            max_cpa[k] = np.nanmax(np.abs(correlation))

        # Store best guess and highest CPA value for this byte position
        best_guess[byte_idx] = np.argmax(max_cpa)
        cpa_refs[byte_idx] = np.nanmax(max_cpa)

    return best_guess, cpa_refs

def scapegoat_cpa_byte(traces, keys, plaintexts, target_byte):
    """
    Perform Correlation Power Analysis (CPA) on a specific key byte to guess the subkey.

    Parameters:
    - traces: The captured power traces.
    - keys: The actual secret keys corresponding to the traces.
    - plaintexts: The plaintexts used for the power analysis.
    - target_byte: The index of the target byte in the key to analyze.

    Returns:
    - best_guess: The best guess for the target key byte.
    - cpa_ref: The highest correlation value obtained for the target byte.
    """
    max_cpa = np.zeros(256)  # Store maximum CPA values for each possible subkey guess
    cpa_ref = 0  # Store highest correlation value for the target byte
    best_guess = 0  # Store best subkey guess for the target byte

    # Perform CPA attack for each possible subkey guess (0-255)
    for k in range(256):
        # Compute leakage model for each subkey guess
        leakage = leakage_model_hamming_weight(
            num_traces=len(plaintexts),
            plaintexts=plaintexts,
            subkey_guess=k,
            target_byte=target_byte
        )
        # Compute the Pearson correlation between the leakage and the traces
        correlation = pearson_correlation(leakage, traces)
        max_cpa[k] = np.nanmax(np.abs(correlation))  # Store the highest correlation for this guess

    # Find the best subkey guess and highest correlation value
    best_guess = np.argmax(max_cpa)
    cpa_ref = np.nanmax(max_cpa)

    return best_guess, cpa_ref

def leakage_model_hamming_weight_snr(num_traces: int, plaintexts: list | np.ndarray, subkey_guess: any, target_byte: int) -> np.ndarray:
    """
    Generates hypothetical leakage using the damming distance leakage model. In this implementation the reference state
    is the output of the sbox at index 0.

    :param num_traces: The number of traces collected when measuring the observed leakage
    :type num_traces: int
    :param plaintexts: The array of plaintexts used to collect the observed leakage
    :type plaintexts: list | np.ndarray
    :param subkey_guess: the subkey guess
    :type subkey_guess: any
    :param target_byte: the target byte of the key
    :type target_byte: int
    :return: numpy array of the hypothetical leakage
    :rtype: np.ndarray
    :Authors: Samuel Karkache (swkarkache@wpi.edu)
    """
    leakage = np.empty(num_traces, dtype=object)

    for i in range(num_traces):
        leakage[i] = bin(Sbox[subkey_guess[i][target_byte] ^ plaintexts[i][target_byte]]).count('1')

    return leakage

def sbox_snr(num_traces: int, plaintexts: list | np.ndarray, subkey_guess: any, target_byte: int) -> np.ndarray:
    """
    Generates hypothetical leakage using the damming distance leakage model. In this implementation the reference state
    is the output of the sbox at index 0.

    :param num_traces: The number of traces collected when measuring the observed leakage
    :type num_traces: int
    :param plaintexts: The array of plaintexts used to collect the observed leakage
    :type plaintexts: list | np.ndarray
    :param subkey_guess: the subkey guess
    :type subkey_guess: any
    :param target_byte: the target byte of the key
    :type target_byte: int
    :return: numpy array of the hypothetical leakage
    :rtype: np.ndarray
    :Authors: Samuel Karkache (swkarkache@wpi.edu)
    """
    leakage = np.empty(num_traces, dtype=object)

    for i in range(num_traces):
        leakage[i] = Sbox[subkey_guess[i][target_byte] ^ plaintexts[i][target_byte]]

    return leakage

def plot_heatmap(heatmap_values, text, anno=False, cba=True, squar=True):
    """
    Plots a heatmap using seaborn.

    Parameters:
    - heatmap_values: A 2D array or matrix of values to display in the heatmap.
    - text: Title for the heatmap.
    - anno: Boolean flag to display annotations on the heatmap (default is False).
    - cba: Boolean flag to display the color bar (default is True).
    - squar: Boolean flag to make the plot square-shaped (default is True).
    
    Returns:
    - None
    """
    # Create a heatmap using seaborn
    plt.figure(figsize=(8, 6))
    sns.heatmap(heatmap_values, annot=anno, cbar=cba, square=squar)

    # Adding labels and title
    plt.title(text)
    plt.xlabel("X")
    plt.ylabel("Y")

    # Show the plot
    plt.show()
    return None

    
def save_mat(matr, file_name):
    """
    Saves a given matrix to a .mat file.

    Parameters:
    - matr: The matrix to be saved (should be a numpy array).
    - file_name: The name of the file to save the matrix to (including the .mat extension).

    Returns:
    - None
    """
    # Save the matrix to a .mat file using scipy's savemat function
    scipy.io.savemat(file_name, {'matrix': matr})    
    
    return None


def save_fig(filename):
    """
    Saves the current figure to a specified file in SVG format.

    Parameters:
    - filename: The name of the file (should include the file extension, e.g., '.svg').

    Returns:
    - None
    """
    # Save the current figure to the specified file with SVG format
    plt.gcf().savefig(filename, format="svg", bbox_inches="tight")
    
    return None



def plot_SNR_heatmap_byte_bp(test, pt_exp, num, target_byte=0, grid_size=5, SNR_type="BYTE"):
    """
    Calculate the Signal-to-Noise Ratio (SNR) for a grid and plot the results.

    Parameters:
    - test: An object that handles data retrieval.
    - pt_exp: Experiment data handler for the plaintexts and keys.
    - num: The number of traces to consider.
    - target_byte: The byte to target in the S-box (default is 0).
    - grid_size: The size of the grid (default is 5x5).
    - SNR_type: The type of SNR to compute ("BYTE", "HW", or "FULL").
    
    Returns:
    - CEMA_values: The calculated SNR values for each grid location.
    """
    keys = pt_exp.get_dataset("keys").read_data(0, num)
    plaintexts = pt_exp.get_dataset("plaintexts").read_data(0, num)
    
    # Initialize the list to store SNR values
    CEMA_values = []

    # Select the appropriate SNR calculation based on the SNR_type
    if SNR_type == "BYTE":
        labels = sbox_snr(num_traces=num, plaintexts=plaintexts, subkey_guess=keys, target_byte=target_byte)
    elif SNR_type == "HW":
        labels = leakage_model_hamming_weight_snr(num_traces=num, plaintexts=plaintexts, subkey_guess=keys, target_byte=target_byte)
    elif SNR_type == "FULL":
        labels = Sbox[plaintexts ^ keys]
        labels = labels[:, 1]
    else:
        print("Incorrect SNR type")
        return -1

    # Get unique labels for the SNR calculation
    labelsUnique = np.unique(labels)

    # Calculate SNR for each grid location
    for i in range(grid_size):
        for j in trange(grid_size):
            sorted_labels = {k: [] for k in labelsUnique}  # Initialize sorted labels dictionary
            print(f"loop_{i}_{j}")

            traces = test.get_dataset(f"random_{i}_{j}").read_data(0, num)
            
            # Organize the traces based on the labels
            for index, label in enumerate(labels):
                sorted_labels[label].append(np.array(traces[index]))

            # Calculate the SNR for the current grid point
            snr_value = signal_to_noise_ratio(sorted_labels)

            # Store the SNR value
            CEMA_values.append(snr_value)

    return CEMA_values


# deprecated functions


In [None]:
      
        
# def Grid_Tracing(X_range,Y_range,X_number_of_step,Y_number_of_step,X,Y,Z,interface,sco,number_of_traces):
#     cordinate_traces={}
#     X_moment=0
#     Y_moment=0
#     traces = [capture_nopt(sco,num_of_samples=500) for i in trange(number_of_traces)]
#     cordinate = (X_moment, Y_moment)
#     cordinate_traces[cordinate] = traces
#     X_start_position=X.get_actual_position()
#     Y_start_position=Y.get_actual_position()
#     print(f"Starting Postion - ({X_start_position},{Y_start_position})")
#     while Y_moment<=Y_number_of_step:
#         X_initial_position=X.get_actual_position()
#         Y_initial_position=Y.get_actual_position()

#         for _ in range(X_number_of_step):
#             X.move_by(X_range)
#             print(f'Moving X to {X_initial_position + X_range}')

#             while X.get_actual_position() != X_initial_position + X_range:


#                 time.sleep(0.1)
#             traces = [capture_nopt(sco,num_of_samples=500) for i in trange(number_of_traces)]
#             X_moment += 1
#             cordinate = (X_moment, Y_moment)
#             cordinate_traces[cordinate] = traces
#             X_initial_position=X.get_actual_position()

#         if Y_moment==Y_number_of_step:
#             break
#         Y.move_by(Y_range)
#         Y_moment+=1
#         print(f'Moving Y to {Y_initial_position + Y_range}')
#         while Y.get_actual_position() != Y_initial_position + Y_range:

#             time.sleep(0.1)
#         cordinate = (X_moment, Y_moment)
#         cordinate_traces[cordinate] = [capture_nopt(sco,num_of_samples=500) for i in trange(number_of_traces)]
#         Y_initial_position=Y.get_actual_position()

#         for _ in range(X_number_of_step):
#             X.move_by(-X_range)
#             print(f'Moving X to {X_initial_position - X_range}')
#             while X.get_actual_position() != X_initial_position - X_range:

#                 time.sleep(0.1)
#             traces = [capture_nopt(sco,num_of_samples=500) for i in trange(number_of_traces)]
#             X_moment -= 1
#             cordinate = (X_moment, Y_moment)
#             cordinate_traces[cordinate] = traces
#             X_initial_position=X.get_actual_position()
#         if Y_moment==Y_number_of_step:
#             break
#         # Move down 1 step
#         Y.move_by(Y_range)
#         Y_moment += 1
#         print(f'Moving Y to {Y_initial_position + Y_range}')
#         while Y.get_actual_position() != Y_initial_position + Y_range:
#             time.sleep(0.1)

#         cordinate = (X_moment, Y_moment)
#         cordinate_traces[cordinate] = [capture_nopt(sco,num_of_samples=500) for i in trange(number_of_traces)]
#         Y_initial_position=Y.get_actual_position()
    
#     X.move_to(X_start_position)
#     while X.get_actual_position() != X_start_position:
#         time.sleep(0.1)
#     Y.move_to(Y_start_position)
#     while Y.get_actual_position() != Y_start_position:
#         time.sleep(0.1)
#     print(f"Final Postion - ({X.get_actual_position()},{Y.get_actual_position()})")

#     return cordinate_traces


def plot_SNR_heatmap_byte(test, pt_exp, num, target_byte=0, grid_size=5):
    """
    Calculate Signal-to-Noise Ratio (SNR) for a grid and plot the results as a heatmap.

    Parameters:
    - test: An object that provides methods to retrieve traces.
    - pt_exp: Experiment data handler containing keys and plaintexts.
    - num: Number of traces to consider.
    - target_byte: The target byte of the AES S-box (default is 0).
    - grid_size: The size of the grid (default is 5x5).
    
    Returns:
    - CEMA_values_rotated: The rotated matrix of SNR values.
    - Log-transformed CEMA values: Log-transformed SNR values for better visualization.
    """
    # Load keys and plaintexts from the experiment data
    keys = pt_exp.get_dataset("keys").read_data(0, num)
    plaintexts = pt_exp.get_dataset("plaintexts").read_data(0, num)

    # Initialize the CEMA values matrix
    CEMA_values = np.zeros((grid_size, grid_size))

    # Calculate the SNR labels based on the chosen target byte
    labels = sbox_snr(num_traces=num, plaintexts=plaintexts, subkey_guess=keys, target_byte=target_byte)
    labelsUnique = np.unique(labels)

    # Iterate over the grid and compute the SNR for each point
    for i in range(grid_size):
        for j in trange(grid_size):
            sorted_labels = {k: [] for k in labelsUnique}  # Initialize dictionary to hold sorted labels
            print(f"Processing grid point ({i}, {j})")

            traces = test.get_dataset(f"random_{i}_{j}").read_data(0, num)

            # Organize the traces based on the labels
            for index, label in enumerate(labels):
                sorted_labels[label].append(np.array(traces[index]))

            # Calculate the SNR for the current grid point
            snr_value = signal_to_noise_ratio(sorted_labels)

            # Store the maximum SNR value for this grid point
            CEMA_values[i, j] = np.nanmax(np.abs(snr_value))

    # Rotate the CEMA values matrix
    CEMA_values_rotated = np.rot90(CEMA_values, k=3)  # Rotate by 90 degrees clockwise (k=3)

    # Plot the heatmap using seaborn
    plt.figure(figsize=(8, 6))
    sns.heatmap(CEMA_values_rotated, annot=True, cbar=True, square=True)

    # Adding labels and title to the plot
    plt.title("Heatmap of SNR Byte")
    plt.xlabel("j")
    plt.ylabel("i")

    # Show the plot
    plt.show()

    # Return the rotated CEMA values and their log-transformed values
    return CEMA_values_rotated, 10 * np.log10(CEMA_values_rotated)

def plot_SNR_heatmap_hw_byte(test, pt_exp, num, target_byte=0, grid_size=5):
    """
    Calculate the Hamming Weight-based SNR for a grid and plot the results as a heatmap.

    Parameters:
    - test: An object that has a `calculate_t_test` method to compute the t-statistics.
    - pt_exp: Experiment data handler containing keys and plaintexts.
    - num: Number of traces to consider.
    - target_byte: The target byte of the AES S-box (default is 0).
    - grid_size: The size of the grid (default is 5x5).
    
    Returns:
    - CEMA_values_rotated: The rotated matrix of SNR values.
    - Log-transformed CEMA values: Log-transformed SNR values for better visualization.
    """
    # Load keys and plaintexts from the experiment data
    keys = pt_exp.get_dataset("keys").read_data(0, num)
    plaintexts = pt_exp.get_dataset("plaintexts").read_data(0, num)

    # Initialize matrices for storing SNR values and guesses
    CEMA_values = np.zeros((grid_size, grid_size))
    CEMA_guesses = np.zeros((grid_size, grid_size))  # Not used currently, can be useful for future extensions

    # Calculate SNR labels using the Hamming Weight leakage model
    labels = leakage_model_hamming_weight_snr(num_traces=num, plaintexts=plaintexts, subkey_guess=keys, target_byte=target_byte)
    labelsUnique = np.unique(labels)

    # Iterate over the grid to compute SNR for each point
    for i in range(grid_size):
        for j in trange(grid_size):
            sorted_labels = {k: [] for k in labelsUnique}  # Initialize the dictionary for sorted labels
            print(f"Processing grid point ({i}, {j})")

            traces = test.get_dataset(f"random_{i}_{j}").read_data(0, num)

            # Organize traces based on the labels
            for index, label in enumerate(labels):
                sorted_labels[label].append(np.array(traces[index]))

            # Calculate the SNR for the current grid point
            snr_value = signal_to_noise_ratio(sorted_labels)

            # Store the maximum SNR value for this grid point
            CEMA_values[i, j] = np.nanmax(np.abs(snr_value))

    # Rotate the CEMA values matrix
    CEMA_values_rotated = np.rot90(CEMA_values, k=3)  # Rotate by 90 degrees clockwise (k=3)

    # Create a heatmap using seaborn
    plt.figure(figsize=(8, 6))
    sns.heatmap(CEMA_values_rotated, annot=True, cbar=True, square=True)

    # Adding labels and title to the plot
    plt.title("Heatmap of SNR Byte HW")
    plt.xlabel("j")
    plt.ylabel("i")

    # Show the plot
    plt.show()

    # Return the rotated CEMA values and their log-transformed values
    return CEMA_values_rotated, 10 * np.log10(CEMA_values_rotated)


