<a href="https://colab.research.google.com/github/vinayak2019/Polymer/blob/main/dope_crystal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pymatgen

In [None]:
from pymatgen.core import Structure, Molecule, Lattice
import numpy as np
import os
from scipy.spatial.transform import Rotation as R
from scipy.spatial import cKDTree  # For efficient nearest-neighbor searches
import sys  # For progress printing

# --- Van der Waals Radii (extend as needed for your specific elements) ---
VdW_RADII = {
    "H": 1.20, "C": 1.70, "N": 1.55, "O": 1.52, "F": 1.47, "S": 1.80, "Cl": 1.75,
    "Si": 2.10, "Ge": 2.11, "Ga": 1.87, "As": 1.85, "Fe": 1.95, "Br": 1.85,
    "P": 1.80, "B": 1.92, "Mg": 1.73, "Na": 2.27, "K": 2.75, "Al": 1.84,
    "Ti": 1.87, "V": 1.75, "Cr": 1.85, "Mn": 1.79, "Co": 2.00, "Ni": 1.63,
    "Cu": 1.40, "Zn": 1.39, "Sr": 2.49, "Mo": 2.10, "Ru": 2.05, "Rh": 1.95,
    "Pd": 1.63, "Ag": 1.72, "Cd": 1.58, "Sn": 2.17, "Sb": 2.06, "I": 1.98,
    "Cs": 2.98, "Ba": 2.68, "Pt": 1.72, "Au": 1.66, "Hg": 1.55, "Pb": 2.02,
}


# --- Utility Function for Distance Calculation with PBC (Vectorized for 1-to-N) ---
def calculate_distances_with_pbc_1_to_N(coord1, coordsN, lattice_matrix, inv_lattice_matrix):
    """
    Calculates the minimum image distance between a single point (coord1)
    and multiple points (coordsN) in a periodic lattice.
    Args:
        coord1 (np.array): Cartesian coordinates of the single point (shape (3,)).
        coordsN (np.array): Cartesian coordinates of N points (shape (N, 3)).
        lattice_matrix (np.array): The 3x3 lattice vectors.
        inv_lattice_matrix (np.array): The inverse of the lattice matrix.
    Returns:
        np.array: An array of minimum image distances (shape (N,)).
    """
    # Convert to fractional coordinates
    frac_coord1 = np.dot(coord1, inv_lattice_matrix)
    frac_coordsN = np.dot(coordsN, inv_lattice_matrix)

    # Calculate fractional displacements
    frac_disp = frac_coordsN - frac_coord1[np.newaxis, :]  # Broadcasting coord1 to (1,3)
    frac_disp -= np.round(frac_disp)  # Minimum image convention

    # Convert back to Cartesian displacements
    cart_disp = np.dot(frac_disp, lattice_matrix)  # This performs (N,3) @ (3,3) -> (N,3)

    return np.linalg.norm(cart_disp, axis=1)  # Norm along the last axis


def _process_single_combination_with_penalty(
        translation_coords, rotation_matrix,
        original_sm_coords, original_sm_center_of_mass, original_sm_vdw_radii, sm_species,
        expanded_host_coords, expanded_host_vdw_radii, host_kdtree, lattice_matrix, inv_lattice_matrix, buffer_factor
):
    """
    Helper function to process a single translational and rotational combination,
    calculating a penalty based on maximum atomic overlap.
    Returns (penalty, min_clearance, placed_molecule).
    """
    translation_vector = translation_coords - original_sm_center_of_mass
    pre_translated_sm_coords = original_sm_coords + translation_vector

    rotated_and_translated_sm_coords = (
            np.dot(pre_translated_sm_coords - original_sm_center_of_mass, rotation_matrix.T)
            + original_sm_center_of_mass
    )

    max_penetration = 0.0  # This will be our penalty: maximum depth of overlap
    min_overall_clearance = float('inf')  # Smallest clearance (can be negative for overlap)

    # Max possible required vdW sum for the KDTree search radius
    max_sm_vdw = np.max(original_sm_vdw_radii)
    max_host_vdw = np.max(expanded_host_vdw_radii)  # Use max from expanded host for safety

    # Set the KDTree search radius to capture all potential periodic images
    max_cell_span = np.max(np.linalg.norm(lattice_matrix, axis=1))
    max_search_radius_kdtree = (max_sm_vdw + max_host_vdw) * buffer_factor + max_cell_span

    # Iterate through each atom of the small molecule
    for i, sm_coord in enumerate(rotated_and_translated_sm_coords):
        sm_vdw_radius = original_sm_vdw_radii[i]

        # Query the KDTree for host atoms in the expanded supercell that are "near" sm_coord.
        nearby_host_indices = host_kdtree.query_ball_point(sm_coord, r=max_search_radius_kdtree)

        if not nearby_host_indices:
            continue  # No host atoms nearby for this small molecule atom within the search radius

        # Extract coordinates and VdW radii for only the nearby host atoms
        nearby_host_coords = expanded_host_coords[nearby_host_indices]
        nearby_host_vdw_radii = expanded_host_vdw_radii[nearby_host_indices]

        # --- OPTIMIZATION HERE: Vectorized distance and clearance calculation ---
        # Calculate PBC distances from this single sm_coord to all nearby_host_coords
        distances = calculate_distances_with_pbc_1_to_N(
            sm_coord, nearby_host_coords, lattice_matrix, inv_lattice_matrix
        )

        # Calculate required minimum distances for these pairs
        required_min_distances_pairs = (sm_vdw_radius + nearby_host_vdw_radii) * buffer_factor

        # Calculate clearances for all pairs in a single vectorized operation
        clearances = distances - required_min_distances_pairs

        # Update overall minimum clearance (can be negative if overlapping)
        if np.min(clearances) < min_overall_clearance:
            min_overall_clearance = np.min(clearances)

        # If there's significant penetration (i.e., clearance is negative)
        if np.any(clearances < -1e-7):  # Use a small negative tolerance for strict overlap detection
            current_penetration = -np.min(clearances[clearances < -1e-7])  # Find the max penetration depth
            if current_penetration > max_penetration:
                max_penetration = current_penetration

    # The penalty is the maximum penetration depth found for this placement.
    penalty = max_penetration

    placed_molecule = Molecule(sm_species, rotated_and_translated_sm_coords)

    return (penalty, min_overall_clearance, placed_molecule)


def find_least_clash_positions(cif_file_path, small_molecule_pymatgen_object, grid_spacing=0.5,
                               rotation_step_degrees=90, buffer_factor=0.8, num_results=1):
    """
    Finds positions for a small molecule within a crystal structure,
    prioritizing positions with the least clash (lowest penalty).
    This version uses cKDTree with an expanded supercell for efficient neighbor finding,
    and quantifies clash based on maximum atomic penetration using vectorized operations.
    """

    host_structure = Structure.from_file(cif_file_path)
    lattice = host_structure.lattice
    lattice_matrix = lattice.matrix
    inv_lattice_matrix = np.linalg.inv(lattice_matrix)  # Pre-calculate inverse lattice matrix

    # Extract host atom data for cKDTree
    host_coords = np.array([s.coords for s in host_structure.sites])
    host_species = [str(s.specie) for s in host_structure.sites]
    host_vdw_radii = np.array([VdW_RADII.get(s, 1.5) for s in host_species])

    # Pre-calculate initial molecule properties once
    original_sm_coords = small_molecule_pymatgen_object.cart_coords
    sm_species = [str(site.specie) for site in small_molecule_pymatgen_object.sites]
    original_sm_vdw_radii = np.array([VdW_RADII.get(s, 1.5) for s in sm_species])
    original_sm_center_of_mass = small_molecule_pymatgen_object.center_of_mass

    # Define translational grid based on host lattice vectors
    a, b, c = lattice.abc
    x_grid = np.arange(0, a, grid_spacing)
    y_grid = np.arange(0, b, grid_spacing)
    z_grid = np.arange(0, c, grid_spacing)

    # Define Euler angles to iterate over
    angles_deg = np.arange(0, 360, rotation_step_degrees)

    total_grid_points = len(x_grid) * len(y_grid) * len(z_grid)
    total_orientations = len(angles_deg) ** 3
    total_combinations = total_grid_points * total_orientations

    print(
        f"Starting search with {total_grid_points} translational points and {total_orientations} orientations ({total_combinations} total combinations)...")
    print(f"WARNING: A small 'grid_spacing' or 'rotation_step_degrees' will lead to very long computation times.")

    # Create a list of all rotation matrices upfront
    rotation_matrices = []
    for alpha_deg in angles_deg:
        for beta_deg in angles_deg:
            for gamma_deg in angles_deg:
                rotation = R.from_euler('xyz', [alpha_deg, beta_deg, gamma_deg], degrees=True)
                rotation_matrices.append(rotation.as_matrix())
    print(f"Pre-calculated {len(rotation_matrices)} rotation matrices.")

    # Create an expanded host structure for PBC-aware nearest neighbor search with cKDTree.
    # This creates a 3x3x3 supercell explicitly for the KDTree.
    expanded_host_coords = []
    cell_offsets = np.array([[dx, dy, dz] for dx in [-1, 0, 1] for dy in [-1, 0, 1] for dz in [-1, 0, 1]])

    for offset in cell_offsets:
        expanded_host_coords.append(host_coords + np.dot(offset, lattice_matrix))
    expanded_host_coords = np.vstack(expanded_host_coords)
    expanded_host_vdw_radii = np.tile(host_vdw_radii, len(cell_offsets))  # Repeat VdW radii for expanded coords

    host_kdtree = cKDTree(expanded_host_coords)
    print(f"Built KDTree with {len(expanded_host_coords)} expanded host atom images.")

    all_candidate_placements = []  # Will store (penalty, min_clearance, molecule)

    # --- Sequential Processing ---
    for i, x in enumerate(x_grid):
        for j, y in enumerate(y_grid):
            for k, z in enumerate(z_grid):
                candidate_com_position = np.array([x, y, z])

                # Pre-translate all small molecule atoms for this grid point
                pre_translated_sm_coords_base = original_sm_coords + (
                            candidate_com_position - original_sm_center_of_mass)

                for rot_idx, rotation_matrix in enumerate(rotation_matrices):
                    # Progress indicator
                    current_combination_idx = (i * len(y_grid) * len(z_grid) * len(rotation_matrices)) + \
                                              (j * len(z_grid) * len(rotation_matrices)) + \
                                              (k * len(rotation_matrices)) + rot_idx
                    if total_combinations > 1000 and (current_combination_idx + 1) % (total_combinations // 100) == 0:
                        sys.stdout.write(
                            f"\r  Processing combination {current_combination_idx + 1}/{total_combinations} ({((current_combination_idx + 1) / total_combinations) * 100:.1f}%)")
                        sys.stdout.flush()
                    elif current_combination_idx == total_combinations - 1:
                        sys.stdout.write(
                            f"\r  Processing combination {current_combination_idx + 1}/{total_combinations} ({((current_combination_idx + 1) / total_combinations) * 100:.1f}%) - Complete.")
                        sys.stdout.flush()

                    penalty, min_clearance, placed_molecule = _process_single_combination_with_penalty(
                        candidate_com_position, rotation_matrix,
                        original_sm_coords, original_sm_center_of_mass, original_sm_vdw_radii, sm_species,
                        expanded_host_coords, expanded_host_vdw_radii, host_kdtree, lattice_matrix, inv_lattice_matrix,
                        buffer_factor
                    )
                    all_candidate_placements.append((penalty, min_clearance, placed_molecule))
    print("\nSearch complete.")

    # Sort candidates:
    # Primary sort key: penalty (ascending - lower penalty is better)
    # Secondary sort key: min_clearance (descending - for same penalty, more clearance is better)
    all_candidate_placements.sort(key=lambda x: (x[0], -x[1]))

    # Return the top N molecules
    print(f"Found {len(all_candidate_placements)} candidate positions. Returning top {num_results}.")
    # Extract only the molecule objects from the sorted list
    return [mol for penalty, clearance, mol in all_candidate_placements[:num_results]]


def add_molecule_to_structure_and_save(host_cif_path, placed_molecule, output_cif_path):
    """
    Adds a placed molecule to an existing crystal structure and saves it to a new CIF file.
    """
    host_structure = Structure.from_file(host_cif_path)
    new_structure = Structure(host_structure.lattice,
                              [s.specie for s in host_structure.sites],
                              [s.coords for s in host_structure.sites],
                              coords_are_cartesian=True)

    for site in placed_molecule.sites:
        new_structure.append(site.specie, site.coords, coords_are_cartesian=True)

    print(f"Original host structure has {len(host_structure)} atoms.")
    print(f"Added molecule has {len(placed_molecule)} atoms.")
    print(f"New combined structure has {len(new_structure)} atoms.")

    new_structure.to(filename=output_cif_path)
    print(f"Combined structure saved to: {output_cif_path}")


# --- Example Usage ---
if __name__ == "__main__":
    # >>>>>>>>>>>>>>>>>> USER CONFIGURATION <<<<<<<<<<<<<<<<<<<<
    # Set the path to your host CIF file
    host_cif_filename = "test.cif"

    # Set the path to your small molecule XYZ file
    molecule_xyz_filename = "mol.xyz"

    # Number of top positions to save (based on least penalty)
    num_top_positions_to_save = 5
    # Buffer factor for VdW radii (1.0 = touch, >1.0 = gap, <1.0 = overlap allowed)
    # Use 1.0 for strict non-overlap, >1.0 for a small gap, <1.0 to allow some overlap.
    current_buffer_factor = 1.01

    # Translational grid spacing (Angstroms)
    current_grid_spacing = 1.0
    # Rotation step for Euler angles (degrees)
    current_rotation_step_degrees = 180
    # >>>>>>>>>>>>>>>>>> END USER CONFIGURATION <<<<<<<<<<<<<<<<<<<<


    small_molecule = Molecule.from_file(molecule_xyz_filename)

    top_least_clash_mols = find_least_clash_positions(
        host_cif_filename,
        small_molecule,
        grid_spacing=current_grid_spacing,
        rotation_step_degrees=current_rotation_step_degrees,
        buffer_factor=current_buffer_factor,
        num_results=num_top_positions_to_save
    )

    if top_least_clash_mols:
        for i, placed_mol in enumerate(top_least_clash_mols):
            host_name = os.path.splitext(os.path.basename(host_cif_filename))[0]
            mol_name = os.path.splitext(os.path.basename(molecule_xyz_filename))[0]
            # Updated output filename to reflect penalty-based approach
            output_cif_filename = f'{host_name}_with_{mol_name}_pos_{i + 1}_grid{current_grid_spacing}_rot{current_rotation_step_degrees}_buf{current_buffer_factor:.2f}_penalty.cif'

            add_molecule_to_structure_and_save(host_cif_filename, placed_mol, output_cif_filename)
            print(f"Successfully saved top position {i + 1} to {output_cif_filename}")
    else:
        print("No suitable positions found based on the specified criteria.")