In [4]:
import numpy as np
import json
import glob
import os
import matplotlib.pyplot as plt

def load_data_flexible(data_path, p_fixed_name='p_proj', p_fixed_value=0.5):
    """
    Load data flexibly - handles individual JSON files, JSON arrays, and JSON Lines format
    """
    json_data = []
    json_files = glob.glob(os.path.join(data_path, f'{p_fixed_name}{p_fixed_value}/*.json'))
    
    print(f"Found {len(json_files)} JSON files")
    
    for file in json_files:
        with open(file, 'r') as f:
            try:
                # Try to load as regular JSON first (single object or array)
                data = json.load(f)
                if isinstance(data, list):
                    # If it's an array of results, extend the main list
                    json_data.extend(data)
                else:
                    # Single result
                    json_data.append(data)
            except json.JSONDecodeError:
                # If that fails, try JSON Lines format (one JSON object per line)
                f.seek(0)  # Reset file pointer
                line_count = 0
                for line in f:
                    line = line.strip()
                    if line:  # Skip empty lines
                        try:
                            data = json.loads(line)
                            json_data.append(data)
                            line_count += 1
                        except json.JSONDecodeError as e:
                            print(f"Error parsing line in {file}: {e}")
                print(f"Loaded {line_count} results from JSON Lines file: {os.path.basename(file)}")
    
    return json_data

# Load data from json_data folder
data_path = '/scratch/ty296/json_data/'
p_fixed_name = 'p_ctrl'
p_fixed_value = 0.0
json_data = load_data_flexible(data_path, p_fixed_name, p_fixed_value)
print(f"Loaded {len(json_data)} total data points")

Found 1200 JSON files
Loaded 200 results from JSON Lines file: 44474867_a0_L12.json
Loaded 200 results from JSON Lines file: 44474488_a0_L8.json
Loaded 200 results from JSON Lines file: 44474924_a0_L12.json
Loaded 200 results from JSON Lines file: 44474776_a0_L10.json
Loaded 0 results from JSON Lines file: 44476291_a0_L16.json
Loaded 200 results from JSON Lines file: 44474512_a0_L8.json
Loaded 200 results from JSON Lines file: 44474677_a0_L10.json
Loaded 200 results from JSON Lines file: 44474551_a0_L8.json
Loaded 200 results from JSON Lines file: 44474749_a0_L10.json
Loaded 200 results from JSON Lines file: 44474809_a0_L10.json
Loaded 200 results from JSON Lines file: 44474804_a0_L10.json
Loaded 200 results from JSON Lines file: 44474493_a0_L8.json
Loaded 200 results from JSON Lines file: 44476370_a0_L16.json
Loaded 200 results from JSON Lines file: 44476462_a0_L16.json
Loaded 200 results from JSON Lines file: 44474481_a0_L8.json
Loaded 200 results from JSON Lines file: 44474669_a0_L1

In [5]:
def group_data_by_params(json_data):
    """Group data by all parameter combinations for histogram analysis"""
    grouped_data = {}
    
    for data in json_data:
        # Create a key tuple with the parameters we want to group by
        key = (
            data['args']['L'],
            data['args']['ancilla'],
            data['p_ctrl'],
            data['p_proj']
        )
        
        # If this parameter combination hasn't been seen before, initialize lists
        if key not in grouped_data:
            grouped_data[key] = {
                'EE': [],
                'O': [],
                'max_bond': []
            }
        
        # Append the values to their respective lists
        grouped_data[key]['EE'].append(data['EE'])
        grouped_data[key]['O'].append(data['O'])
        grouped_data[key]['max_bond'].append(data['max_bond'])
    
    return grouped_data

def plot_histograms(json_data, save_plots=True, max_plots=5):
    """Plot histograms for each parameter combination (limited to max_plots)"""
    grouped_results = group_data_by_params(json_data)
    
    print(f"Found {len(grouped_results)} parameter combinations")
    if len(grouped_results) > max_plots:
        print(f"Limiting to first {max_plots} combinations for demonstration")
    
    count = 0
    for key, values in grouped_results.items():
        if count >= max_plots:
            break
            
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Plot EE histogram
        axes[0].hist(values['EE'], bins=10)
        axes[0].set_title('EE Distribution')
        axes[0].set_xlabel('EE')
        axes[0].set_ylabel('Frequency')

        # Plot O histogram
        axes[1].hist(values['O'], bins=10)
        axes[1].set_title('O Distribution')
        axes[1].set_xlabel('O')
        axes[1].set_ylabel('Frequency')

        # Plot max_bond histogram
        axes[2].hist(values['max_bond'], bins=10)
        axes[2].set_title('Max Bond Distribution')
        axes[2].set_xlabel('Max Bond')
        axes[2].set_ylabel('Frequency')

        fig.tight_layout()
        
        if save_plots:
            filename = f'/scratch/ty296/plots/histogram_{key[0]:03d}_a{key[1]:03d}_p_ctrl{key[2]:.3f}_p_proj{key[3]:.3f}.png'
            fig.savefig(filename)
            
        plt.close()

        # Print the parameters for the plotted data
        L, ancilla, p_ctrl, p_proj = key
        print(f"Histograms for: L = {L}, ancilla = {ancilla}, p_ctrl = {p_ctrl}, p_proj = {p_proj}")
        count += 1

# Optional: Run histogram analysis (commented out by default)
plot_histograms(json_data, save_plots=True, max_plots=20)




Found 100 parameter combinations
Limiting to first 20 combinations for demonstration
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.2
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.24210526315789474
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.28421052631578947
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.3263157894736842
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.3684210526315789
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.4105263157894737
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.45263157894736844
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.49473684210526314
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.5368421052631579
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.5789473684210527
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj = 0.6210526315789474
Histograms for: L = 12, ancilla = 0, p_ctrl = 0.0, p_proj

In [6]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

def group_by_L_and_varying_p(json_data):
    """
    Group data by L and the varying p parameter to calculate statistics
    
    Parameters:
    - json_data: list of data dictionaries
    - p_fixed_name: which parameter is fixed ('p_proj' or 'p_ctrl')
                   if 'p_proj' is fixed, we vary p_ctrl
                   if 'p_ctrl' is fixed, we vary p_proj
    - p_fixed_value: the value of the fixed parameter
    """
    grouped = {}
    varying_p_name = 'p_ctrl' if p_fixed_name == 'p_proj' else 'p_proj'
    for data in json_data:
        L = data['args']['L']
        varying_p_value = data[varying_p_name]
        key = (L, varying_p_value, p_fixed_value)
        
        if key not in grouped:
            grouped[key] = []
        
        grouped[key].append(data['EE'])
    
    return grouped, varying_p_name

def calculate_stats(ee_values):
    """Calculate mean and standard error of the mean"""
    ee_array = np.array(ee_values)
    mean = np.mean(ee_array)
    std_err = stats.sem(ee_array)  # Standard error of the mean
    return mean, std_err

def plot_average_EE_vs_p(json_data, p_fixed_name='p_proj', save_plot=True, show_plot=False):
    """
    Plot average EE vs the varying p parameter
    
    Parameters:
    - json_data: list of data dictionaries
    - p_fixed_name: which parameter is fixed ('p_proj' or 'p_ctrl')
    - save_plot: whether to save the plot
    - show_plot: whether to display the plot
    """
    
    # Group data by L and varying p parameter
    ee_grouped, varying_p_name = group_by_L_and_varying_p(json_data)
    
    # Organize data for plotting
    plot_data = {}
    for (L, varying_p_value, p_fixed_value), ee_values in ee_grouped.items():
        if L not in plot_data:
            plot_data[L] = {varying_p_name: [], 'mean_EE': [], 'sem_EE': []}
        
        print(f'L = {L}, {varying_p_name} = {varying_p_value}, ensemble size: {len(ee_values)}')
        mean_ee, sem_ee = calculate_stats(ee_values)
        plot_data[L][varying_p_name].append(varying_p_value)
        plot_data[L]['mean_EE'].append(mean_ee)
        plot_data[L]['sem_EE'].append(sem_ee)
    
    # Sort data by varying p parameter for each L
    for L in plot_data:
        sorted_indices = np.argsort(plot_data[L][varying_p_name])
        plot_data[L][varying_p_name] = np.array(plot_data[L][varying_p_name])[sorted_indices]
        plot_data[L]['mean_EE'] = np.array(plot_data[L]['mean_EE'])[sorted_indices]
        plot_data[L]['sem_EE'] = np.array(plot_data[L]['sem_EE'])[sorted_indices]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown']
    markers = ['o', 's', '^', 'D', 'v', '<']
    
    for i, L in enumerate(sorted(plot_data.keys())):
        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        
        ax.errorbar(plot_data[L][varying_p_name], 
                    plot_data[L]['mean_EE'], 
                    yerr=plot_data[L]['sem_EE'],
                    label=f'L = {L}',
                    color=color,
                    marker=marker,
                    markersize=6,
                    linewidth=2,
                    capsize=3,
                    capthick=1)
    
    # Set labels and title based on which parameter is varying
    ax.set_xlabel(varying_p_name, fontsize=12)
    ax.set_ylabel('Average EE', fontsize=12)
    
    if varying_p_name == 'p_ctrl':
        title = f'Average Entanglement Entropy vs Control Probability (p_proj fixed)'
        filename = f'/scratch/ty296/plots/average_EE_vs_p_ctrl_p_proj{p_fixed_value}.png'
    else:
        title = f'Average Entanglement Entropy vs Projection Probability (p_ctrl fixed)'
        filename = f'/scratch/ty296/plots/average_EE_vs_p_proj_p_ctrl{p_fixed_value}.png' 
    
    ax.set_title(title, fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Save or show the plot
    fig.tight_layout()
    if save_plot:
        fig.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Plot saved to: {filename}")
    
    if show_plot:
        plt.show()
    else:
        plt.close(fig)
    
    # Print some statistics
    print(f"\nData summary (varying {varying_p_name}, fixed {p_fixed_name}):")
    for L in sorted(plot_data.keys()):
        n_points = len(plot_data[L][varying_p_name])
        print(f"L = {L}: {n_points} data points")
        print(f"  {varying_p_name} range: {plot_data[L][varying_p_name].min():.3f} to {plot_data[L][varying_p_name].max():.3f}")
        print(f"  EE range: {plot_data[L]['mean_EE'].min():.3f} ± {plot_data[L]['sem_EE'][np.argmin(plot_data[L]['mean_EE'])]:.3f} to {plot_data[L]['mean_EE'].max():.3f} ± {plot_data[L]['sem_EE'][np.argmax(plot_data[L]['mean_EE'])]:.3f}")
    
    return plot_data

# Example usage:
# For your current data (p_proj fixed at 0.5, p_ctrl varies):
print("=== Plotting EE vs p_ctrl (p_proj fixed) ===")
plot_data_ctrl = plot_average_EE_vs_p(json_data, p_fixed_name='p_proj', save_plot=True, show_plot=False)

print("\n" + "="*50)
print("=== If you had data with p_ctrl fixed and p_proj varying, you would use: ===")
print("plot_data_proj = plot_average_EE_vs_p(json_data, p_fixed_name='p_ctrl', save_plot=True, show_plot=False)")
print("="*50)


=== Plotting EE vs p_ctrl (p_proj fixed) ===
L = 12, p_proj = 0.2, ensemble size: 2000
L = 12, p_proj = 0.24210526315789474, ensemble size: 2000
L = 12, p_proj = 0.28421052631578947, ensemble size: 2000
L = 12, p_proj = 0.3263157894736842, ensemble size: 2000
L = 12, p_proj = 0.3684210526315789, ensemble size: 2000
L = 12, p_proj = 0.4105263157894737, ensemble size: 2000
L = 12, p_proj = 0.45263157894736844, ensemble size: 2000
L = 12, p_proj = 0.49473684210526314, ensemble size: 2000
L = 12, p_proj = 0.5368421052631579, ensemble size: 2000
L = 12, p_proj = 0.5789473684210527, ensemble size: 2000
L = 12, p_proj = 0.6210526315789474, ensemble size: 2000
L = 12, p_proj = 0.6631578947368421, ensemble size: 2000
L = 12, p_proj = 0.7052631578947368, ensemble size: 2000
L = 12, p_proj = 0.7473684210526316, ensemble size: 2000
L = 12, p_proj = 0.7894736842105263, ensemble size: 2000
L = 12, p_proj = 0.8315789473684211, ensemble size: 2000
L = 12, p_proj = 0.8736842105263158, ensemble size: 20