In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt

In [None]:
with open("oatom_envs_jp_dio-orig_min4.json", "r") as f:
    oatom_envs = json.load(f)

In [None]:
oatom_envs[0]

In [None]:
def filter_by_env_composition(envs, min_num_grains=2, max_num_grains=3, smallest_max_fract=0.05):
    filtered_envs = []
    for env in envs:
        grain_fract = env['grain_fract']
        if len(grain_fract) > max_num_grains or len(grain_fract) < min_num_grains:
            continue

        sorted_fracts = sorted(grain_fract.values(), reverse=True)

        if len(sorted_fracts) == max_num_grains and sorted_fracts[-1] > smallest_max_fract:
            continue
        filtered_envs.append(env)
    return filtered_envs

In [None]:
# original that was mis-specified
def filter_by_grain_composition(environments, num_grains=3, third_grain_max_fract=0.05):
    """
    Filter environments based on grain composition criteria.

    Parameters:
    -----------
    environments : list of dict
        List of environment dictionaries containing 'grain_fract' key
    num_grains : int
        Required number of grains in the environment (default: 3)
    third_grain_max_fract : float
        Maximum allowed fraction for the third largest grain (default: 0.05)

    Returns:
    --------
    filtered_envs : list of dict
        Filtered environment dictionaries
    """
    filtered_envs = []

    for env in environments:
        grain_fract = env['grain_fract']

        # Check if environment has exactly the required number of grains
        if len(grain_fract) != num_grains:
            continue

        # Sort grain fractions in descending order
        fractions = sorted(grain_fract.values(), reverse=True)

        # Check if third grain is below threshold
        if len(fractions) >= 3 and fractions[2] < third_grain_max_fract:
            filtered_envs.append(env)

    return filtered_envs

In [None]:
third_grain_thresh = 0.05
filtered_envs = filter_by_env_composition(oatom_envs, smallest_max_fract=third_grain_thresh)
incorrect_filtered_envs = filter_by_grain_composition(oatom_envs, third_grain_max_fract=third_grain_thresh)

In [None]:
len(filtered_envs)

In [None]:
len(incorrect_filtered_envs)

In [None]:
def extract_partition_value(env, partition_type="fract_hcp_c"):
    if partition_type == "fract_hcp_c":
        return env["fract_hcp_c"]
    elif partition_type == "third_grain_fract":
        grain_fract = env["grain_fract"]
        if len(grain_fract) > 3:
            raise ValueError(f"partition type {partition_type} not meaningful for systems with > 3 grains")
        else:
            sorted_fracts = sorted(grain_fract.values(), reverse=True)
            third_fract = sorted_fracts[2] if len(sorted_fracts) == 3 else 0.0
            return third_fract
    else:
        raise ValueError(f"Unknown partition type {partition_type}")

def prepare_histogram_data(envs, partition_bins, partition_type='fract_hcp_c'):
    largest_grain_fracts = []
    partition_values = []

    for env in envs:
        grain_fract = env["grain_fract"]
        largest_fract = max(grain_fract.values())
        largest_grain_fracts.append(largest_fract)

        partition_val = extract_partition_value(env, partition_type=partition_type)
        #print(partition_val)
        partition_values.append(partition_val)

    largest_grain_fracts = np.array(largest_grain_fracts) # I was initially missing this line, but it was in the reference code
    partition_categories = np.digitize(partition_values,partition_bins) -1
    return largest_grain_fracts, partition_categories, partition_bins, partition_type

Some quick testing of these functions

In [None]:
hcp_c_bins = [0.0, 0.2, 0.3, 0.4, 0.5, 1.0]
fracts_hcp, cats_hcp, bins_hcp, ptype_hcp = prepare_histogram_data(
    filtered_envs,
    partition_bins=hcp_c_bins,
    partition_type='fract_hcp_c'
)

In [None]:
fracts_hcp
#cats_hcp
#bins_hcp

In [None]:
third_grain_bins = [0.0, 0.002, 0.01, 0.02, 0.03, 0.04, 0.05]

fracts_3rd, cats_3rd, bins_3rd, ptype_3rd = prepare_histogram_data(
    filtered_envs,
    partition_bins=third_grain_bins,
    partition_type='third_grain_fract'
)

In [None]:
#fracts_3rd
#cats_3rd
bins_3rd

In [None]:
x_range = (0.48,1.0)
num_bins = 20

hist_bins = np.linspace(x_range[0], x_range[1], num_bins+1)
#hist_bins

num_categories = len(bins_3rd) -1
#num_categories
category_counts = np.zeros((num_categories, num_bins))
#category_counts.shape

i = 1
mask = cats_3rd == i
#mask
len(fracts_3rd)
category_data = fracts_3rd[mask]
category_data

In [None]:
# Updated dual histogram plot with modular partition support

def plot_dual_histogram(largest_fracts, partition_categories, partition_bins,
                           partition_type='fract_hcp_c',
                           x_range=(0.48, 1.0), num_bins=20,
                           figsize=(12, 8), colors=None):
    """
    Create a dual histogram with partitioned bars showing distribution by partition type.

    Parameters:
    -----------
    largest_fracts : np.ndarray
        Array of largest grain fractions
    partition_categories : np.ndarray
        Array of category indices for each environment
    partition_bins : list of float
        Bin edges for partition categorization
    partition_type : str
        Type of partition ('fract_hcp_c' or 'third_grain_fract')
    x_range : tuple
        (min, max) range for x-axis (default: (0.48, 1.0))
    num_bins : int
        Number of bins for histogram (default: 20)
    figsize : tuple
        Figure size (width, height) (default: (12, 8))
    colors : list
        Colors for each partition category (optional)

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    if colors is None:
        # Default colors for categories
        colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']

    # Create histogram bins
    hist_bins = np.linspace(x_range[0], x_range[1], num_bins + 1)

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Calculate total histogram
    total_counts, _ = np.histogram(largest_fracts, bins=hist_bins)

    # Calculate counts for each partition category in each bin
    num_categories = len(partition_bins) - 1
    category_counts = np.zeros((num_categories, num_bins))

    for i in range(num_categories):
        mask = partition_categories == i
        category_data = largest_fracts[mask]
        counts, _ = np.histogram(category_data, bins=hist_bins)
        category_counts[i, :] = counts

    # Plot main histogram
    bin_centers = (hist_bins[:-1] + hist_bins[1:]) / 2
    bin_width = hist_bins[1] - hist_bins[0]

    ax.bar(bin_centers, total_counts, width=bin_width * 0.95,
           color='lightgray', edgecolor='black', linewidth=0.5,
           label='Total count', alpha=0.7)

    # Generate category labels based on partition type
    category_labels = []
    for i in range(num_categories):
        if partition_type == 'fract_hcp_c':
            label = f"[{partition_bins[i]:.2f}, {partition_bins[i+1]:.2f})"
            if i == 0:
                label = f"<{partition_bins[i+1]:.2f}"
            elif i == num_categories - 1:
                label = f">{partition_bins[i]:.2f}"
        elif partition_type == 'third_grain_fract':
            label = f"[{partition_bins[i]:.3f}, {partition_bins[i+1]:.3f})"
        else:
            label = f"[{partition_bins[i]}, {partition_bins[i+1]})"
        category_labels.append(label)

    # Create thin stacked bars above each main bar
    thin_bar_width = bin_width * 0.3
    bottom = np.zeros(num_bins)

    for i in range(num_categories):
        # Normalize to show as proportion of total in each bin
        normalized_counts = np.where(total_counts > 0,
                                     category_counts[i, :] / total_counts * np.max(total_counts) * 0.2,
                                     0)

        ax.bar(bin_centers, normalized_counts, width=thin_bar_width,
               bottom=total_counts + bottom, color=colors[i],
               edgecolor='black', linewidth=0.3,
               label=f'{partition_type} {category_labels[i]}')
        bottom += normalized_counts

    # Formatting
    ax.set_xlabel('Largest Grain Fraction', fontsize=12, fontweight='bold')
    ax.set_ylabel('Count', fontsize=12, fontweight='bold')

    title_text = f'Distribution of Largest Grain Fraction\nPartitioned by {partition_type}'
    ax.set_title(title_text, fontsize=14, fontweight='bold')
    ax.set_xlim(x_range[0], x_range[1])
    ax.grid(True, alpha=0.3, linestyle=':', axis='y')
    ax.legend(loc='upper left', fontsize=9)

    plt.tight_layout()
    return fig, ax


In [None]:
hcp_c_bins = [0.0, 0.2, 0.3, 0.4, 0.5, 1.0]
fracts_hcp, cats_hcp, bins_hcp, ptype_hcp = prepare_histogram_data(
    filtered_envs,
    partition_bins=hcp_c_bins,
    partition_type='fract_hcp_c'
)

x_min, x_max = 0.48,1.0
num_bins = 20
# Create plot
fig_hcp, ax_hcp = plot_dual_histogram(
    fracts_hcp, cats_hcp, bins_hcp,
    partition_type=ptype_hcp,
    x_range=(x_min, x_max),
    num_bins=num_bins
)

In [None]:
hcp_c_bins = [0.0, 0.2, 0.3, 0.4, 0.5, 1.0]
fracts_hcp, cats_hcp, bins_hcp, ptype_hcp = prepare_histogram_data(
    filtered_envs,
    partition_bins=hcp_c_bins,
    partition_type='fract_hcp_c'
)

#x_min, x_max = 0.48,1.0
x_min, x_max = 0.48,0.7
num_bins = 20
# Create plot
fig_hcp, ax_hcp = plot_dual_histogram(
    fracts_hcp, cats_hcp, bins_hcp,
    partition_type=ptype_hcp,
    x_range=(x_min, x_max),
    num_bins=num_bins)

In [None]:
# Plot with third grain fraction partitions

# Define partition bins for third grain fraction
third_grain_bins = [0.0, 0.002, 0.01, 0.02, 0.03, 0.04, 0.05]

# Prepare data
fracts_3rd, cats_3rd, bins_3rd, ptype_3rd = prepare_histogram_data(
    filtered_envs,
    partition_bins=third_grain_bins,
    partition_type='third_grain_fract'
)

# Plot settings
x_min, x_max = 0.48, 1.0
n_bins = 20

# Create plot
fig_3rd, ax_3rd = plot_dual_histogram(
    fracts_3rd, cats_3rd, bins_3rd,
    partition_type=ptype_3rd,
    x_range=(x_min, x_max),
    num_bins=n_bins
)

print(f"\nThird grain fraction distribution:")
for i in range(len(third_grain_bins) - 1):
    count = np.sum(cats_3rd == i)
    print(f"  [{third_grain_bins[i]:.3f}, {third_grain_bins[i+1]:.3f}): {count} environments")
plt.show()