In [5]:
import pennylane as qml
import jax

from jax import numpy as np
from qiskit.circuit.library import *
from qiskit import *
from qiskit.quantum_info import *
import matplotlib.pyplot as plt

from pennylane.wires import Wires
from functools import partial
from multiprocessing import Pool, cpu_count
from pathlib import Path
import pandas as pd
import pickle
import base64

import optax
from jax import config
import os
import jax
import time
import jax.numpy as jnp
from jax import jit, value_and_grad, vmap
import pennylane.numpy as pnp
#os.environ['OPENBLAS_NUM_THREADS'] = '1'
has_jax = True
diable_jit = True
config.update('jax_disable_jit', diable_jit)
#config.parse_flags_with_absl()
config.update("jax_enable_x64", True)
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'

In [3]:
import os
from pathlib import Path
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import iqr
def is_valid_pickle_file(file_path):
    """Check if a pickle file is valid."""
    try:
        if file_path.exists() and file_path.stat().st_size > 0:
            with open(file_path, 'rb') as f:
                try:
                    pickle.load(f)
                    return True
                except EOFError:
                    print(f"File {file_path} is corrupted.")
                    return False
    except Exception as e:
        print(f"An error occurred: {e}")
        return False
#print("all_gradients shape:", all_gradients.shape)
def calculate_gradient_stats(gradients):
    mean_grad = jnp.mean(gradients, axis=0)
    mean_grad_squared = jnp.mean(gradients**2, axis=0)
    var_grad = mean_grad_squared - mean_grad**2
    return mean_grad, var_grad

def calculate_inv_qfim(eigen_vals, eigen_vecs, n_params):
    cutoff_eigvals=10**-12
    eigvals_inv=np.zeros(n_params)
    #invert eigenvalues if they are above threshold, else set to zero
    for i in range(n_params):
        if(eigen_vals[i]<cutoff_eigvals):
            eigvals_inv[i]=0 #inverted eigenvalues with cutoff of smallest eigenvalues set to zero
        else:
            eigvals_inv[i]=1/eigen_vals[i]
    qfi_inv_matrix=np.dot(eigen_vecs,np.dot(np.diag(eigvals_inv),np.transpose(np.conjugate(eigen_vecs))))
    return qfi_inv_matrix

def calculate_gradient_variance(gradients):
    grad_matrix = jnp.array(gradients)
    mean_grad = jnp.mean(grad_matrix, axis=0)
    var_grad = jnp.mean((grad_matrix - mean_grad) ** 2, axis=0)
    return var_grad

def process_data_combined(df, threshold, by_test, N_R, trot, print_bool, weight_median=0.5, weight_iqr=0.5):
    """Load and process data from a pickle file."""
    
    max_eigvals = []
    trace_eigvals = []
    min_eigvals = []
    var_eigval = []
    ranks = []
    var_log_eigval, norm_trace_eigvals = [],[]
    ratios = []
    counts = []
    entropies = []
    qfim_eigval_list = []
    redundancies = []
    for fixed_params_dict in df.keys():
        for test in df[fixed_params_dict].keys():
            qfim_eigvals = df[fixed_params_dict][test]['qfim_eigvals']
            qfim_eigval_list.append(qfim_eigvals)
            n_params = len(qfim_eigvals)
            nonzero_eigvals = qfim_eigvals[qfim_eigvals > threshold]
            counts.append(len(nonzero_eigvals))
            num_nonzero_eigvals = len(nonzero_eigvals)
            ranks.append(num_nonzero_eigvals)
            ratios.append(num_nonzero_eigvals / n_params)
            redundancies.append((n_params-num_nonzero_eigvals)/n_params)
            var_eigval.append(np.var(nonzero_eigvals) if nonzero_eigvals.size > 0 else np.nan)
            var_log_eigval.append(np.var(np.log10(nonzero_eigvals)) if nonzero_eigvals.size > 0 else np.nan)
            trace_eigvals.append(np.sum(qfim_eigvals))
            norm_trace_eigvals.append(np.sum(qfim_eigvals)/len(nonzero_eigvals))
            max_eigvals.append(np.max(nonzero_eigvals) if nonzero_eigvals.size > 0 else np.nan)
            min_eigvals.append(np.min(nonzero_eigvals) if nonzero_eigvals.size > 0 else np.nan)
            # print(df[fixed_params_dict][test].keys())
            entropies.append(df[fixed_params_dict][test]['entropies'])
    
    mean_trace = np.mean(trace_eigvals)
    mean_norm_trace_eigvals = np.mean(norm_trace_eigvals)
    mean_var_eigval = np.mean(var_eigval)
    mean_var_log_eigval = np.mean(var_log_eigval)
    mean_entropy = np.mean(entropies)
    # Calculate the median and IQR
    median_trace = np.median(trace_eigvals)
    median_var_eigval = np.median(var_eigval)
    median_var_log_eigval = np.median(var_log_eigval)
    
    iqr_trace = iqr(trace_eigvals, rng= (10,90))
    iqr_var_eigval = iqr(var_eigval,  rng= (10,90))
    iqr_var_log_eigval = iqr(var_log_eigval, rng= (10,90))
    
    weighted_avg_trace = weight_median * median_trace + weight_iqr * iqr_trace
    weighted_avg_var_eigval = weight_median * median_var_eigval + weight_iqr * iqr_var_eigval
    weighted_avg_var_log_eigval = weight_median * median_var_log_eigval + weight_iqr * iqr_var_log_eigval

    
    return {
        'mean_trace_eigvals': np.mean(trace_eigvals),
        'mean_entropy':mean_entropy,
        'quantum_dim': np.mean(ranks),
        'ratios':np.mean(ratios),
        'redundancies':np.mean(redundancies),
        'mean_norm_trace_eigvals':mean_norm_trace_eigvals,
        'median_trace_eigvals': median_trace,
        'iqr_trace_eigvals': iqr_trace,
        'weighted_avg_trace_eigvals': weighted_avg_trace,
        'mean_var_eigval': np.mean(var_eigval),
        'median_var_eigval': median_var_eigval,
        'iqr_var_eigval': iqr_var_eigval,
        'weighted_avg_var_eigval': weighted_avg_var_eigval,
         'mean_var_log_eigval': np.mean(var_log_eigval),
        'median_var_log_eigval': median_var_log_eigval,
        'iqr_var_log_eigval': iqr_var_log_eigval,
        'weighted_avg_var_log_eigval': weighted_avg_var_log_eigval,
        'all_qfim_eigvals':qfim_eigval_list,
    }

threshold = 1e-14
by_test = False
N_ctrls = [1, 2]
base_state = 'GHZ_state/1xK'
base_path = Path('/Users/sophieblock/QRCCapstone/')
model_type = 'gate_model_DQFIM'
all_data = []
num_states_sampled = 50

# Function to check if a file is a valid pickle file (you might have this already)
def is_valid_pickle_file(file_path):
    return file_path.exists() and file_path.is_file()
 

In [None]:
import os
import pickle
import pandas as pd
from pathlib import Path

# Your variables
threshold = 1e-14
by_test = False
N_ctrls = [1, 2]
base_state = 'GHZ_state/1xK'  # You have a '1xK' directory in the structure
base_path = Path('/Users/sophieblock/QRCCapstone/')
model_type = 'gate_model_DQFIM'
all_data = []
num_states_sampled = 50

# Function to check if a file is a valid pickle file
def is_valid_pickle_file(file_path):
    return file_path.exists() and file_path.is_file()

# Function to check and skip hidden system files like .DS_Store
def is_hidden_file(file_name):
    return file_name.startswith('.')

for N_ctrl in N_ctrls:
    
    # Model path includes the '1xK' folder after the 'L_{num_states_sampled}'
    model_path = base_path / 'QFIM_traced_final_results' / model_type / f'Nc_{N_ctrl}' / f'L_{num_states_sampled}' / '1xK'
    
    # Print the model path being processed
    print(f"Processing model path: {model_path}")
    
    if not model_path.exists():
        print(f"Model path not found: {model_path}")
        continue
    
    for Nr in sorted(os.listdir(model_path)):
        if is_hidden_file(Nr):  # Skip hidden files
            continue
        Nr_path = model_path / Nr
        if not Nr_path.is_dir():
            print(f"Nr path not a directory: {Nr_path}")
            continue

        for trotter_step in sorted(os.listdir(Nr_path)):
            if is_hidden_file(trotter_step):  # Skip hidden files
                continue
            trotter_step_path = Nr_path / trotter_step
            if not trotter_step_path.is_dir():
                print(f"Trotter step path not a directory: {trotter_step_path}")
                continue

            # Now correctly pointing to nested L_{num_states_sampled} inside each trotter_step
            folder_gate = trotter_step_path / f'L_{num_states_sampled}'  

            data_file = folder_gate / 'data_pi_range.pickle'
            
            # Check if the pickle file exists and is valid
            if is_valid_pickle_file(data_file):
                print(f"Reading data from: {data_file}")
                with open(data_file, 'rb') as f:
                    df = pickle.load(f)
                    trotter_step_num = int(trotter_step.split('_')[-1])
                    processed_data = process_data_combined(df, threshold, by_test, int(Nr.split('_')[-1]), trotter_step_num, False)
                    processed_data.update({
                        'N_ctrl': N_ctrl,
                        'N_reserv': int(Nr.split('_')[-1]),
                        'Trotter_Step': trotter_step_num
                    })
                    all_data.append(processed_data)
            else:
                print(f"Data file not found or invalid: {data_file}")

# Convert the list of dictionaries to a DataFrame
if all_data:
    df_all = pd.DataFrame(all_data)
    print(df_all)
else:
    print("No data found.")


In [None]:
threshold = 1e-12

# Apply the condition where values below the threshold are set to 0, and then calculate the mean
df_all['avg_qfim_eigvals'] = df_all['all_qfim_eigvals'].apply(
    lambda x: np.mean(np.where(np.array(x) < threshold, 0, np.array(x)), axis=0)
)
# Choose a scaling factor alpha
alpha = 0.75  # sqrt, adjust this between 0 and 1 depending on how much scaling you want

# Apply the power law transformation
df_all['power_scaled_avg_qfim_eigvals'] = df_all['avg_qfim_eigvals'].apply(
    lambda avg_vals: np.array(avg_vals) ** alpha
)
df_all['power_scaled_avg_qfim_eigvals_nonzero'] = df_all['avg_qfim_eigvals'].apply(
    lambda avg_vals: np.array([val for val in avg_vals if val > threshold]) ** alpha
)
# # Now calculate the variance of the power-scaled values and store it in a new column
df_all['var_power_scaled_avg_qfim_eigvals'] = df_all['power_scaled_avg_qfim_eigvals'].apply(
    lambda scaled_vals: np.var(scaled_vals)
)
df_all['var_power_scaled_avg_qfim_eigvals_nonzero'] = df_all['power_scaled_avg_qfim_eigvals_nonzero'].apply(
    lambda scaled_vals: np.var(scaled_vals)
)
df_all.head()

In [None]:
df_res3 = df_all[df_all['N_reserv'] == 3]
df_res3 = df_res3[df_res3['N_ctrl'] == 2]
df_res3.head()

In [None]:
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

N_ctrl = 2

metric_keys = ['weighted_avg_trace_eigvals', 'mean_entropy',  'redundancies','var_power_scaled_avg_qfim_eigvals', 'var_power_scaled_avg_qfim_eigvals_nonzero']
metrics_info = {
    'weighted_avg_trace_eigvals': {
        'title': f'Weighted Avg ${{Tr}}[\mathcal{{Q}}]$ [$N_C = {N_ctrl}$]',
        'cbar': ''
    },
    'quantum_dim': {
        'title': f'$G_C$ (Rank) [$N_C = {N_ctrl}$]',
        'cbar': ''
    },
    'mean_entropy': {
        'title': f'VN Entropy [$N_C = {N_ctrl}$]',
        'cbar': ''
    },
    'redundancies': {
        'title': f'$R$ (Redundancy)',
        'cbar': ''
    },
    'var_power_scaled_avg_qfim_eigvals': {
        'title': f'Variance of QFIM Eigenvalues Power Scaled [$N_C = {N_ctrl}$] (iqr)',
        'cbar': ''
    },
    'var_power_scaled_avg_qfim_eigvals_nonzero': {
        'title': f'Variance of QFIM Eigenvalues (Non-Zero) Power Scaled [$N_C = {N_ctrl}$]',
        'cbar': ''
    }
}

resies = [1,2,3]
resies = [1, 2, 3,4,5,6]

for metric_key in metric_keys:
    # Retrieve title and color bar from the dictionary
    metric_title = metrics_info[metric_key]['title']
    metric_cbar = metrics_info[metric_key]['cbar']
    # Filter the data for the current N_ctrl
    df_filtered = df_all[df_all['N_ctrl'] == N_ctrl][['N_reserv', 'Trotter_Step', f'{metric_key}']]
    
    if N_ctrl == 1:
        trots = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    elif N_ctrl == 2:
        trots = [1, 3,4,5, 6,7, 8,9, 10, 12, 14, 16, 18, 20, 22, 24]
        trots = np.arange(1, 45, 1)
    else:
        trots = np.arange(1, 45, 1)
    
    df_filtered = df_filtered[df_filtered['Trotter_Step'].isin(trots)]
    # df_filtered = df_filtered[df_filtered['N_reserv'].isin(resies)]
    
    # Aggregate the data by Trotter_Step and N_reserv
    df_heatmap = df_filtered.groupby(['Trotter_Step', 'N_reserv']).agg({metric_key: 'mean'}).reset_index()
    
    # Pivot the data to create a heatmap
    heatmap_pivot = df_heatmap.pivot(index='Trotter_Step', columns='N_reserv', values=metric_key)
    
    # Plot the heatmap
    fig, ax = plt.subplots(figsize=(16, 10))

    # Disable the default color bar by setting `cbar=False`
    sns.heatmap(heatmap_pivot, ax=ax, cmap='magma', annot=True, fmt=".2f", annot_kws={"size": 22,"weight":'bold',}, cbar=False)

    # Create a divider for the existing axes instance
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.05)  # Adjust pad for proximity

    # Create a manual color bar
    cbar = fig.colorbar(ax.collections[0], cax=cax)
    cbar.set_label(metric_cbar, rotation=0, labelpad=20, fontsize=22, weight="bold")
    cbar.ax.tick_params(labelsize=20)

    ax.invert_yaxis()
    ax.set_title(metric_title, fontsize=30, pad=20)
    ax.set_ylabel('$d$', labelpad=30, fontsize=28, rotation=0)
    ax.set_xlabel('', fontsize=20)
    ax.tick_params(axis='y', labelrotation=0, labelsize=18)
    ax.tick_params(axis='x', labelsize=18)
    ax.set_xticklabels([f'$N_R = {rc+1}$' for rc in range(len(ax.get_xticklabels()))], fontweight='bold', fontsize=28, rotation=0)
    
    plt.tight_layout()
    plt.show()