In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt
from collections import defaultdict

In [None]:
with open("oatom_envs_jp_dio-orig_min4.json", "r") as f:
    oatom_envs = json.load(f)

In [None]:
def filter_by_env_composition(envs, min_num_grains=2, max_num_grains=3, smallest_max_fract=0.05):
    filtered_envs = []
    for env in envs:
        grain_fract = env['grain_fract']
        if len(grain_fract) > max_num_grains or len(grain_fract) < min_num_grains:
            continue

        sorted_fracts = sorted(grain_fract.values(), reverse=True)

        if len(sorted_fracts) == max_num_grains and sorted_fracts[-1] > smallest_max_fract:
            continue
        filtered_envs.append(env)
    return filtered_envs

In [None]:
def extract_partition_value(env, partition_type="fract_hcp_c"):
    if partition_type == "fract_hcp_c":
        return env["fract_hcp_c"]
    elif partition_type == "third_grain_fract":
        grain_fract = env["grain_fract"]
        if len(grain_fract) > 3:
            raise ValueError(f"partition type {partition_type} not meaningful for systems with > 3 grains")
        else:
            sorted_fracts = sorted(grain_fract.values(), reverse=True)
            third_fract = sorted_fracts[2] if len(sorted_fracts) == 3 else 0.0
            return third_fract
    else:
        raise ValueError(f"Unknown partition type {partition_type}")

def prepare_histogram_data(envs, partition_bins, partition_type='fract_hcp_c'):
    largest_grain_fracts = []
    partition_values = []

    for env in envs:
        grain_fract = env["grain_fract"]
        largest_fract = max(grain_fract.values())
        largest_grain_fracts.append(largest_fract)

        partition_val = extract_partition_value(env, partition_type=partition_type)
        #print(partition_val)
        partition_values.append(partition_val)

    largest_grain_fracts = np.array(largest_grain_fracts) # I was initially missing this line
    partition_categories = np.digitize(partition_values,partition_bins) -1
    return largest_grain_fracts, partition_categories, partition_bins, partition_type

In [None]:
# Updated dual histogram plot with modular partition support

def plot_dual_histogram(largest_fracts, partition_categories, partition_bins,
                           partition_type='fract_hcp_c',
                           x_range=(0.48, 1.0), num_bins=20,
                           figsize=(12, 8), colors=None):
    """
    Create a dual histogram with partitioned bars showing distribution by partition type.

    Parameters:
    -----------
    largest_fracts : np.ndarray
        Array of largest grain fractions
    partition_categories : np.ndarray
        Array of category indices for each environment
    partition_bins : list of float
        Bin edges for partition categorization
    partition_type : str
        Type of partition ('fract_hcp_c' or 'third_grain_fract')
    x_range : tuple
        (min, max) range for x-axis (default: (0.48, 1.0))
    num_bins : int
        Number of bins for histogram (default: 20)
    figsize : tuple
        Figure size (width, height) (default: (12, 8))
    colors : list
        Colors for each partition category (optional)

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    if colors is None:
        # Default colors for categories
        colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']

    # Create histogram bins
    hist_bins = np.linspace(x_range[0], x_range[1], num_bins + 1)

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Calculate total histogram
    total_counts, _ = np.histogram(largest_fracts, bins=hist_bins)
    print(f"total counts {total_counts}")

    # Calculate counts for each partition category in each bin
    num_categories = len(partition_bins) - 1
    category_counts = np.zeros((num_categories, num_bins))

    for i in range(num_categories):
        mask = partition_categories == i
        category_data = largest_fracts[mask]
        counts, _ = np.histogram(category_data, bins=hist_bins)
        category_counts[i, :] = counts

    # Plot main histogram
    bin_centers = (hist_bins[:-1] + hist_bins[1:]) / 2
    bin_width = hist_bins[1] - hist_bins[0]

    ax.bar(bin_centers, total_counts, width=bin_width * 0.95,
           color='lightgray', edgecolor='black', linewidth=0.5,
           label='Total count', alpha=0.7)

    # Generate category labels based on partition type
    category_labels = []
    for i in range(num_categories):
        if partition_type == 'fract_hcp_c':
            label = f"[{partition_bins[i]:.2f}, {partition_bins[i+1]:.2f})"
            if i == 0:
                label = f"<{partition_bins[i+1]:.2f}"
            elif i == num_categories - 1:
                label = f">{partition_bins[i]:.2f}"
        elif partition_type == 'third_grain_fract':
            label = f"[{partition_bins[i]:.3f}, {partition_bins[i+1]:.3f})"
        else:
            label = f"[{partition_bins[i]}, {partition_bins[i+1]})"
        category_labels.append(label)

    # Create thin stacked bars above each main bar
    thin_bar_width = bin_width * 0.3
    bottom = np.zeros(num_bins)

    for i in range(num_categories):
        # Normalize to show as proportion of total in each bin
        normalized_counts = np.where(total_counts > 0,
                                     category_counts[i, :] / total_counts * np.max(total_counts) * 0.2,
                                     0)

        ax.bar(bin_centers, normalized_counts, width=thin_bar_width,
               bottom=total_counts + bottom, color=colors[i],
               edgecolor='black', linewidth=0.3,
               label=f'{partition_type} {category_labels[i]}')
        bottom += normalized_counts

    # Formatting
    ax.set_xlabel('Largest Grain Fraction', fontsize=12, fontweight='bold')
    ax.set_ylabel('Count', fontsize=12, fontweight='bold')

    title_text = f'Distribution of Largest Grain Fraction\nPartitioned by {partition_type}'
    ax.set_title(title_text, fontsize=14, fontweight='bold')
    ax.set_xlim(x_range[0], x_range[1])
    ax.grid(True, alpha=0.3, linestyle=':', axis='y')
    ax.legend(loc='upper left', fontsize=9)

    plt.tight_layout()
    return fig, ax


In [None]:
third_grain_thresh_default = 0.05
filtered_envs_default = filter_by_env_composition(oatom_envs, smallest_max_fract=third_grain_thresh_default)

third_grain_thresh_loose = 0.075
filtered_envs_looser = filter_by_env_composition(oatom_envs, smallest_max_fract=third_grain_thresh_loose)

In [None]:
#filtered_envs = filtered_envs_default
filtered_envs = filtered_envs_looser

hcp_c_bins = [0.0, 0.2, 0.3, 0.4, 0.5, 1.0]
fracts_hcp, cats_hcp, bins_hcp, ptype_hcp = prepare_histogram_data(
    filtered_envs,
    partition_bins=hcp_c_bins,
    partition_type='fract_hcp_c'
)

x_min, x_max = 0.48,1.0
num_bins = 20
# Create plot
fig_hcp, ax_hcp = plot_dual_histogram(
    fracts_hcp, cats_hcp, bins_hcp,
    partition_type=ptype_hcp,
    x_range=(x_min, x_max),
    num_bins=num_bins
)

In [None]:
#filtered_envs = filtered_envs_default
#third_grain_bins = [0.0, 0.002, 0.01, 0.02, 0.03, 0.04, 0.05]

filtered_envs = filtered_envs_looser
third_grain_bins = [0.0, 0.002, 0.01, 0.02, 0.03, 0.04, 0.075]
# Prepare data
fracts_3rd, cats_3rd, bins_3rd, ptype_3rd = prepare_histogram_data(
    filtered_envs,
    partition_bins=third_grain_bins,
    partition_type='third_grain_fract'
)

# Plot settings
x_min, x_max = 0.48, 1.0
n_bins = 20

# Create plot
fig_3rd, ax_3rd = plot_dual_histogram(
    fracts_3rd, cats_3rd, bins_3rd,
    partition_type=ptype_3rd,
    x_range=(x_min, x_max),
    num_bins=n_bins
)

print(f"\nThird grain fraction distribution:")
for i in range(len(third_grain_bins) - 1):
    count = np.sum(cats_3rd == i)
    print(f"  [{third_grain_bins[i]:.3f}, {third_grain_bins[i+1]:.3f}): {count} environments")
plt.show()

In [None]:
def filter_envs_for_heat_map(envs, min_num_grains=2, max_num_grains=3, smallest_max_fract=0.05, max_largest_fract=0.6):
    filtered_envs = []
    for env in envs:
        grain_fract = env['grain_fract']
        if len(grain_fract) > max_num_grains or len(grain_fract) < min_num_grains:
            continue

        sorted_fracts = sorted(grain_fract.values(), reverse=True)

        if len(sorted_fracts) == max_num_grains and sorted_fracts[-1] > smallest_max_fract:
            continue

        if sorted_fracts[0] > max_largest_fract:
            continue
        filtered_envs.append(env)
    return filtered_envs

In [None]:
hmap_filtered_envs_default = filter_envs_for_heat_map(oatom_envs, smallest_max_fract=0.05, max_largest_fract=0.6)
print(f"Filtered to {len(hmap_filtered_envs_default)} environments (from {len(oatom_envs)} total)")

In [None]:
# Prepare heat map data

def prepare_heatmap_data(environments):
    """
    Create a symmetric count matrix for grain pairs based on the two largest grains
    in each environment.

    Parameters:
    -----------
    environments : list of dict
        List of filtered environment dictionaries

    Returns:
    --------
    tuple of (grain_ids, count_matrix)
        grain_ids : sorted list of unique grain IDs
        count_matrix : 2D numpy array with counts for each grain pair
    """
    # Collect all grain pairs with their counts
    pair_counts = defaultdict(int)
    all_grain_ids = set()

    for env in environments:
        grain_fract = env['grain_fract']

        # Get grain IDs sorted by their fractions (descending)
        sorted_grains = sorted(grain_fract.items(), key=lambda x: x[1], reverse=True)

        # Get the two largest grains
        grain1 = int(sorted_grains[0][0])
        grain2 = int(sorted_grains[1][0])

        all_grain_ids.add(grain1)
        all_grain_ids.add(grain2)

        # Create ordered pair (smaller ID first for consistency)
        pair = tuple(sorted([grain1, grain2]))
        pair_counts[pair] += 1

    # Create sorted list of grain IDs
    grain_ids = sorted(all_grain_ids)
    n_grains = len(grain_ids)

    # Create grain ID to index mapping
    grain_to_idx = {g: i for i, g in enumerate(grain_ids)}

    # Build symmetric count matrix
    count_matrix = np.zeros((n_grains, n_grains), dtype=int)

    for (g1, g2), count in pair_counts.items():
        idx1 = grain_to_idx[g1]
        idx2 = grain_to_idx[g2]
        # Make symmetric
        count_matrix[idx1, idx2] = count
        count_matrix[idx2, idx1] = count

    return grain_ids, count_matrix

In [None]:
grain_ids, count_matrix = prepare_heatmap_data(hmap_filtered_envs_default)
print(f"Heat map dimensions: {len(grain_ids)} x {len(grain_ids)}")
print(f"Grain IDs: {grain_ids}")
print(f"Total pairs (non-zero entries / 2): {np.sum(count_matrix > 0) // 2}")

In [None]:
# Format count values for display

def format_count(count):
    """
    Format count values: numbers > 1000 shown as '1K', etc.

    Parameters:
    -----------
    count : int
        Count value

    Returns:
    --------
    str
        Formatted count string
    """
    if count == 0:
        return ''
    elif count >= 1000:
        return f'{count//1000}K'
    else:
        return str(count)


# Plot heat map with counts

def plot_heatmap(grain_ids, count_matrix, figsize=(12, 10), cmap='YlOrRd',
                 upper_triangle_only=True, diagonal_color='lightgray'):
    """
    Create a heat map of grain pair counts with values displayed in each cell.

    Parameters:
    -----------
    grain_ids : list
        Sorted list of grain IDs
    count_matrix : ndarray
        Symmetric count matrix
    figsize : tuple
        Figure size (width, height)
    cmap : str
        Colormap name (default: 'YlOrRd')
    upper_triangle_only : bool
        If True, only show upper right triangle (default: True)
    diagonal_color : str
        Color for diagonal cells (default: 'lightgray')

    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    fig, ax = plt.subplots(figsize=figsize)

    # Create a masked version of the count matrix
    n = len(grain_ids)
    display_matrix = np.copy(count_matrix).astype(float)

    # Mask lower triangle if upper_triangle_only is True
    if upper_triangle_only:
        for i in range(n):
            for j in range(n):
                if i > j:  # Lower triangle
                    display_matrix[i, j] = np.nan

    # Create heat map
    im = ax.imshow(display_matrix, cmap=cmap, aspect='auto')

    # Manually color the diagonal
    for i in range(n):
        rect = plt.Rectangle((i - 0.5, i - 0.5), 1, 1,
                             facecolor=diagonal_color, edgecolor='white', linewidth=1)
        ax.add_patch(rect)

    # Set ticks and labels
    ax.set_xticks(np.arange(len(grain_ids)))
    ax.set_yticks(np.arange(len(grain_ids)))
    ax.set_xticklabels(grain_ids)
    ax.set_yticklabels(grain_ids)

    # Rotate x-axis labels for better readability
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    # Add text annotations with formatted counts
    for i in range(len(grain_ids)):
        for j in range(len(grain_ids)):
            # Skip lower triangle if upper_triangle_only is True
            if upper_triangle_only and i > j:
                continue

            count = count_matrix[i, j]
            #text = format_count(count)
            text = str(count)
            if text:  # Only show non-zero values
                # Choose text color based on background
                # Diagonal uses different logic
                if i == j:
                    text_color = 'black'
                else:
                    text_color = 'white' if count > count_matrix.max() / 2 else 'black'
                ax.text(j, i, text, ha="center", va="center",
                       color=text_color, fontsize=10, fontweight='bold')

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Count', rotation=270, labelpad=20, fontsize=12)

    # Labels and title
    ax.set_xlabel('Grain ID', fontsize=12, fontweight='bold')
    ax.set_ylabel('Grain ID', fontsize=12, fontweight='bold')
    title_suffix = ' (Upper Triangle)' if upper_triangle_only else ''
    ax.set_title(f'Grain Pair Heat Map{title_suffix}\n(Two Largest Grains in Each Environment)',
                fontsize=14, fontweight='bold', pad=20)

    plt.tight_layout()
    return fig, ax

In [None]:
fig, ax = plot_heatmap(grain_ids, count_matrix, figsize=(12, 10),
                       upper_triangle_only=True, diagonal_color='lightgray')