In [None]:
import json

In [None]:
with open("grain_pair_mapping_2p6.json", "r") as f:
    grain_pair_mapping_2p6 = json.load(f)

## Prompt 
I have a dict that maps a tuple represented as a string (e.g. '(3,5)') to a list of dicts. The tuple represents a pair of grains in an atomistic system, and the list of dicts is a list of atomic environments. Different grain pairs have different numbers of environments (i.e. the list of dicts is of different lengths). I want to sub-select some of these environments in a way that keeps the relative proportion of these different grain pairs the same. Here is an example of a function that does this:
...previous function...
This is easy enough. However, there are two modifications.
1. I want to make sure that every grain pair has a minimum number of selected environments (e.g., 10), even if that disrupts the relative proportion.
2. I only want a total of N (e.g N=200) environments selected (though there can perhaps be a buffer +/- some amount, e.g. 5

Thus I want an algorithm that, as best possible, maintains the same proportion of relative envionments amongst grain pairs, subject to having a minimum number for each grain pair, and resulting in a total number of environments close to some target value (e.g. 200), with a buffer.
## Response
What you want is essentially a **constrained proportional allocation** problem:

- You have groups (grain pairs)
- Each group has a capacity $C_i$ (number of valid environments available)
- You want to select $n_i$ from each group
- Subject to:
  - **Minimum per group:** $n_i \ge n_{\min}$
  - **Total:** $\sum_i n_i \approx N_{\text{target}}$ (within a tolerance)
  - **Proportionality:** $n_i$ should be as close as possible to being proportional to $C_i$ (or to some score-weighted size, if you prefer)

This is very similar to *apportionment* problems (Hamilton / largest-remainder methods), with lower bounds.

In [None]:
import math
import numpy as np

def select_envs_proportional_with_min(
    grain_pair_mapping,
    value_key="first_shell_fract_hcp_c",
    n_total=200,
    min_per_pair=10,
    buffer=5,
):
    """
    Select environments proportionally across grain pairs with:
      - minimum per grain pair
      - total target selection count (with buffer)
    """

    # --- Step 1: collect valid envs ---
    pair_envs = {}
    for pair_key, envs in grain_pair_mapping.items():
        valid = [
            env for env in envs
            if value_key in env and env[value_key] is not None
        ]
        if valid:
            pair_envs[pair_key] = sorted(
                valid,
                key=lambda e: e[value_key],
                reverse=True,
            )

    if not pair_envs:
        return {}

    capacities = {k: len(v) for k, v in pair_envs.items()}

    # --- Step 2: minimum allocation ---
    n_alloc = {
        k: min(min_per_pair, capacities[k])
        for k in capacities
    }

    n_min_total = sum(n_alloc.values())

    if n_min_total > n_total + buffer:
        raise ValueError(
            f"Minimum allocation {n_min_total} exceeds target {n_total} + buffer"
        )

    # --- Step 3: remaining quota ---
    remaining = n_total - n_min_total
    remaining_caps = {
        k: capacities[k] - n_alloc[k]
        for k in capacities
        if capacities[k] > n_alloc[k]
    }

    if remaining > 0 and remaining_caps:
        total_remaining_cap = sum(remaining_caps.values())

        # --- Step 4: proportional allocation ---
        fractional = {
            k: remaining * remaining_caps[k] / total_remaining_cap
            for k in remaining_caps
        }

        floor_alloc = {k: int(math.floor(v)) for k, v in fractional.items()}
        for k, v in floor_alloc.items():
            n_alloc[k] += v

        remainder = remaining - sum(floor_alloc.values())

        # Largest remainder distribution
        remainders = sorted(
            fractional.items(),
            key=lambda kv: kv[1] - math.floor(kv[1]),
            reverse=True,
        )

        for k, _ in remainders:
            if remainder <= 0:
                break
            if n_alloc[k] < capacities[k]:
                n_alloc[k] += 1
                remainder -= 1

    # --- Step 5: final selection ---
    result = {}
    for k, envs in pair_envs.items():
        selected = envs[:n_alloc[k]]
        #mean_val = float(np.mean([e[value_key] for e in selected])) \
        #    if selected else np.nan

        #result[k] = {
        #    "selected_envs": selected,
        #    "mean_first_shell_fract_hcp_c": mean_val,
        #}
        result[k] = selected

    return result

In [None]:
proposed_selection = select_envs_proportional_with_min(
    grain_pair_mapping_2p6,
    value_key="first_shell_fract_hcp_c",
    n_total=200,
    min_per_pair=5,
    buffer=5,)

In [None]:
proposed_selection

Quickly running some of the previous plot just to check

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm

def plot_max_grain_fract_hist_colored(
    envs,
    *,
    bins=50,
    ax=None,
    grain_pair_label=None,
    cmap="viridis",
    vmin=0.0,
    vmax=1.0,
):
    """
    Histogram of max grain_fract, colored by the average
    first_shell_fract_hcp_c within each bin.
    """

    selected_envs = envs

    # --- Extract values ---
    x = np.array([
        max(env["grain_fract"].values())
        for env in selected_envs
        if env.get("grain_fract")
    ])

    c = np.array([
        env["first_shell_fract_hcp_c"]
        for env in selected_envs
        if env.get("grain_fract")
    ])

    no_title = True
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 4))
        no_title= False

    # --- Compute histogram bins ---
    counts, bin_edges = np.histogram(x, bins=bins)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    bin_widths = np.diff(bin_edges)

    # --- Compute per-bin mean first_shell_fract_hcp_c ---
    bin_means = np.full(len(counts), np.nan)

    #bin_indices = np.digitize(x, bin_edges) - 1
    bin_indices = np.clip(np.digitize(x, bin_edges) - 1, 0, len(counts) - 1)

    for i in range(len(bin_means)):
        mask = bin_indices == i
        if np.any(mask):
            bin_means[i] = c[mask].mean()

    # --- Colormap ---
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    cmap = cm.get_cmap(cmap)

    colors = cmap(norm(bin_means))

    # --- Draw bars ---
    ax.bar(
        bin_centers,
        counts,
        width=bin_widths,
        color=colors,
        edgecolor="black",
        align="center",
    )

    if grain_pair_label is not None:
        ax.set_title(grain_pair_label, fontsize=10)

    if not no_title:
        # --- Labels ---
        ax.set_xlabel("Max grain fraction")
        ax.set_ylabel("Count")


        # --- Colorbar ---
        sm = cm.ScalarMappable(norm=norm, cmap=cmap)
        sm.set_array([])
        plt.colorbar(
            sm,
            ax=ax,
            label="Mean first_shell_fract_hcp_c",
            pad=0.01,
        )

    return ax

In [None]:
gpm_keys = list(grain_pair_mapping_2p6.keys())
len(gpm_keys)

In [None]:
n_rows, n_cols = 8, 4
fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(28, 12),
    sharex=True,
    sharey=False,
)

axes = axes.flatten()

for i, pair_key in enumerate(gpm_keys):
    envs = proposed_selection[pair_key]
    ax = axes[i]

    plot_max_grain_fract_hist_colored(
        envs,
        bins=50,
        grain_pair_label=pair_key,
        ax=ax,
    )

    ax.tick_params(labelsize=8)

# Hide unused subplot (if any)
for j in range(i + 1, len(axes)):
    axes[j].axis("off")

fig.supxlabel("Max grain fraction", fontsize=14)
fig.supylabel("Count", fontsize=14)

fig.subplots_adjust(
    wspace=0.25,
    hspace=0.35,
)

plt.show()

In [None]:
num_total_envs = 0
for k,v in proposed_selection.items():
    num_envs = len(v)
    print(f"{k}: {num_envs}")
    num_total_envs += num_envs
print(num_total_envs)

(6, 11): 5
(3, 11): 6
(2, 4): 8
(3, 6): 6
(2, 3): 12
(2, 5): 7
(4, 5): 6
(1, 3): 11
(3, 10): 7
(6, 10): 6
(1, 4): 9
(2, 11): 5
(2, 6): 6
(7, 9): 6
(2, 9): 6
(2, 7): 6
(1, 7): 6
(1, 6): 6
(1, 5): 6
(2, 10): 6
(1, 11): 6
(1, 9): 6
(2, 8): 6
(1, 10): 6
(4, 7): 5
(8, 9): 6
(5, 7): 6
(4, 9): 6
(5, 8): 6
(4, 8): 5
(1, 8): 6
200

In [None]:
def to_jsonable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, np.generic):
        return obj.item()
    if isinstance(obj, dict):
        return {k: to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [to_jsonable(v) for v in obj]
    return obj

In [None]:
with open("proposed_selected_envs_v1.json", "w") as f:
    json.dump(proposed_selection,f,indent=2)