In [None]:
# ---------------- 1. Imports & params ----------------
import math
import random
from collections import defaultdict
from copy import deepcopy

import numpy as np
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

In [None]:
"""
Monte Carlo SQS generator for Mn0.2Zn0.8Fe2O4
Input: CIF of the parent spinel (e.g., "Mn(FeO2)2.cif")
Output: best SQS written as CIF ("SQS_Mn0.2Zn0.8Fe2O4.cif")
Author: Generated for user
Requires: pymatgen, numpy
"""

# --- User parameters (edit as needed) ---
PARENT_FILE = "Mn(FeO2)2.cif"  # input CIF you provided
SUPERCELL_SCALE = (2, 2, 2)  # adjust to get desired atom count
A_ELEMENT_IN_PARENT = "Mn"  # element in parent CIF that marks A-site (tetrahedral)
B_ELEMENT_IN_PARENT = "Fe"  # element marking B-site (octahedral)
A_frac = {"Mn": 0.2, "Zn": 0.8}  # target A-site fractions
B_frac = {"Fe": 1.0}  # keep B-site ordered for Stage 1 (Fe only)
PAIR_CUTOFF = 4.5  # angstrom cutoff to consider pair clusters
TRIPLET_CUTOFF = 5.0  # cutoff for triplet cluster inclusion
NUM_MC = 50000  # Monte Carlo steps
T_START = 0.2  # starting temperature for simulated annealing
T_END = 0.01  # final temperature
SAVE_FILENAME = "SQS_Mn0.2Zn0.8Fe2O4.cif"

# ---------------- 2. Load parent, make supercell ----------------
parent = Structure.from_file(PARENT_FILE)
# use conventional cell to help mapping if required
sga = SpacegroupAnalyzer(parent)
parent_conv = sga.get_conventional_standard_structure()

supercell = parent_conv.copy()
supercell.make_supercell(SUPERCELL_SCALE)


# ---------------- 3. Helper: map site indices by element ----------------
def map_sites_by_symbol(unit_struct, super_struct, symbol, tol=1e-3):
    """Map sites by species symbol (ignoring oxidation states)."""
    unit_coords = []
    for site in unit_struct:
        if any(sp.symbol == symbol for sp, occ in site.species.items()):
            unit_coords.append(site.frac_coords)
    mapped = []
    for i, s in enumerate(super_struct):
        for uc in unit_coords:
            d = ((s.frac_coords - uc + 0.5) % 1.0) - 0.5
            if np.dot(d, d) < tol:
                mapped.append(i)
                break
    mapped = sorted(list(set(mapped)))
    if len(mapped) == 0:
        raise ValueError(f"No sites found for element {symbol} in supercell mapping.")
    return mapped


# Use new version:
A_indices = map_sites_by_symbol(parent_conv, supercell, A_ELEMENT_IN_PARENT, tol=0.02)
B_indices = map_sites_by_symbol(parent_conv, supercell, B_ELEMENT_IN_PARENT, tol=0.02)

print(
    f"Mapped {len(A_indices)} A-sites and {len(B_indices)} B-sites in supercell (total atoms: {len(supercell)})"
)

print(parent_conv.symbol_set)


# ---------------- 4. Initialize configuration honoring fractions ----------------
def init_config(struct, indices, frac_map, rng=random):
    """
    Assign species deterministically to indices following integer counts from frac_map.
    Rounding corrected by filling highest-prob species.
    """
    n = len(indices)
    species_list = []
    for sp, f in frac_map.items():
        count = int(round(f * n))
        species_list += [sp] * count
    # fix rounding mismatch
    while len(species_list) < n:
        # append species with largest fraction
        species_list.append(max(frac_map.items(), key=lambda x: x[1])[0])
    species_list = species_list[:n]
    rng.shuffle(species_list)
    new_struct = struct.copy()
    for idx, sp in zip(indices, species_list):
        new_struct[idx].species = {sp: 1.0}
    return new_struct


# start with a random configuration matching fractions
current_struct = init_config(supercell, A_indices, A_frac)
current_struct = init_config(
    current_struct, B_indices, B_frac
)  # here B_frac is Fe-only for stage1


# ---------------- 5. Correlation calculators (binary: ±1 mapping) ----------------
def sigma_map_for_set(struct, indices, species_order=None):
    """
    For a binary active set, returns dict idx -> sigma (+1/-1) and species order used.
    species_order: optional [sp_plus, sp_minus]; if None determined from present species.
    """
    species_present = sorted(
        list({list(struct[i].species.items())[0][0] for i in indices})
    )
    if species_order is None:
        if len(species_present) != 2:
            raise ValueError(
                "Expected exactly 2 species on the active set for binary mapping."
            )
        species_order = species_present
    sp_plus, sp_minus = species_order[0], species_order[1]
    sig = {}
    for i in indices:
        sp = list(struct[i].species.items())[0][0]
        sig[i] = 1 if sp == sp_plus else -1
    return sig, species_order


def compute_pair_correlations(struct, indices, cutoff=PAIR_CUTOFF):
    """
    Return dict shell_distance -> average sigma_i * sigma_j for pairs found within cutoff.
    Uses binary sigma mapping (determined from species seen in indices).
    """
    if len(indices) < 2:
        return {}
    sig, species_order = sigma_map_for_set(struct, indices)
    pairs = defaultdict(list)
    for a_idx, i in enumerate(indices):
        for j in indices[a_idx + 1 :]:
            d = struct.get_distance(i, j)
            if d <= cutoff:
                shell = round(d, 3)
                pairs[shell].append(sig[i] * sig[j])
    corr = (
        {shell: sum(vals) / len(vals) for shell, vals in pairs.items()} if pairs else {}
    )
    return corr


def compute_triplet_correlations(struct, indices, cutoff=TRIPLET_CUTOFF):
    """
    Enumerate triplets (i<j<k) with max pair distance <= cutoff and compute avg product sigma_i*sigma_j*sigma_k.
    """
    if len(indices) < 3:
        return {}
    sig, _ = sigma_map_for_set(struct, indices)
    vals = []
    n = len(indices)
    for a in range(n):
        i = indices[a]
        for b in range(a + 1, n):
            j = indices[b]
            for c in range(b + 1, n):
                k = indices[c]
                dmax = max(
                    struct.get_distance(i, j),
                    struct.get_distance(j, k),
                    struct.get_distance(i, k),
                )
                if dmax <= cutoff:
                    vals.append(sig[i] * sig[j] * sig[k])
    if not vals:
        return {}
    return {"trip_avg": sum(vals) / len(vals)}


# ---------------- 6. Random-target formulas for ideal random alloy ----------------
def random_pair_target(frac_map, species_order=None):
    """For binary sigma mapping (+1/-1), target pair corr = (p_plus - p_minus)**2."""
    sp_list = list(frac_map.items())
    if len(sp_list) != 2:
        raise ValueError("random_pair_target expects binary fraction map.")
    p0 = sp_list[0][1]
    p1 = sp_list[1][1]
    return (p0 - p1) ** 2


def random_triplet_target(frac_map):
    """target triplet average for binary sigma mapping = (p_plus - p_minus)**3."""
    sp_list = list(frac_map.items())
    if len(sp_list) != 2:
        raise ValueError("random_triplet_target expects binary fraction map.")
    p0 = sp_list[0][1]
    p1 = sp_list[1][1]
    return (p0 - p1) ** 3


# ---------------- 7. Cost function ----------------
def cost_function(struct):
    """
    Compute sum of squared errors for A-site pair shells and triplets.
    Currently includes B-site pair error too if B is binary.
    """
    total_err = 0.0
    # A-site pair error
    try:
        pairA = compute_pair_correlations(struct, A_indices)
        if pairA:
            targA = random_pair_target(A_frac)
            total_err += sum((v - targA) ** 2 for v in pairA.values()) / max(
                1, len(pairA)
            )
    except ValueError:
        # if A is not binary (shouldn't happen for this script)
        pass

    # A-site triplet
    try:
        tripA = compute_triplet_correlations(struct, A_indices)
        if tripA:
            targ_tripA = random_triplet_target(A_frac)
            total_err += (tripA["trip_avg"] - targ_tripA) ** 2
    except ValueError:
        pass

    # B-site pair error (only computed if B is binary)
    try:
        pairB = compute_pair_correlations(struct, B_indices)
        if pairB:
            targB = random_pair_target(B_frac)
            total_err += sum((v - targB) ** 2 for v in pairB.values()) / max(
                1, len(pairB)
            )
    except ValueError:
        # B may be single-species (Fe only) in Stage 1; ignore then
        pass

    return total_err


# ---------------- 8. Swap operator ----------------
def swap_two_sites(struct, indices, rng=random):
    """Swap species labels between two distinct indices in 'indices'."""
    i, j = rng.sample(indices, 2)
    s_i = list(struct[i].species.items())[0][0]
    s_j = list(struct[j].species.items())[0][0]
    struct[i].species = {s_j: 1.0}
    struct[j].species = {s_i: 1.0}
    return i, j


# ---------------- 9. Monte Carlo loop with simulated annealing ----------------
def run_montecarlo(
    initial_struct, num_steps=NUM_MC, t_start=T_START, t_end=T_END, rng=random
):
    current = initial_struct.copy()
    best = current.copy()
    current_cost = cost_function(current)
    best_cost = current_cost

    for step in range(1, num_steps + 1):
        # linear annealing schedule (can change to exponential)
        frac = step / float(num_steps)
        T = t_start * (1 - frac) + t_end * frac

        # choose sublattice to operate on (A favored)
        oper = rng.choices(
            ["A", "B"], weights=[len(A_indices), max(1, len(B_indices))]
        )[0]
        idxs = A_indices if oper == "A" else B_indices

        trial = current.copy()
        swap_two_sites(trial, idxs, rng=rng)
        trial_cost = cost_function(trial)
        dE = trial_cost - current_cost

        # Metropolis acceptance
        if dE < 0.0 or rng.random() < math.exp(-dE / max(T, 1e-12)):
            current = trial
            current_cost = trial_cost
            if current_cost < best_cost:
                best = current.copy()
                best_cost = current_cost

        # occasional logging
        if step % max(1, num_steps // 20) == 0:
            print(
                f"MC step {step}/{num_steps}  current_cost={current_cost:.6e}  best_cost={best_cost:.6e}  T={T:.5f}"
            )

    return best, best_cost


# ---------------- 10. Run & save ----------------
print("Starting Monte Carlo SQS generation...")
best_struct, best_cost = run_montecarlo(current_struct, NUM_MC, T_START, T_END)
print(f"Done. Best cost: {best_cost:.6e}. Writing CIF: {SAVE_FILENAME}")
best_struct.to(fmt="cif", filename=SAVE_FILENAME)