Github Copilot autocomplete was used to help produce this code.

Chat GPT was used to help in some cases.

In [None]:
import numpy as np
import spglib
from ase import io
from ase import build
import math
from sympy import *
from matplotlib import pyplot as plt

def get_cell(atoms):
    """ Takes in an ASE atoms object and returns the lattice, scaled positions and atomic numbers. """
    lattice = atoms.get_cell()
    scaled_positions = atoms.get_scaled_positions()
    numbers = atoms.get_atomic_numbers()
    cell = (lattice, scaled_positions, numbers)
    return cell

def get_symmetry_operations(cell):
    """ Takes in a cell and returns the symmetry operations, space group and symmetry dataset. """
    symmetry = spglib.get_symmetry(cell)
    space_group = spglib.get_spacegroup(cell)
    symmetry_dataset = spglib.get_symmetry_dataset(cell)
    return symmetry, space_group, symmetry_dataset

def build_supercell(atoms, scaling):
    """ Takes in an ASE atoms object and a scaling factor and returns a supercell. """
    supercell = build.make_supercell(atoms, np.eye(3)*scaling)
    return supercell
# def build_supercell(atoms, scaling):
#     """ Takes in an ASE atoms object and a scaling factor and returns a supercell. """
#     supercell = build.make_supercell(atoms, np.diag([3,3,1]))
#     return supercell

def fractional_to_cartesian(h, rho):
    """ Converts fractional coordinates to cartesian using the lattice vectors. """
    r = h.T @ rho
    return r

def grouping(distances, tolerance=1e-6):
    """ Groups the distances into shells. """
    #ChatGPT helped with this code
    sorted_distances = np.sort(distances)
    rounded_distances = np.round(sorted_distances, decimals=4)
    unique_distances, unique_indices = np.unique(rounded_distances, return_inverse=True)
    shell_mapping = {distance: shell for shell, distance in enumerate(unique_distances)}
    original_shells = np.array([shell_mapping[np.round(d, decimals=4)] for d in distances])
    if not np.any(np.isclose(sorted_distances, 0, atol=tolerance)):
        original_shells = original_shells + 1
    return original_shells

In [None]:
def plot_unit_cell_3d(positons):
    """ Plots the unit cell in 3D """
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d') 
    ax.scatter(positons[:,0], positons[:,1], positons[:,2]) # plot x,y,z
    plt.show()
    return

In [None]:
def plot_unit_cell(positions):
    """ Plots the unit cell in the 3 2D planes. """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].plot(positions[:, 0], positions[:, 1], '.')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    axes[1].plot(positions[:, 0], positions[:, 2], '.')
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('z')
    axes[2].plot(positions[:, 1], positions[:, 2], '.')
    axes[2].set_xlabel('y')
    axes[2].set_ylabel('z')
    plt.tight_layout()
    plt.show()
    return

In [None]:
def get_point_distances(point, positions, radius=None, tolerance=1e-6):
    """ Finds the cartesian distances between a point and all other points within a certain radius.
        If radius is None, it finds all distances. 
        Returns the distances and the index of the points.
        Does not count the distance to itself.
        
        Args:
            point: The point to find distances from.
            positions: The positions to find distances to.
            radius: The radius to find distances within.
            tolerance: The tolerance for the distance comparison.
        Returns:
            distances: The distances to the points.
            index_list: The index of the points."""
    
    distances = []
    index_list = []
    
    if radius != None: # if radius is given, finds all distances lower than the radius
        for i, pos in enumerate(positions):
            d = np.linalg.norm(point - pos) # distance between the points
            if np.isclose(d, 0 , atol=tolerance): # if the point is itself, skip
                continue
            elif np.linalg.norm(point - pos) < radius: # if the distance is less than the radius, add it to the list
                distances.append(d) # add distance
                index_list.append(i) # add index
    else:
        for i, pos in enumerate(positions): # if no radius is given, find all distances
            d = np.linalg.norm(point - pos)
            if np.isclose(d, 0 , atol=tolerance):
                continue
            distances.append(d)
            index_list.append(i)
    distances = np.array(distances) 
    index_list = np.array(index_list)
    return distances, index_list

In [None]:
def centering(point, scale):

    """ Takes in a point and returns the equalivalent point within the centre unit cell of the supercell.
    Args:
          point (numpy array): The point to be centered.
          scale (int): The scaling factor of the supercell.
    Returns:
          numpy array: The centered point within the middle unit cell.
          """
    
    origin_vector = np.array([0,0,0]) -  np.array([math.floor(point[0]), math.floor(point[1]), math.floor(point[2])])  # finds vector from the origin of the cell the point is in to the true origin
    origin_position = point + origin_vector # adds this vector to the point to take to the origin cell
    centre_scale = math.floor(scale/2) # the point of the origin of the centre cell
    centre_unit_cell_position = origin_position + np.array([centre_scale, centre_scale, centre_scale]) #add this vector the the origin cell point to get the centre cell point
    return centre_unit_cell_position

In [None]:
# This code was produced by Dr Joseph Barker (private communication) and is used with permission

from itertools import product
from numpy.typing import ArrayLike
import ase.io

class SymOp:
    def __init__(self, rotation: ArrayLike, translation: ArrayLike):
        self.translation = np.asarray(translation)
        self.rotation = np.asarray(rotation)

    def __mul__(self, other: 'SymOp') -> 'SymOp':
        """ Multiplies two symmetry operations. """
        return SymOp(self.rotation @ other.rotation, self.translation + self.rotation @ other.translation)

    def __imul__(self, other: 'SymOp') -> 'SymOp':
        """ Implements in-place multiplication (*=) for symmetry operations. """
        self.rotation = self.rotation @ other.rotation
        self.translation = self.translation + self.rotation @ other.translation
        return self

    def __matmul__(self, pos: ArrayLike) -> np.ndarray:
        """ Applies symmetry operation to a point pos. """
        x = np.asarray(pos)
        return self.rotation @ x + self.translation

    def __add__(self, translation: ArrayLike) -> 'SymOp':
        t = np.asarray(translation)
        return SymOp(self.rotation, self.translation + t)

    def __pow__(self, exponent: int) -> 'SymOp':
        """ Computes exponentiation of a symmetry operation. """
        if exponent < 1:
            raise ValueError("exponent must be greater than or equal to 1.")

        result = self
        for _ in range(exponent - 1):
            result *= self
        return result

    def __hash__(self):
        """ Computers the hash values """
        # Quantize to integer representation
        rot_quantized = (self.rotation * 1e8).astype(np.int64)
        trans_quantized = (self.translation * 1e8).astype(np.int64)

        # Efficiently hash using built-in buffer hashing
        return hash((rot_quantized.tobytes(), trans_quantized.tobytes()))

    def __str__(self) -> str:
        """ Returns a human-readable string representation. """
        return "\n".join(
            " ".join(f"{x: 3.3f}" for x in row) + f"  | {t: 3.3f}"
            for row, t in zip(self.rotation, self.translation)
        )


    def inversion() -> 'SymOp':
        """
        Returns the inversion operation as a SymOp.
        The inversion operation sends r -> -r, so the rotation matrix is -I
        and the translation is the zero vector.
        """
        return SymOp(-np.eye(3), np.zeros(3))

def symop_inverse(symop: SymOp) -> SymOp:
    """ Computes the inverse of the symmetry operation """
    transpose = np.transpose(symop.rotation)
    return SymOp(transpose, -transpose @ symop.translation)

def spgcell(ase_atoms, magnetic=False):
    """Generates spglib 'cell' argument from an ASE atoms object."""
    if magnetic:
        return (
            ase_atoms.get_cell(),
            ase_atoms.get_scaled_positions(),
            ase_atoms.get_atomic_numbers(),
            ase_atoms.get_initial_magnetic_moments()
        )

    return (
        ase_atoms.get_cell(),
        ase_atoms.get_scaled_positions(),
        ase_atoms.get_atomic_numbers()
    )

def squared_distance(a, b):
    """Computes squared Euclidean distance between two points."""
    c = a - b
    return np.dot(c, c)

def get_space_group_symops(atoms) -> list[SymOp]:
    dataset = spglib.get_symmetry(spgcell(atoms))
    space_group = [SymOp(r, t) for r, t in zip(dataset['rotations'], dataset['translations'])]
    return space_group

def get_space_group_symops_cell(cell) -> list[SymOp]:
    dataset = spglib.get_symmetry(cell)
    space_group = [SymOp(r, t) for r, t in zip(dataset['rotations'], dataset['translations'])]
    return space_group

def get_point_group_name(symops: list[SymOp]) -> str:
    rotations = [s.rotation for s in symops]
    return spglib.get_pointgroup(rotations)[0]

In [None]:
def reduce_to_unit_cell(x):
    """ Returns the coordinate x (or each element of x) reduced to the interval [0, 1). """
    #chat gpt wrote this function
    return np.mod(x, 1)

def difference(x, y):
    """ Returns the distance (vector) between two points. """
    return y - x

In [None]:
def symmetry_operations_for_bond_2_points(space_group: list, x1, x2, tol: float = 1e-5):
    """ Takes two points (endpoints of a bond) and applies all symmetry operations in the space group one by one.
        If the symmetry operation leaves the bond invariant, this operation is returned as a direct operation.
        If the symmetry operation leaves the bond invariant apart from a reversal, this is returned as a reversal operation. """
    
    direct_ops = []
    reversal_ops = []
    
    x1 = np.asarray(x1) 
    x2 = np.asarray(x2)
    
    x1_unit_cell = reduce_to_unit_cell(x1) # reduce the x1 point to the unit cell
 
    # rotations = [op.rotation for op in space_group] # get the rotation matrices from the symmetry operations
    # print("Rotations:")
    # for i, rotation in enumerate(rotations):
    #     print(f"Rotation {i}:")
    #     print(rotation)
    
    for op in space_group: # loops over all symmetry operations S = (R,T) in the space group
        
        xp1 = op @ x1 # apply the symmetry operation to x1 
        xp2 = op @ x2 # apply the symmetry operation to x2
        xp1_reduced = reduce_to_unit_cell(xp1) # reduce the operated x1 to the unit cell
        xp2_reduced = reduce_to_unit_cell(xp2) # reduce the operated x2 to the unit cell
      
        worked = False # boolean to check if the operation worked
        done = False
        not_direct = False
        not_reversed = False

        if np.allclose(x1_unit_cell, xp1_reduced, atol=tol): # checks if the operated point is equivalent to x1 (this accounts for integer lattice translations)
            x1_translation = xp1 - x1 # finds vector from the original point to the operated point
            x2_post_translation = xp2 - x1_translation # applies this vector to the second operated point
            if np.allclose(x2_post_translation, x2, atol=tol): # checks if the translated point is equivalent to the original point
                print("operation leaves bond invariant")
                print(np.array(op.rotation))
                print(np.array(op.translation))
                direct_ops.append(op) # if the operation leaves the bond invariant, add it to the list
                worked = True
                done = True
            else:
                done = True
                not_direct = True
        if np.allclose(xp2_reduced, x1_unit_cell, atol=tol) and not worked: # checks if the second operated point is equivalent to the original point x1 (accounts for integer lattice translations)
            tranlsation = xp2 - x1 # finds vector from the original point x1  to the second operated point
            x1_post_translation = xp1 - tranlsation # applies this vector to the first operated point
            if np.allclose(x1_post_translation, x2, atol=tol): # if the translated point xp1 is equivalent to the original point x2, the bond is reversed under the operation
                print("operation reverses bond")
                print(np.array(op.rotation))
                print(np.array(op.translation))
                reversal_ops.append(op) # if the operation reverses the bond, add it to the list
                worked = True
                done = True
            else:
                done = True
                not_reversed = True
        if not worked and (not done or not_direct or not_reversed):
            # print("operation does not leave bond invariant either way") # if the operation does not leave the bond invariant
            # print(np.array(op.rotation))
            pass
            
    return direct_ops, reversal_ops

In [None]:
def symmetry_operations_for_bond_mapping_3_points(space_group: list, x1, x2, x3, tol: float = 1e-5):
    """ Takes two points (endpoints of a bond) and applies each rotation part of the symmetry operation to both points one by one.
        Checks that one enpoint stays invarint after the operation, and the other endpoint is mapped to the third point.
        If the symmetry operation maps the bond from x1-x2 to x1-x3 directly, this operation is returned as a direct operation.
        If the symmetry operation maps the bond from x1-x2 to x2-x3, this is returned as a reversal operation. """

    direct_ops = []
    reversal_ops = []
    
    x1 = np.asarray(x1)
    x2 = np.asarray(x2)
    x3 = np.asarray(x3)
    
    x1_unit_cell = reduce_to_unit_cell(x1) # reduce the x1 point to the unit cell
    
    for op in space_group:  # loops over all Rotations
        
        xp1 = op @ x1 # apply the symmetry operation to x1
        xp2 = op @ x2 # apply the symmetry operation to x2

        xp1_reduced = reduce_to_unit_cell(xp1) # reduce the operated x1 to the unit cell
        xp2_reduced = reduce_to_unit_cell(xp2) # reduce the operated x2 to the unit cell

        worked = False
        done = False
        not_direct = False
        not_reversed = False

        if np.allclose(x1_unit_cell, xp1_reduced, atol=tol): #if the operated x1 point is the same as the original (accounting for lattice transaltions)
            x1_translation = xp1 - x1 # find vector from original x1 to the operated x1
            x2_post_translation = xp2 - x1_translation # apply vector in reverse to the second operated point xp2
            if np.allclose(x2_post_translation, x3, atol=tol): # check if translated xp2 point is the same as the third point x3
                print("operation has mapped bond directly") # if the translated point is equivalent to x3, the bond is mapped directly
                direct_ops.append(op) # if the operation leaves the bond invariant, add it to the list
                worked = True
                done = True
            else:
                done = True
                not_direct = True
        if np.allclose(xp2_reduced, x1_unit_cell, atol=tol) and not worked: # if the operated x2 point is the same as the original x1 (accounting for lattice translations)
            tranlsation = xp2 - x1 # find vector from original x1 to the operated xp2
            x1_post_translation = xp1 - tranlsation # apply vector in reverse to the first operated point xp1
            if np.allclose(x1_post_translation, x3, atol=tol): # check if translated xp1 point is the same as the third point x3
                print("operation maps the bond in reverse") # if the translated point is equivalent to x3, the bond is mapped in reverse
                reversal_ops.append(op) # if the operation reverses the bond, add it to the reversal list
                worked = True
                done = True
            else:
                done = True
                not_reversed = True
        if not worked and (not done or not_direct or not_reversed): # if the operation does not leave the bond invariant
            # print("operation does not map bond in either way")
            # print(rot)
            pass
    
    return direct_ops, reversal_ops

In [None]:
def symmetry_operations_for_different_bond_mapping_4_points(space_group: list, x1, x2, x3, x4, tol: float = 1e-4):

    """ Takes 4 points. Checks if the bon x1-x2 can be mapped to the bond x3-x4 by applying the symmetry operations in the space group."""

    direct_ops = []
    reversal_ops = []
    
    x1 = np.asarray(x1)
    x2 = np.asarray(x2)
    x3 = np.asarray(x3)
    x4 = np.asarray(x4)

    x3_unit_cell = reduce_to_unit_cell(x3) # reduce the x3 point to the unit cell    
    
    for op in space_group: # loops over all Rotations
        
        xp1 = op @ x1 # apply the symmetry operation to x1
        xp2 = op @ x2 # apply the symmetry operation to x2

        xp1_reduced = reduce_to_unit_cell(xp1) # reduce the operated x1 to the unit cell
        xp2_reduced = reduce_to_unit_cell(xp2) # reduce the operated x2 to the unit cell

        worked = False
        done = False
        not_direct = False
        not_reversed = False

        if np.allclose(x3_unit_cell, xp1_reduced, atol=tol): # if xp1 is the same as x3 (accounting for lattice translations)
            x1_translation = xp1 - x3 # find vector from original x3 to the operated xp1
            xp2_post_translation = xp2 - x1_translation # apply vector in reverse to the second operated point xp2
            if np.allclose(xp2_post_translation, x4, atol=tol): # check if translated xp2 point is the same as the fourth point x4
                print("operation has mapped bond to the other bond") # if the translated point is equivalent to x4, the bond is mapped directly
                print(np.array(op.rotation))
                direct_ops.append(op) 
                worked = True
                done = True
            else:
                done = True
                not_direct = True
        if np.allclose(xp2_reduced, x3_unit_cell, atol=tol) and not worked: # if xp2 is the same as x3 (accounting for lattice translations)
            x2_tranlsation = xp2 - x3 # find vector from original x3 to the operated xp2
            xp1_post_translation = xp1 - x2_tranlsation # apply vector in reverse to the first operated point xp1
            if np.allclose(xp1_post_translation, x4, atol=tol): # check if translated xp1 point is the same as the fourth point x4
                print("operation maps the bond to the other bond in reverse") # if the translated point is equivalent to x4, the bond is mapped in reverse
                print(np.array(op.rotation))
                reversal_ops.append(op.rotation)# add the operation to the reversal list
                worked = True
                done = True
            else:
                done = True
                not_reversed = True
        if not worked and (not done or not_direct or not_reversed):
            # print("operation does not map the two different bonds in either way")
            # print(rot)
            pass

    return direct_ops, reversal_ops

In [None]:
def plotting_graphs(lattice, position, supercell_positions, scale, radius=None):
    
    """Plots the supercell in the xy, xz, yz planes.
    Highlights the position of the point we want to analyse the bonds from in red, and the supercell positions in blue.
    Shows the central unit cell that the point should be located, in order to find bonds around the point.
    Shows the radius with which inside is bonds we are looking at."""

    # Determine the radii for the plots
    radius_array = np.array([radius, radius, radius])

    centre_scale = math.floor(scale / 2) 
    box_vector = np.array([centre_scale, centre_scale, centre_scale])
    start_point = fractional_to_cartesian(lattice, box_vector).reshape(-1)
    
    start_point_x = start_point[0]
    start_point_y = start_point[1]
    start_point_z = start_point[2]
    
    x_xy = [start_point_x,
            start_point_x + lattice[0, 0],
            start_point_x + lattice[0, 0] + lattice[1, 0],
            start_point_x + lattice[1, 0]]
    y_xy = [start_point_y,
            start_point_y + lattice[0, 1],
            start_point_y + lattice[0, 1] + lattice[1, 1],
            start_point_y + lattice[1, 1]]
    
    x_xz = [start_point_x,
            start_point_x + lattice[0, 0],
            start_point_x + lattice[0, 0] + lattice[2, 0],
            start_point_x + lattice[2, 0]]
    z_xz = [start_point_z,
            start_point_z + lattice[0, 2],
            start_point_z + lattice[0, 2] + lattice[2, 2],
            start_point_z + lattice[2, 2]]
    
    y_yz = [start_point_y,
            start_point_y + lattice[1, 1],
            start_point_y + lattice[1, 1] + lattice[2, 1],
            start_point_y + lattice[2, 1]]
    z_yz = [start_point_z,
            start_point_z + lattice[1, 2],
            start_point_z + lattice[1, 2] + lattice[2, 2],
            start_point_z + lattice[2, 2]]

    # Create a single figure with three subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # --- XY Plot (first subplot) ---
    axes[0].plot(supercell_positions[:, 0], supercell_positions[:, 1], '.', zorder=1)
    axes[0].plot(position[0], position[1], '.', color='red', zorder=2)
    circle_xy = plt.Circle((position[0], position[1]), radius_array[0],
                            color='g', fill=False, linewidth=2, zorder=3, alpha=0.5)
    polygon_xy = plt.Polygon(xy=list(zip(x_xy, y_xy)), fill=False)
    axes[0].add_patch(polygon_xy)
    axes[0].add_patch(circle_xy)
    axes[0].set_aspect('equal')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    
    # --- XZ Plot (second subplot) ---
    axes[1].plot(supercell_positions[:, 0], supercell_positions[:, 2], '.', zorder=1)
    axes[1].plot(position[0], position[2], '.', color='red', zorder=2)
    circle_xz = plt.Circle((position[0], position[2]), radius_array[1],
                            color='g', fill=False, linewidth=2, zorder=3, alpha=0.5)
    polygon_xz = plt.Polygon(xy=list(zip(x_xz, z_xz)), fill=False)
    axes[1].add_patch(polygon_xz)
    axes[1].add_patch(circle_xz)
    axes[1].set_aspect('equal')
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('z')
    
    # --- YZ Plot (third subplot) ---
    axes[2].plot(supercell_positions[:, 1], supercell_positions[:, 2], '.', zorder=1)
    axes[2].plot(position[1], position[2], '.', color='red', zorder=2)
    circle_yz = plt.Circle((position[1], position[2]), radius_array[2],
                            color='g', fill=False, linewidth=2, zorder=3, alpha=0.5)
    polygon_yz = plt.Polygon(xy=list(zip(y_yz, z_yz)), fill=False)
    axes[2].add_patch(polygon_yz)
    axes[2].add_patch(circle_yz)
    axes[2].set_aspect('equal')
    axes[2].set_xlabel('y')
    axes[2].set_ylabel('z')
    
    plt.tight_layout()
    plt.show()

    return

In [None]:
def plotting_graphs_no_radius(lattice, position, supercell_positions):
    """Plots the supercell in the xy, xz, yz planes.
    Highlights the position of the point we want to analyse the bonds from in red, and the supercell positions in blue."""

    # Create a single figure with three subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # --- XY Plot (first subplot) ---
    axes[0].plot(supercell_positions[:, 0], supercell_positions[:, 1], '.', zorder=1)
    axes[0].plot(position[0], position[1], '*', color='red', zorder=2)
    axes[0].set_aspect('equal')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    
    # --- XZ Plot (second subplot) ---
    axes[1].plot(supercell_positions[:, 0], supercell_positions[:, 2], '.', zorder=1)
    axes[1].plot(position[0], position[2], '.', color='red', zorder=2)
    axes[1].set_aspect('equal')
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('z')
    
    # --- YZ Plot (third subplot) ---
    axes[2].plot(supercell_positions[:, 1], supercell_positions[:, 2], '.', zorder=1)
    axes[2].plot(position[1], position[2], '.', color='red', zorder=2)
    axes[2].set_aspect('equal')
    axes[2].set_xlabel('y')
    axes[2].set_ylabel('z')
    
    plt.tight_layout()
    plt.show()

    return

In [None]:
from sympy import symbols, Matrix, linsolve, init_printing
import sympy as sy
init_printing(use_unicode=True)

def tensor_manipulation_first_bond(ops, rev_ops, p, n, atom_number=None):
    """ Applies the symmetry operations that leave the bond invariant to the unspecified exchaneg tesnor J.

        For direct operations, it uses the equation J = R J R^T, where R is the symmetry operation.
        For reversal operations, it uses the equation J = (R J R^T)^T.

        The simulataneous equations are solved using sympy.
        
        The next operation is then aplied to the transformed tensor.

        The final tensor is returned.
    
    Args:
        ops: The symmetry operations that leave the bond invariant.
        rev_ops: The symmetry operations that reverse the bond.
        p: The shell of the point we are looking at.
        n: The number in the shellof the point we are looking at.
        atom_number: Which atom we are looking at. If going through multiple atoms.

    Returns:
        J_transformed: The transformed exchange tensor.
        
        """
    

    # if atom_number is not None:

    #     symbol_string = (
    #         f"J_{{xx}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{xy}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{xz}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{yx}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{yy}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{yz}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{zx}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{zy}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{zz}}^{{a_{{{atom_number}}}p_{{{p}}}A_{{{n}}}}}"
    #     )

    # else:
    #     symbol_string = (
    #         f"J_{{xx}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{xy}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{xz}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{yx}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{yy}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{yz}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{zx}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{zy}}^{{p_{{{p}}}A_{{{n}}}}}, "
    #         f"J_{{zz}}^{{p_{{{p}}}A_{{{n}}}}}"
    #     )

    # uncomment for extra information

    symbol_string = (
            f"J_xx, "
            f"J_xy "
            f"J_xz, "
            f"J_yx, "
            f"J_yy, "
            f"J_yz, "
            f"J_zx, "
            f"J_zy, "
            f"J_zz"
    )

    J_xx, J_xy, J_xz, J_yx, J_yy, J_yz, J_zx, J_zy, J_zz = sy.symbols(symbol_string) # define the symbols for the exchange tensor

    J = sy.Matrix([ # asuigns the symbols to the exchange tensor as a matrix 
        [J_xx, J_xy, J_xz],
        [J_yx, J_yy, J_yz],
        [J_zx, J_zy, J_zz]])

    expressions = [] 
    for j, R in enumerate(ops): #loops over the symmetry operations that leave the bond invariant
        J_transformed = R @ (J @ R.T) # apply the symmetry operation to the exchange tensor
        for x in range(len(J)): 
            if J_transformed[x] == J[x]: # if the components are the same, skip
                continue
            else:
                new_expression = J_transformed[x] - J[x] # if the components are different, create an expression to solve for 0, as they are equal.
                expressions.append(new_expression) # add the expression to the list
        answers = sy.linsolve(expressions, J_xx, J_xy, J_xz, J_yx, J_yy, J_yz, J_zx, J_zy, J_zz) # solve the equations
        answers = list(answers) # convert the answers to a list
        if len(answers) != 0: # if there are answers, substitute them into the transformed tensor
            for y in range(len(answers[0])):
                J_transformed = J_transformed.subs(J_transformed[y], answers[0][y]) #substitue the answers into the transformed tensor
        else:
            continue
        J= J_transformed # update the tensor to the transformed tensor

    for k, R in enumerate(rev_ops): # loops over the symmetry operations that reverse the bond
        J_transformed = (R @ (J @ R.T)).T # apply the symmetry operation to the exchange tensor
        for x in range(len(J)): 
            if J_transformed[x] == J[x]: # if the components are the same, skip
                continue
            else:
                new_expression = J_transformed[x] - J[x] # if the components are different, create an expression to solve for 0, as they are equal.
                expressions.append(new_expression)
        answers = sy.linsolve(expressions, J_xx, J_xy, J_xz, J_yx, J_yy, J_yz, J_zx, J_zy, J_zz) # solve the equations
        answers = list(answers)
        if len(answers) != 0: # if the answers list is empty, skip
            for z in range(len(answers[0])):
                J_transformed = J_transformed.subs(J_transformed[z], answers[0][z]) # substitute the answers into the transformed tensor
        else:
            continue
        J= J_transformed # update the tensor to the transformed tensor

    return J_transformed # return the transformed tensor

In [None]:
def tensor_manipulation_corresponding_bond(J, R, reversal = False):
    """Applies an operation that maps a bond to another bond to the exchange tensor J.
    This relates the two exchange tensors. Eq J' = R J R^T for a direct operation, or J' = (R J R^T)^T for a reversal operation.
    """
    J_transformed = J
    if reversal == False:
        J_transformed = R @ (J @ R.T)
    else:
        J_transformed = (R @ (J @ R.T)).T
    return J_transformed

In [None]:

def plot_2_bonds(atom, atom2, positions, numbers, atom1_number, atom2_number, bond2_atom, bond2_atom2):
    """
    Plots the bonding between two atoms in the xy, xz, and yz planes.
    Colors are assigned based on a dictionary that maps an atomic number
    to a specific color and label. In this example, atomic number 8 (oxygen)
    is blue and atomic number 26 (ruthenium) is red. The two atoms forming the bond
    are drawn in iron's color.
    
    Parameters:
      atom, atom2  : Arrays or lists with the (x, y, z) coordinates of the two atoms in the bond.
      positions    : (N, 3) numpy array of atomic positions.
      numbers      : (N,) numpy array of atomic numbers corresponding to the positions.
      atom1_number : Atomic number for the first atom type (e.g., 26 for iron).
      atom2_number : Atomic number for the second atom type (e.g., 8 for iodine).
    """
    
    # Define the mapping dictionary for atomic numbers.
    # Here we set iron to 26 (red) and iodine to 8 (blue).
    positions = np.vstack((positions, np.atleast_2d(atom)))
    numbers = np.append(numbers, atom1_number)  # Append the
    ruthenium_number = 44
    oxygen_number = 8
    elem_color = {ruthenium_number: 'red', oxygen_number: 'blue', }
    elem_label = {ruthenium_number: 'Ru2+', oxygen_number: 'O-'}


    # Create masks for the two atomic types for all positions.
    mask_atom1 = (numbers == ruthenium_number)
    mask_atom2 = (numbers == oxygen_number)
    
    # Create subplots for the three projections.
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
    
    # --- XY Plane ---
    # Plot positions with atomic number atom1 (ruthenium)
    axes[0].scatter(positions[mask_atom1, 0], positions[mask_atom1, 1],
                    color=elem_color[ruthenium_number], label=elem_label[ruthenium_number],
                    zorder=1, alpha=0.5)
    # Plot positions with atomic number atom2 (oxygen)
    axes[0].scatter(positions[mask_atom2, 0], positions[mask_atom2, 1],
                    color=elem_color[oxygen_number], label=elem_label[oxygen_number],
                    zorder=1, alpha=0.5)
    # Plot the two bond atoms in the ruthenium color.
    axes[0].plot(atom[0], atom[1], 'o', color=elem_color[ruthenium_number], zorder=2)
    # axes[0].plot(atom2[0], atom2[1], 'o', color=elem_color[ruthenium_number], zorder=2)
    # Draw a blue line connecting them.
    axes[0].plot([atom[0], atom2[0]], [atom[1], atom2[1]], 'b-', zorder=3, alpha=0.5)
    # axes[0].plot([bond2_atom[0]], [bond2_atom[1]], 'o', color=elem_color[ruthenium_number], zorder=2)
    # axes[0].plot([bond2_atom2[0]], [bond2_atom2[1]], 'o', color=elem_color[ruthenium_number], zorder=2)
    axes[0].plot([bond2_atom[0], bond2_atom2[0]], [bond2_atom[1], bond2_atom2[1]], 'k-', zorder=3, alpha=0.5)
    axes[0].set_title("Bonding in xy plane")
    axes[0].set_xlabel("x")
    axes[0].set_ylabel("y")
    axes[0].legend()
    
    # --- XZ Plane ---
    axes[1].scatter(positions[mask_atom1, 0], positions[mask_atom1, 2],
                    color=elem_color[ruthenium_number], label=elem_label[ruthenium_number],
                    zorder=1, alpha=0.5)
    axes[1].scatter(positions[mask_atom2, 0], positions[mask_atom2, 2],
                    color=elem_color[oxygen_number], label=elem_label[oxygen_number],
                    zorder=1, alpha=0.5)
    axes[1].plot(atom[0], atom[2], 'o', color=elem_color[ruthenium_number], zorder=2)
    # axes[1].plot(atom2[0], atom2[2], 'o', color=elem_color[ruthenium_number], zorder=2)
    axes[1].plot([atom[0], atom2[0]], [atom[2], atom2[2]], 'b-', zorder=3, alpha=0.5)
    axes[1].plot([bond2_atom[0], bond2_atom2[0]], [bond2_atom[2], bond2_atom2[2]], 'k-', zorder=3, alpha=0.5)
    axes[1].set_title("Bonding in xz plane")
    axes[1].set_xlabel("x")
    axes[1].set_ylabel("z")
    axes[1].legend()
    
    # --- YZ Plane ---
    axes[2].scatter(positions[mask_atom1, 1], positions[mask_atom1, 2],
                    color=elem_color[ruthenium_number], label=elem_label[ruthenium_number],
                    zorder=1, alpha=0.5)
    axes[2].scatter(positions[mask_atom2, 1], positions[mask_atom2, 2],
                    color=elem_color[oxygen_number], label=elem_label[oxygen_number],
                    zorder=1, alpha=0.5)
    axes[2].plot(atom[1], atom[2], 'o', color=elem_color[ruthenium_number], zorder=2)
    # axes[2].plot(atom2[1], atom2[2], 'o', color=elem_color[ruthenium_number], zorder=2)
    axes[2].plot([atom[1], atom2[1]], [atom[2], atom2[2]], 'b-', zorder=3, alpha=0.5)
    axes[2].plot([bond2_atom[1], bond2_atom2[1]], [bond2_atom[2], bond2_atom2[2]], 'k-', zorder=3, alpha=0.5)
    axes[2].set_title("Bonding in yz plane")
    axes[2].set_xlabel("y")
    axes[2].set_ylabel("z")
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()

    # Assume ruthenium_number is defined (here ruthenium_number = 26)
    ruthenium_number = 26

    # First, filter positions to keep only ruthenium atoms.
    ruthenium_mask = (numbers == ruthenium_number)
    ruthenium_positions = positions[ruthenium_mask]
    # ruthenium_positions = np.vstack((ruthenium_positions, np.atleast_2d(atom)))


    # }

    # # Create a 3D plot.
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot each category separately.

    ax.scatter(positions[mask_atom1, 0], positions[mask_atom1, 1], positions[mask_atom1, 2], color = 'red', label = "Ru2+")
    ax.scatter(positions[mask_atom2, 0], positions[mask_atom2, 1], positions[mask_atom2, 2], color = "blue", label = 'O-')
    ax.scatter(atom[0], atom[1], atom[2], color = 'red')

    ax.plot([atom[0], atom2[0]], [atom[1], atom2[1]], [atom[2], atom2[2]],
            color = 'orange', linestyle = '-', lw=1)
    ax.plot([bond2_atom[0], bond2_atom2[0]], [bond2_atom[1], bond2_atom2[1]], [bond2_atom[2], bond2_atom2[2]], 'k-', zorder=3, alpha=0.5)
    

    # ax.set_title("3D Plot: Iron atoms classified by coordinate extremes")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    # ax.set_xticks([])
    # ax.set_yticks([])
    # ax.set_zticks([])

  
    ax.legend()
    plt.show()
    
    # # --- Extra Plot: Positions in the xy-plane with a z value ~ atom[2]'s z ---
    # tol = 1e-5  # tolerance for z
    # mask_z = np.abs(positions[:, 2] - atom[2]) < tol
    # filtered_positions = positions[mask_z]
    # filtered_numbers = numbers[mask_z]
    
    # # Create masks for the filtered positions.
    # mask_atom1_filtered = (filtered_numbers == atom1_number)
    # mask_atom2_filtered = (filtered_numbers == atom2_number)
    
    # plt.figure()
    # plt.scatter(filtered_positions[mask_atom1_filtered, 0], filtered_positions[mask_atom1_filtered, 1],
    #             color=elem_color[atom1_number], label=elem_label[atom1_number], alpha=0.5)
    # plt.scatter(filtered_positions[mask_atom2_filtered, 0], filtered_positions[mask_atom2_filtered, 1],
    #             color=elem_color[atom2_number], label=elem_label[atom2_number], alpha=0.5)
    # plt.plot(atom[0], atom[1], 'o', color=elem_color[atom1_number], zorder=2)
    # plt.plot(atom2[0], atom2[1], 'o', color=elem_color[atom1_number], zorder=2)
    # plt.plot([atom[0], atom2[0]], [atom[1], atom2[1]], 'b-', zorder=3, alpha=0.5)
    
    # plt.title("Bonding xy plane (z ≃ atom[2]'s z)")
    # plt.xlabel("x")
    # plt.ylabel("y")
    # plt.axis("equal")
    # plt.legend()
    # plt.show()
    
    return

In [None]:

def plot_2_bonds(
    atom_a1, atom_a2,            # bond 1 endpoints (x,y,z)
    atom_b1, atom_b2,            # bond 2 endpoints (x,y,z)
    positions,                   # (N,3) array of all atomic positions
    numbers=None,                # (N,) array of atomic numbers (optional)
    species_map=None,            # {Z: {"color":..., "label":...}} (optional)
    planes=("xy", "xz", "yz"),   # which 2D planes to show
    zslice_at=None,              # optional XY slice at given z
    zslice_tol=1e-5,
    show3d=True,
    line1_kwargs=None,           # style for bond 1 line
    line2_kwargs=None,           # style for bond 2 line
    scatter_kwargs=None,
    equal_aspect=True,
    angstrom=True,
):
    """
    Plot two bonds across 2D projections (xy/xz/yz), optional XY z-slice, and optional 3D.

    Returns
    -------
    figs : dict
        Keys: "2d" (Figure), optionally "zslice" and "3d".
    """
    # --- sanitize inputs ---
    P = np.asarray(positions, dtype=float)
    a1 = np.asarray(atom_a1, dtype=float).ravel()
    a2 = np.asarray(atom_a2, dtype=float).ravel()
    b1 = np.asarray(atom_b1, dtype=float).ravel()
    b2 = np.asarray(atom_b2, dtype=float).ravel()

    figs = {}

    # --- defaults ---
    if line1_kwargs is None: line1_kwargs = {}
    if line2_kwargs is None: line2_kwargs = {}
    if scatter_kwargs is None: scatter_kwargs = {}

    line1_kwargs = {"linestyle": "-", "alpha": 0.85, **line1_kwargs}
    line2_kwargs = {"linestyle": "-", "alpha": 0.85, **line2_kwargs}
    scatter_kwargs = {"alpha": 0.5, "zorder": 1, "s": 12, **scatter_kwargs}

    # --- helpers ---
    def scatter_by_species(ax, xy, nums):
        if nums is None or species_map is None:
            ax.scatter(xy[:, 0], xy[:, 1], **scatter_kwargs)
            return
        for Z in np.unique(nums):
            m = (nums == Z)
            style = species_map.get(int(Z), {})
            color = style.get("color", None)
            label = style.get("label", f"Z={int(Z)}")
            ax.scatter(xy[m, 0], xy[m, 1], label=label, color=color, **scatter_kwargs)

    proj_indices = {"xy": (0, 1), "xz": (0, 2), "yz": (1, 2)}
    axis_names = ("x", "y", "z")

    # --- 2D projections ---
    ncols = len(planes)
    fig2d, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(5*ncols, 5))
    if ncols == 1:
        axes = [axes]

    for ax, plane in zip(axes, planes):
        i, j = proj_indices[plane]
        xy = P[:, [i, j]]
        scatter_by_species(ax, xy, numbers)

        # draw bonds
        ax.plot([a1[i], a2[i]], [a1[j], a2[j]], **line1_kwargs)
        ax.plot([b1[i], b2[i]], [b1[j], b2[j]], **line2_kwargs)

        ax.set_xlabel(axis_names[i] + (" (Å)" if angstrom else ""))
        ax.set_ylabel(axis_names[j] + (" (Å)" if angstrom else ""))
        ax.set_title(f"Bonding in {plane} plane")
        if equal_aspect:
            ax.set_aspect("equal", adjustable="box")
        if numbers is not None and species_map:
            ax.legend(loc="best")

    plt.tight_layout()
    figs["2d"] = fig2d

    # --- optional XY z-slice near zslice_at ---
    if zslice_at is not None:
        mask = np.abs(P[:, 2] - float(zslice_at)) < float(zslice_tol)
        xy = P[mask][:, [0, 1]]
        nums = numbers[mask] if numbers is not None else None

        figz, axz = plt.subplots(figsize=(5, 5))
        scatter_by_species(axz, xy, nums)
        axz.plot([a1[0], a2[0]], [a1[1], a2[1]], **line1_kwargs)
        axz.plot([b1[0], b2[0]], [b1[1], b2[1]], **line2_kwargs)
        axz.set_xlabel("x" + (" (Å)" if angstrom else ""))
        axz.set_ylabel("y" + (" (Å)" if angstrom else ""))
        axz.set_title(f"XY slice at z ≈ {zslice_at}")
        if equal_aspect:
            axz.set_aspect("equal", adjustable="box")
        if numbers is not None and species_map:
            axz.legend(loc="best")
        plt.tight_layout()
        figs["zslice"] = figz

    # --- optional 3D ---
    if show3d:
        from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
        fig3d = plt.figure(figsize=(6, 6))
        ax3d = fig3d.add_subplot(111, projection="3d")

        if numbers is None or species_map is None:
            ax3d.scatter(P[:, 0], P[:, 1], P[:, 2], **scatter_kwargs)
        else:
            for Z in np.unique(numbers):
                m = (numbers == Z)
                style = species_map.get(int(Z), {})
                color = style.get("color", None)
                label = style.get("label", f"Z={int(Z)}")
                ax3d.scatter(P[m, 0], P[m, 1], P[m, 2], label=label, color=color, **scatter_kwargs)

        ax3d.plot([a1[0], a2[0]], [a1[1], a2[1]], [a1[2], a2[2]], **line1_kwargs)
        ax3d.plot([b1[0], b2[0]], [b1[1], b2[1]], [b1[2], b2[2]], **line2_kwargs)

        ax3d.set_xlabel("x" + (" (Å)" if angstrom else ""))
        ax3d.set_ylabel("y" + (" (Å)" if angstrom else ""))
        ax3d.set_zlabel("z" + (" (Å)" if angstrom else ""))
        if numbers is not None and species_map:
            ax3d.legend(loc="best")
        figs["3d"] = fig3d

    return figs


In [None]:
species_map = {
    44: {"color": "red",  "label": "Ru"},
    8:  {"color": "blue", "label": "O"},
}

# usage:
# figs = plot_2_bonds(
#     atom_a1=atom, atom_a2=atom2,
#     atom_b1=bond2_atom, atom_b2=bond2_atom2,
#     positions=positions, numbers=numbers,
#     species_map=species_map,
#     line1_kwargs={"color":"orange"},
#     line2_kwargs={"color":"black", "alpha":0.6},
#     zslice_at=atom[2], show3d=True
# )


In [None]:
def plot_bond(atom, atom2, positions, numbers, atom1_number, atom2_number):
    """
    Plots the bonding between two atoms in the xy, xz, and yz planes.
    Colors are assigned based on a dictionary that maps an atomic number
    to a specific color and label. In this example, atomic number 8 (ga)
    is blue and atomic number 26 (manganese) is red. The two atoms forming the bond
    are drawn in manganese's color.
    
    Parameters:
      atom, atom2  : Arrays or lists with the (x, y, z) coordinates of the two atoms in the bond.
      positions    : (N, 3) numpy array of atomic positions.
      numbers      : (N,) numpy array of atomic numbers corresponding to the positions.
      atom1_number : Atomic number for the first atom type (e.g., 26 for manganese).
      atom2_number : Atomic number for the second atom type (e.g., 8 for iodine).
    """
    
    # Define the mapping dictionary for atomic numbers.
    # Here we set manganese to 26 (red) and iodine to 8 (blue).
    positions = np.vstack((positions, np.atleast_2d(atom)))
    numbers = np.append(numbers, atom1_number)  # Append the

    manganese_number = 25
    ga_number = 32
    ni_number = 28
    elem_color = {manganese_number: 'blue', ga_number: 'red', ni_number: 'green'}
    elem_label = {manganese_number: r'Mn$^{2+}$', ga_number: 'Ge', ni_number: 'Ni'}

    # Create masks for the atomic types for all positions.
    mask_atom1 = (numbers == manganese_number)
    mask_atom2 = (numbers == ga_number)
    mask_atom3 = (numbers == ni_number)



    plt.figure(figsize=(5, 5))
    #create 2d plot of xy plane
    

    plt.scatter(positions[mask_atom1, 0], positions[mask_atom1, 1],
                    color=elem_color[manganese_number], label=elem_label[manganese_number],
                    zorder=1, alpha=0.5)
    plt.scatter(positions[mask_atom2, 0], positions[mask_atom2, 1],
                    color=elem_color[ga_number], label=elem_label[ga_number],
                    zorder=1, alpha=0.5)
    plt.scatter(positions[mask_atom3, 0], positions[mask_atom3, 1],
                    color=elem_color[ni_number], label=elem_label[ni_number],
                    zorder=1, alpha=0.5)
    
    # Plot the two bond atoms in manganese's color.


    # plt.plot(atom[0], atom[1], 'o', color=elem_color[manganese_number], zorder=2)
    # plt.plot(atom2[0], atom2[1], 'o', color=elem_color[manganese_number], zorder=2)
    # Draw a blue line connecting them.
    plt.plot([atom[0], atom2[0]], [atom[1], atom2[1]], color = 'orange', linestyle = '-', zorder=2, alpha=0.5)
    angstrom_symbol = "\u212B"
    plt.xlabel("x (\u212B)")
    plt.ylabel("y (\u212B)")
    plt.legend()
    plt.show()
    
    # Create subplots for the three projections.
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
    
    # --- XY Plane ---
    # Plot positions in the xy-plane for each element.
    axes[0].scatter(positions[mask_atom1, 0], positions[mask_atom1, 1],
                    color=elem_color[manganese_number], label=elem_label[manganese_number],
                    zorder=1, alpha=0.5)
    # axes[0].scatter(positions[mask_atom2, 0], positions[mask_atom2, 1],
    #                 color=elem_color[ga_number], label=elem_label[ga_number],
    #                 zorder=1, alpha=0.5)
    # axes[0].scatter(positions[mask_atom3, 0], positions[mask_atom3, 1],
    #                 color=elem_color[ni_number], label=elem_label[ni_number],
    #                 zorder=1, alpha=0.5)
    
    # Plot the two bond atoms in manganese's color

    # axes[0].plot(atom[0], atom[1], 'o', color=elem_color[manganese_number], zorder=2)
    # axes[0].plot(atom2[0], atom2[1], 'o', color=elem_color[manganese_number], zorder=2)
    # Draw a blue line connecting them.
    
    axes[0].plot([atom[0], atom2[0]], [atom[1], atom2[1]], color = 'orange', linestyle = '-', zorder=2, alpha=0.5)
    axes[0].set_xlabel("x")
    axes[0].set_ylabel("y")
    axes[0].legend()
    
    # --- XZ Plane ---
    axes[1].scatter(positions[mask_atom1, 0], positions[mask_atom1, 2],
                    color=elem_color[manganese_number], label=elem_label[manganese_number],
                    zorder=1, alpha=0.5)
    # axes[1].scatter(positions[mask_atom2, 0], positions[mask_atom2, 2],
    #                 color=elem_color[ga_number], label=elem_label[ga_number],
    #                 zorder=1, alpha=0.5)
    # axes[1].scatter(positions[mask_atom3, 0], positions[mask_atom3, 2],
    #                 color=elem_color[ni_number], label=elem_label[ni_number],
    #                 zorder=1, alpha=0.5)
    axes[1].plot([atom[0], atom2[0]], [atom[2], atom2[2]], color = 'orange', linestyle = '-', zorder=2, alpha=0.5)
    axes[1].set_xlabel("x")
    axes[1].set_ylabel("z")
    axes[1].legend()
    
    # --- YZ Plane ---
    axes[2].scatter(positions[mask_atom1, 1], positions[mask_atom1, 2],
                    color=elem_color[manganese_number], label=elem_label[manganese_number],
                    zorder=1, alpha=0.5)
    # axes[2].scatter(positions[mask_atom2, 1], positions[mask_atom2, 2],
    #                 color=elem_color[ga_number], label=elem_label[ga_number],
    #                 zorder=1, alpha=0.5)
    # axes[2].scatter(positions[mask_atom3, 1], positions[mask_atom3, 2],
    #                 color=elem_color[ni_number], label=elem_label[ni_number],
    #                 zorder=1, alpha=0.5)
    # axes[2].plot(atom[1], atom[2], 'o', color=elem_color[manganese_number], zorder=2)
    # axes[2].plot(atom2[1], atom2[2], 'o', color=elem_color[manganese_number], zorder=2)
    # axes[2].plot([atom[1], atom2[1]], [atom[2], atom2[2]], 'b-', zorder=3, alpha=0.5)
    axes[2].plot([atom[1], atom2[1]], [atom[2], atom2[2]], color = 'orange', linestyle = '-', zorder=2, alpha=0.5)

    axes[2].set_xlabel("y")
    axes[2].set_ylabel("z")
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()
    manganese_number = 25
    ga_number = 32

    elem_color = {manganese_number: 'red', ga_number: 'blue', }
    elem_label = {manganese_number: 'Mn2+', ga_number: 'F-'}


    # Create masks for the two atomic types for all positions.
    mask_atom1 = (numbers == manganese_number)
    mask_atom2 = (numbers == ga_number)
    
    # Create subplots for the three projections.
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
    
    # --- XY Plane ---
    # Plot positions with atomic number atom1 (manganese)
    axes[0].scatter(positions[mask_atom1, 0], positions[mask_atom1, 1],
                    color=elem_color[manganese_number], label=elem_label[manganese_number],
                    zorder=1, alpha=0.5)
    # Plot positions with atomic number atom2 (ga)
    axes[0].scatter(positions[mask_atom2, 0], positions[mask_atom2, 1],
                    color=elem_color[ga_number], label=elem_label[ga_number],
                    zorder=1, alpha=0.5)
    # Plot the two bond atoms in the manganese color.
    axes[0].plot(atom[0], atom[1], 'o', color=elem_color[manganese_number], zorder=2)
    # axes[0].plot(atom2[0], atom2[1], 'o', color=elem_color[manganese_number], zorder=2)
    # Draw a blue line connecting them.
    axes[0].plot([atom[0], atom2[0]], [atom[1], atom2[1]], 'b-', zorder=3, alpha=0.5)
    axes[0].set_title("Bonding in xy plane")
    axes[0].set_xlabel("x")
    axes[0].set_ylabel("y")
    axes[0].legend()
    
    # --- XZ Plane ---
    axes[1].scatter(positions[mask_atom1, 0], positions[mask_atom1, 2],
                    color=elem_color[manganese_number], label=elem_label[manganese_number],
                    zorder=1, alpha=0.5)
    axes[1].scatter(positions[mask_atom2, 0], positions[mask_atom2, 2],
                    color=elem_color[ga_number], label=elem_label[ga_number],
                    zorder=1, alpha=0.5)
    axes[1].plot(atom[0], atom[2], 'o', color=elem_color[manganese_number], zorder=2)
    # axes[1].plot(atom2[0], atom2[2], 'o', color=elem_color[manganese_number], zorder=2)
    axes[1].plot([atom[0], atom2[0]], [atom[2], atom2[2]], 'b-', zorder=3, alpha=0.5)
    axes[1].set_title("Bonding in xz plane")
    axes[1].set_xlabel("x")
    axes[1].set_ylabel("z")
    axes[1].legend()
    
    # --- YZ Plane ---
    axes[2].scatter(positions[mask_atom1, 1], positions[mask_atom1, 2],
                    color=elem_color[manganese_number], label=elem_label[manganese_number],
                    zorder=1, alpha=0.5)
    axes[2].scatter(positions[mask_atom2, 1], positions[mask_atom2, 2],
                    color=elem_color[ga_number], label=elem_label[ga_number],
                    zorder=1, alpha=0.5)
    axes[2].plot(atom[1], atom[2], 'o', color=elem_color[manganese_number], zorder=2)
    # axes[2].plot(atom2[1], atom2[2], 'o', color=elem_color[manganese_number], zorder=2)
    axes[2].plot([atom[1], atom2[1]], [atom[2], atom2[2]], 'b-', zorder=3, alpha=0.5)
    axes[2].set_title("Bonding in yz plane")
    axes[2].set_xlabel("y")
    axes[2].set_ylabel("z")
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()

    # Assume manganese_number is defined (here manganese_number = 26)
    manganese_number = 25

    # First, filter positions to keep only manganese atoms.
    manganese_mask = (numbers == manganese_number)
    manganese_positions = positions[manganese_mask]
    # manganese_positions = np.vstack((manganese_positions, np.atleast_2d(atom)))

    # # Create a 3D plot.
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot each category separately.

    ax.scatter(positions[mask_atom1, 0], positions[mask_atom1, 1], positions[mask_atom1, 2], color = 'red', label = "Ru2+")
    ax.scatter(positions[mask_atom2, 0], positions[mask_atom2, 1], positions[mask_atom2, 2], color = "blue", label = 'O-')
    ax.scatter(atom[0], atom[1], atom[2], color = 'red')

    # Now, plot the bond atoms (which are manganese) in the manganese (red) colour.
    # These atoms (atom and atom2) are provided externally.
    manganese_color = "red"  # same as for manganese in your main dictionary
    # ax.scatter(atom[0], atom[1], atom[2],
    #           color="r", label='Bond Atom')
    # ax.scatter(atom2[0], atom2[1], atom2[2],
    #           color="r", label='Bond Atom 2')
    # Draw the bond line connecting these two atoms.
    ax.plot([atom[0], atom2[0]], [atom[1], atom2[1]], [atom[2], atom2[2]],
            color = 'orange', linestyle = '-', lw=1)
    

    # ax.set_title("3D Plot: manganese atoms classified by coordinate extremes")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    # ax.set_xticks([])
    # ax.set_yticks([])
    # ax.set_zticks([])

    ax.legend()
    plt.show()
    
    return

In [None]:
def plot_bond(
    atom_a, atom_b, positions, numbers=None,
    species_map=None,
    planes=("xy", "xz", "yz"),
    zslice_at=None, zslice_tol=1e-5,
    show3d=False,
    line_kwargs=None,
    scatter_kwargs=None,
    equal_aspect=True,
    angstrom=True,
):
    """
    Plot the bond between two atoms across selected planes and (optionally) in 3D,
    with optional coloring by atomic number.

    Parameters
    ----------
    atom_a, atom_b : array-like shape (3,)
        Cartesian coordinates (x,y,z) of the two bond atoms.
    positions : (N,3) array
        All atomic positions.
    numbers : (N,) array or None
        Atomic numbers for `positions`. If None, all points use a single style.
    species_map : dict or None
        Mapping {Z: {"color": str, "label": str}} for coloring/legend by species.
        Example for Mn5Ge3: {25: {"color":"red","label":"Mn"}, 32: {"color":"blue","label":"Ge"}}
    planes : tuple of {"xy","xz","yz"}
        Which 2D projections to draw.
    zslice_at : float or None
        If set, also draw a 2D scatter of points with |z - zslice_at| < zslice_tol (XY view).
    zslice_tol : float
        Tolerance used with `zslice_at`.
    show3d : bool
        If True, also render a 3D plot.
    line_kwargs : dict or None
        Matplotlib kwargs for the bond line (default: {"alpha":0.8}).
    scatter_kwargs : dict or None
        Matplotlib kwargs for scatter points (default: {"alpha":0.5, "zorder":1, "s":12}).
    equal_aspect : bool
        If True, set equal aspect on 2D axes.
    angstrom : bool
        If True, annotate axes with Å units.

    Returns
    -------
    figs : dict
        Keys "2d" → matplotlib Figure with projections, optionally "zslice" and/or "3d".
    """

    atom_a = np.asarray(atom_a, dtype=float).ravel()
    atom_b = np.asarray(atom_b, dtype=float).ravel()
    P = np.asarray(positions, dtype=float)

    figs = {}

    # Styling defaults
    if line_kwargs is None:
        line_kwargs = {}
    line_kwargs = {"linestyle": "-", "alpha": 0.8, **line_kwargs}

    if scatter_kwargs is None:
        scatter_kwargs = {}
    scatter_kwargs = {"alpha": 0.5, "zorder": 1, "s": 12, **scatter_kwargs}

    # Helper: scatter by species or single style
    def scatter_by_species(ax, xy, numbers):
        if numbers is None or species_map is None:
            ax.scatter(xy[:, 0], xy[:, 1], **scatter_kwargs)
            return
        for Z in np.unique(numbers):
            mask = (numbers == Z)
            style = species_map.get(int(Z), {})
            color = style.get("color", None)
            label = style.get("label", f"Z={int(Z)}")
            ax.scatter(xy[mask, 0], xy[mask, 1], label=label, color=color, **scatter_kwargs)

    # Build 2D projections
    proj_indices = {
        "xy": (0, 1),
        "xz": (0, 2),
        "yz": (1, 2),
    }
    ncols = len(planes)
    fig2d, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(5*ncols, 5))
    if ncols == 1:
        axes = [axes]

    for ax, plane in zip(axes, planes):
        i, j = proj_indices[plane]
        xy = P[:, [i, j]]
        scatter_by_species(ax, xy, numbers)

        # Draw the bond
        ax.plot([atom_a[i], atom_b[i]], [atom_a[j], atom_b[j]], **line_kwargs)

        ax.set_xlabel(("x","y","z")[i] + (" (Å)" if angstrom else ""))
        ax.set_ylabel(("x","y","z")[j] + (" (Å)" if angstrom else ""))
        ax.set_title(f"Bonding in {plane} plane")
        if equal_aspect:
            ax.set_aspect("equal", adjustable="box")
        if numbers is not None and species_map:
            ax.legend(loc="best")

    plt.tight_layout()
    figs["2d"] = fig2d

    # Optional z-slice XY plot, adopting the nice idea from your first version
    if zslice_at is not None:
        mask = np.abs(P[:, 2] - float(zslice_at)) < float(zslice_tol)
        xy = P[mask][:, [0, 1]]
        figz, axz = plt.subplots(figsize=(5, 5))
        scatter_by_species(axz, xy, numbers[mask] if numbers is not None else None)
        axz.plot([atom_a[0], atom_b[0]], [atom_a[1], atom_b[1]], **line_kwargs)
        axz.set_xlabel("x" + (" (Å)" if angstrom else ""))
        axz.set_ylabel("y" + (" (Å)" if angstrom else ""))
        axz.set_title(f"XY slice at z ≈ {zslice_at}")
        if equal_aspect:
            axz.set_aspect("equal", adjustable="box")
        if numbers is not None and species_map:
            axz.legend(loc="best")
        plt.tight_layout()
        figs["zslice"] = figz

    # Optional 3D
    if show3d:
        from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 (needed for 3D)
        fig3d = plt.figure(figsize=(6, 6))
        ax3d = fig3d.add_subplot(111, projection="3d")

        if numbers is None or species_map is None:
            ax3d.scatter(P[:, 0], P[:, 1], P[:, 2], **scatter_kwargs)
        else:
            for Z in np.unique(numbers):
                mask = (numbers == Z)
                style = species_map.get(int(Z), {})
                color = style.get("color", None)
                label = style.get("label", f"Z={int(Z)}")
                ax3d.scatter(P[mask, 0], P[mask, 1], P[mask, 2], label=label, color=color, **scatter_kwargs)

        ax3d.plot([atom_a[0], atom_b[0]], [atom_a[1], atom_b[1]], [atom_a[2], atom_b[2]], **line_kwargs)
        ax3d.set_xlabel("x" + (" (Å)" if angstrom else ""))
        ax3d.set_ylabel("y" + (" (Å)" if angstrom else ""))
        ax3d.set_zlabel("z" + (" (Å)" if angstrom else ""))
        if numbers is not None and species_map:
            ax3d.legend(loc="best")
        figs["3d"] = fig3d

    return figs


In [None]:
species_map = {
    25: {"color": "red",  "label": "Mn"},
    32: {"color": "blue", "label": "Ge"},
}
# usage:
# plot_bond(atom_a, atom_b, positions, numbers, species_map, zslice_at=atom_a[2], show3d=True)

In [None]:
def seperate_tensor_for_each_bond(
    atom_data, space_group, i,
    numbers_to_skip=None,
    min_distance=None, max_distance=None,        # optional distance filters
    plot_cart=True,                              # choose cart vs frac plotting
    plot_frac=False
):   
    """Automates the process of finding the exchange tensor for each bond in a given atom.

    Takes in a dictionary of atom data and returns a dictionary for each bond. 
    """
    frac_positions = atom_data['frac_positions_sorted'] # gather all data for atom
    cart_positions = atom_data['cart_positions_sorted']
    numbers = atom_data['numbers_sorted']
    shells = atom_data['shells_sorted']
    distances = atom_data['distances_sorted']
    frac_pos = atom_data['fractional position']
    cart_pos = atom_data['cartesian position']
    unit_atom_number = atom_data['number']

    exchange_data = {} # initiate dictionary to store the data for each bond
    magnetic_shell = 0
    # Process bonds for the current atom
    magnetic_number_checked = 0
    for n, shell in enumerate(shells): # loop over the shells
        print("\nShell", shell, ":")
        current_atom = numbers[n] # assign the current atom of the bond we are lookinfg at

        if numbers_to_skip is not None and current_atom in numbers_to_skip:
            print("Skipping atom:", current_atom)
            continue

        if (min_distance is not None and distances[n] < min_distance) or \
           (max_distance is not None and distances[n] > max_distance):
            continue

        magnetic_number_checked += 1
        # if magnetic_number_checked > 2:
        #     # print("Skipping atom:", current_atom)
        #     break
        current_distance = distances[n] # assign the current distance of the bond we are looking at
        print("  Distance:", current_distance, " Element number:", current_atom) 
        
        # Determine bond index within the shell (accounting for repeats)
        bonds_before = len(np.where(shells < shell)[0]) 
        bond_number = n + 1 - bonds_before # what number bond it is in that shell

        # if bond_number == 1:
        #     magnetic_shell += 1
        # if magnetic_shell != 2:
        #     continue
        # print("Magnetic shell:", magnetic_shell)

        #check what bond number the atom is in that shall when the atom is the same. if it is the first atom of this type in the shell, the bond numebr would be
        # 1, if it is the second atom of this type in the shell, the bond number would be 2, and so on.
        bond_number_shell_same_atom = 1 + sum(1 for j in range(n) if shells[j] == shell and numbers[j] == current_atom) # what number bond of the same atom it is in that shell

        neighbour_frac = frac_positions[n] # get the neighbour fractional position
        neighbour_cart = cart_positions[n] # get the neighbour cartesian position
        direct_ops, reversal_ops = symmetry_operations_for_bond_2_points(space_group, frac_pos, neighbour_frac) # apply the symmetry operations that leave the bond invariant
        direct_rotations = np.array([op.rotation for op in direct_ops]) 
        reversal_rotations = np.array([op.rotation for op in reversal_ops])
        simplified_tensor = tensor_manipulation_first_bond(direct_rotations, reversal_rotations, shell, numbers[n], atom_number=i+1) # specify the tensor based on these symmetry operations

        # if reversal_rotations.size > 0:
        #     all_bond_ops = list(np.concatenate((direct_rotations, reversal_rotations), axis=0)) 
        # else:
        #     all_bond_ops = list(direct_rotations)
        # pointgroup = spglib.get_pointgroup(all_bond_ops)
        # print("Point group:", pointgroup)

        print("[Bond", bond_number, "in shell", shell, "]")
        print("bond number for this atom in the shell:", bond_number_shell_same_atom)
        print("Direct ops:", direct_rotations, "reversal ops:", reversal_rotations)
        print("Bond between:")
        print("Central atom:", frac_pos, "element:", unit_atom_number)
        print("neighbour atom:", neighbour_frac, "(element:", current_atom, ")")
        print("central atom cartesian", cart_pos)
        print("neighbour atom cartesian", neighbour_cart)
        display(simplified_tensor) # display the tensor

        if plot_frac:
            plot_bond(frac_pos, neighbour_frac, frac_positions)
        if plot_cart:
            # if you moved to the new plot_bond signature, adjust call; if not, keep as-is:
            plot_bond(cart_pos, cart_positions[n], cart_positions, numbers, unit_atom_number, current_atom)

        bond_vector_cart = difference(cart_pos, neighbour_cart) # calcualte the bond vector in cartesian coordinates
        bond_vector_cart = np.array(bond_vector_cart)
        bond_vector_frac = difference(frac_pos, neighbour_frac) # calculate the bond vector in fractional coordinates
        bond_vector_frac = np.array(bond_vector_frac)

        exchange_data[n] = { # store the data for this bond
            "fractional position": frac_pos,
            "cartesian position": cart_pos,
            "element 1": unit_atom_number,
            "neighbour_frac": neighbour_frac,
            "neighbour_cart": neighbour_cart,
            "element 2": current_atom,
            "distance": current_distance,
            "bond_vector_frac": bond_vector_frac,
            "bond_vector_cart": bond_vector_cart,
            "shell": shell,
            "bond_number_shell": bond_number,
            "bond_number_shell_same_atom": bond_number_shell_same_atom,
            "is first tensor": True,
            "first bond tensor": simplified_tensor,
            "corresponding tensor for this bond": simplified_tensor,
            "direct_ops": direct_rotations,
            "reversal_ops": reversal_rotations,
            } 

    return exchange_data # return the data for this atom

In [None]:
def first_tensor_for_bond_then_others_from_same_atom(
    atom_data, space_group, i, numbers_to_skip=None,
    min_distance=None, max_distance=None,
    plot_frac=True, plot_cart=False
):  
    """ Takes in a dictionary of atom data and returns a dictionary for each bond. """
    """For the first NN in that shell it finds the specified tensor.
    It then goes through each other atom one by one, and looks at their bonds
    from this unit cell position. The corresponding atomic numbers have to be the same.
    If the first bond can be mapped on to this bond by symmetry,
    the symmetry operations are applied to the first bond tensor.
    And the specified tensor is given. ((We know this has the same values as the first tensor.))"""

    frac_positions = atom_data['frac_positions_sorted'] # gather all data for atom
    cart_positions = atom_data['cart_positions_sorted']
    numbers = atom_data['numbers_sorted']
    shells = atom_data['shells_sorted']
    distances = atom_data['distances_sorted']
    frac_pos = atom_data['fractional position']
    cart_pos = atom_data['cartesian position']
    unit_atom_number = atom_data['number']

    exchange_data = {}
    
    # Dictionary to store first-bond tensors per shell.
    # For each shell, we'll have a nested dictionary keyed by atomic number.
    shell_data = {}
    
    # Process bonds for the current atom
    for n, shell in enumerate(shells): # loop over the shells
        print("\nShell", shell, ":")
        current_atom = numbers[n]

        if numbers_to_skip is not None and current_atom in numbers_to_skip:
            print("Skipping atom:", current_atom)
            continue
        if (min_distance is not None and distances[n] < min_distance) or \
           (max_distance is not None and distances[n] > max_distance):
            continue

        current_distance = distances[n]
        print("  Distance:", current_distance, " Element number:", current_atom)
        
        # Determine bond index within the shell (accounting for repeats)
        bonds_before = len(np.where(shells < shell)[0]) # how many bonds are before this shell
        bond_number = n + 1 - bonds_before # what number bond it is in that shell

        bond_number_shell_same_atom = 1 + sum(1 for j in range(n) if shells[j] == shell and numbers[j] == current_atom)
        
        neighbour_frac = frac_positions[n]
        neighbour_cart = cart_positions[n]
        
        if shell not in shell_data: # Initialise dictionary if shell is not already present
            shell_data[shell] = {}
        
        if bond_number == 1 or current_atom not in shell_data[shell]:  # check if firs tbond of its type in the shell
            direct_ops, reversal_ops = symmetry_operations_for_bond_2_points(space_group, frac_pos, neighbour_frac) # find the symmetry operations that leave the bond invariant
            direct_rotations = np.array([op.rotation for op in direct_ops])
            reversal_rotations = np.array([op.rotation for op in reversal_ops])
            print("Direct ops:", direct_rotations, "reversal ops:", reversal_rotations)
            print("  [Bond", bond_number, "in shell", shell, "] First instance for atomic number", current_atom)
            print("bond number for this atom in the shell:", bond_number_shell_same_atom)
            simplified_tensor = tensor_manipulation_first_bond(direct_rotations, reversal_rotations, shell, current_atom, atom_number=i+1) # specify the tensor
            print("Bond between:")
            print("  Central atom:", frac_pos, "element:", unit_atom_number)
            print("  Neighbour atom:", neighbour_frac, "(element:", current_atom, ")")
            display(simplified_tensor) # display the tensor

            if plot_frac:
                plot_bond(frac_pos, neighbour_frac, frac_positions)
            if plot_cart:
                plot_bond(cart_pos, neighbour_cart, cart_positions)

            bond_vector_cart = np.array(difference(cart_pos, neighbour_cart))
            bond_vector_frac = np.array(difference(frac_pos, neighbour_frac))

            exchange_data[n] = { # store the data for this bond
            "fractional position": frac_pos,
            "cartesian position": cart_pos,
            "element 1": unit_atom_number,
            "neighbour_frac": neighbour_frac,
            "neighbour_cart": neighbour_cart,
            "element 2": current_atom,
            "distance": current_distance,
            "bond_vector_frac": bond_vector_frac,
            "bond_vector_cart": bond_vector_cart,
            "shell": shell,
            "bond_number_shell": bond_number,
            "bond_number_shell_same_atom": bond_number_shell_same_atom,
            "is first tensor": True,
            "first bond tensor": simplified_tensor,
            "corresponding tensor for this bond": simplified_tensor,
            "direct_ops": direct_rotations,
            "reversal_ops": reversal_rotations,
            } 
            shell_data[shell][current_atom] = (neighbour_frac, simplified_tensor) # store the first bond tensor for this atom in this shell

            
        else:
            # If not first bond of its type,
            first_bond_frac, prev_tensor = shell_data[shell][current_atom]  # use the stored tensor for mapping.
            direct_ops, reversal_ops = symmetry_operations_for_bond_mapping_3_points(space_group, frac_pos, first_bond_frac, neighbour_frac) # find the operations that map the first bond to this bond
            direct_rotations = np.array(direct_ops)
            reversal_rotations = np.array(reversal_ops)
            if direct_rotations.size == 0 and reversal_rotations.size == 0:
                    print("Bond is not mapped at all (skip)")
                    continue
            print("Direct ops:", direct_rotations, "reversal ops:", reversal_rotations)
            if direct_rotations.size > 0 :
                chosen_rotation = direct_rotations[0] #pick one rotation that maps the original bond to this bond and specialise the tensor with this rotation
            elif reversal_rotations.size > 0:
                chosen_rotation = reversal_rotations[0]
            else:
                print("No symmetry operation found for this bond.")
                continue
            print("chosen rotation:", chosen_rotation)
            print("[Bond", bond_number, "in shell", shell, "] Subsequent instance for atomic number", current_atom)
            print("bond number for this atom in the shell:", bond_number_shell_same_atom)
            corresponding_tensor = tensor_manipulation_corresponding_bond(prev_tensor, chosen_rotation) # relate the two tensors witht he operation
            print("Bond between:")
            print("  Central atom:", frac_pos, "element:", unit_atom_number)
            print("  Neighbour atom:", neighbour_frac, "(element:", current_atom, ")")
            display(corresponding_tensor)  # display the tensor
            
            if plot_frac:
                plot_bond(frac_pos, neighbour_frac, frac_positions)
            if plot_cart:
                plot_bond(cart_pos, neighbour_cart, cart_positions)

            bond_vector_cart = np.array(difference(cart_pos, neighbour_cart))
            bond_vector_frac = np.array(difference(frac_pos, neighbour_frac))

            exchange_data[n] = { # store the data for this bond
            "fractional position": frac_pos,
            "cartesian position": cart_pos,
            "element 1": unit_atom_number,
            "neighbour_frac": neighbour_frac,
            "neighbour_cart": neighbour_cart,
            "element 2": current_atom,
            "distance": current_distance,
            "bond_vector_frac": bond_vector_frac,
            "bond_vector_cart": bond_vector_cart,
            "shell": shell,
            "bond_number_shell": bond_number,
            "bond_number_shell_same_atom": bond_number_shell_same_atom,
            "is first tensor": False,
            "first bond tensor": simplified_tensor,
            "corresponding tensor for this bond": corresponding_tensor,
            "direct_ops": direct_rotations,
            "reversal_ops": reversal_rotations,
            } 

    return exchange_data

In [None]:
# def seperate_tensor_for_bonds_checks_other_atoms(atom_data, space_group, i, radius=None, numbers_to_skip=None):   
def seperate_tensor_for_bonds_checks_other_atoms(
    atom_data, space_group, i,
    radius=None, numbers_to_skip=None,
    min_distance=4.1, max_distance=4.4,
    cut_window= (-0.5,1.2),        # your previous window but now configurable
    mapping_atol=1e-3,                          # np.allclose tolerance for distance match
    plot_cart=True, plot_frac=False
):   
    """For the first NN in that shell it finds the specified tensor.
   It then goes through each other atom one by one, and looks at their bonds.
   It checks wether this bond can map to any other atoms bonds, first based on ((((shell))) and distance.
   All atomic numbers have to be the same.
   Then it checks wether the first bond can map onto this bond by symmetry.
   If the bond can be mapped, the symmetry operations are applied to the first bond tensor.
   And the specified tensor is given. ((We know this has the same values as the first tensor.))"""

    frac_positions = atom_data['frac_positions_sorted'] # gather all data for atom
    cart_positions = atom_data['cart_positions_sorted']
    numbers = atom_data['numbers_sorted']
    shells = atom_data['shells_sorted']
    distances = atom_data['distances_sorted']
    frac_pos = atom_data['fractional position']
    cart_pos = atom_data['cartesian position']
    unit_atom_number = atom_data['number']

    shell_data = {}
    exchange_data = {}
    low, high = cut_window

    inserted = False
    done = False
    # Process bonds for the current atom
    for n, shell in enumerate(shells): # loop over the shells
        current_atom = numbers[n]
        
        # add atom picker
        if numbers_to_skip is not None and current_atom in numbers_to_skip:
            print("Skipping atom:", current_atom)
            continue

        if (min_distance is not None and distances[n] < min_distance) or \
           (max_distance is not None and distances[n] > max_distance):
            continue

        print(n)
        # add n picker?
        current_distance = distances[n]
        print("\nShell", shell, ":")
        print("  Distance:", current_distance, "Element number:", current_atom)
        
        bonds_before = len(np.where(shells < shell)[0]) # how many bonds are before this shell
        bond_number = n + 1 - bonds_before # what number bond it is in that shell
      
        bond_number_shell_same_atom = 1 + sum(1 for j in range(n) if shells[j] == shell and numbers[j] == current_atom)
        
        neighbour_frac = frac_positions[n]
        neighbour_cart = cart_positions[n]
        
        if shell not in shell_data:
            shell_data[shell] = {} # start dictionary for this shell if it does not exist yet
    
        direct_ops, reversal_ops = symmetry_operations_for_bond_2_points(space_group, frac_pos, neighbour_frac) # find symmetry operations that leave the bond invariant
        direct_rotations = np.array([op.rotation for op in direct_ops])
        reversal_rotations = np.array([op.rotation for op in reversal_ops])
        print("  [Bond", bond_number, "in shell", shell, "] First instance for current_atom", current_atom)
        print("bond number for this atom in the shell:", bond_number_shell_same_atom)
        simplified_tensor = tensor_manipulation_first_bond(direct_rotations, reversal_rotations, shell, current_atom, atom_number=i+1) # specify the tensor
        display(simplified_tensor)

        if reversal_rotations.size > 0:
            all_bond_ops = list(np.concatenate((direct_rotations, reversal_rotations), axis=0))
            print("with rev ops")
        else:
            all_bond_ops = list(direct_rotations)
        pointgroup = spglib.get_pointgroup(all_bond_ops)
        print("Point group of original bond:", pointgroup)

        if plot_cart:
            plot_bond(cart_pos, neighbour_cart, cart_positions, numbers, unit_atom_number, current_atom)
        if plot_frac:
            plot_bond(frac_pos, neighbour_frac, frac_positions)

        shell_data[shell][current_atom] = (neighbour_frac, simplified_tensor) # store the first bond tensor adn positons for this element in this shell
        print("Bond between:")
        print("  Bond between:")
        print("    Central atom:", frac_pos)
        print("    Neighbour atom:", neighbour_frac, "(element:", numbers[n], ")")
        print("In shell", shell)
        print("Central atom cartesian", cart_pos)
        print("Neighbour atom cartesian", neighbour_cart)

        bond_vector_cart = np.array(difference(cart_pos, neighbour_cart))
        bond_vector_frac = np.array(difference(frac_pos, neighbour_frac))
        print("bond vector frac", bond_vector_frac)
        print("bond vector cart", bond_vector_cart)

        exchange_data[n] = { # store the data for this bond
            "fractional position": frac_pos,
            "cartesian position": cart_pos,
            "element 1": unit_atom_number,
            "neighbour_frac": neighbour_frac,
            "neighbour_cart": neighbour_cart,
            "element 2": current_atom,
            "distance": current_distance,
            "bond_vector_frac": bond_vector_frac,
            "bond_vector_cart": bond_vector_cart,
            "shell": shell,
            "bond_number_shell": bond_number,
            "bond_number_shell_same_atom": bond_number_shell_same_atom,
            "is first tensor": True,
            "first bond tensor": simplified_tensor,
            "corresponding tensor for this bond": simplified_tensor,
            "direct_ops": direct_rotations,
            "reversal_ops": reversal_rotations,
            "mappings": {}
            } 

        
        if inserted == False:
            frac_positions_with_original = np.insert(frac_positions, i, frac_pos, axis=0)
            cart_positions_with_original = np.insert(cart_positions, i, cart_pos, axis=0)
            numbers_with_original = np.insert(numbers, i, unit_atom_number, axis=0)
            inserted = True

        print(np.shape(frac_positions_with_original))
        print(np.shape(cart_positions_with_original))
        print(np.shape(numbers_with_original))

        # Now loop over all other atoms to check for mapping via symmetry
        for j, frac_pos2 in enumerate(frac_positions_with_original):
            if j == i:
                print("Skipping the same atom.")
                continue  # only consider atoms after the current one
            if numbers_with_original[j] != unit_atom_number:
                print("Different atomic numbers, skipping bond mapping.")
                continue  # Skip if the atom is not the same as the central one
            
            print("\n--> Mapping from different atom:")
            print("Processing atom", j+1, "with fractional position:", frac_pos2, "and atomic number:", numbers_with_original[j],
                  "comparing bonds to central atom bonds", i+1, "with fractional position:", frac_pos, "and atomic number:", unit_atom_number)
            print("Currently checking equivalnce to bond", frac_pos, "-", neighbour_frac)
            
            cart_pos2 = cart_positions_with_original[j]
            distances2, idx_within2 = get_point_distances(cart_pos2, cart_positions_with_original, radius=None) # find the distances to all other atoms and their index
            frac_positions_cut2 = frac_positions_with_original[idx_within2] # get the fractional positions of the atoms within the radius
            numbers_cut2 = numbers_with_original[idx_within2] # get the atomic numbers of the atoms within the radius
            cart_positions_cut2 = cart_positions_with_original[idx_within2] # get the cartesian positions of the atoms within the radius
            
            shells2 = grouping(distances2) # find the shells of the atoms within the radius
            rounded_distances2 = np.round(distances2, decimals=4)
            sort_indices2 = np.argsort(rounded_distances2) # find the indices to sort distances in ascending order
            distances_sorted2 = rounded_distances2[sort_indices2] # sort the distances
            shells_sorted2 = shells2[sort_indices2] # sort the shells
            numbers_sorted2 = numbers_cut2[sort_indices2]# sort the atomic numbers
            frac_positions_sorted2 = frac_positions_cut2[sort_indices2] # sort the fractional positions
            cart_positions_sorted2 = cart_positions_cut2[sort_indices2] # sort the cartesian positions

            cut = []
            for k, x in enumerate(frac_positions_sorted2):
                if (x[0] > high or x[1] > high or x[2] > high or x[0] < low or x[1] < low or x[2] < low): # check if the atom is outside the desired area
                    cut.append(k)        
            frac_positions_sorted2 = np.delete(frac_positions_sorted2, cut, axis=0)
            cart_positions_sorted2 = np.delete(cart_positions_sorted2, cut, axis=0)
            distances_sorted2 = np.delete(distances_sorted2, cut, axis=0)
            shells_sorted2 = np.delete(shells_sorted2, cut, axis=0)
            numbers_sorted2 = np.delete(numbers_sorted2, cut, axis=0)

            print("For atom", j+1, "found neighbour shells (sorted):", shells_sorted2, "atomic numbers:", numbers_sorted2)
            
            # Loop over bonds for atom pos2
            for m, shell2 in enumerate(shells_sorted2):
                current_atom2 = numbers_sorted2[m]
                current_distance2 = distances_sorted2[m]
                if current_atom2 != current_atom: # check if the atomic number is the same
                    print("Different element numbers, skipping bond mapping.")
                    continue
                # Check if distances match exactly
                # print(current_distance)
                # print(current_distance2)
                # if np.allclose(current_distance2, current_distance, atol = mapping_atol): # check if the distance is the same as original bond
                #     print("Distances match, proceeding with bond mapping.")
                #     print(current_distance2, current_distance)
                # else:
                #     print("Distances do not match, skipping bond mapping.")
                #     continue
                if not np.allclose(current_distance2, current_distance, atol=mapping_atol):
                    print("Distances do not match, skipping bond mapping.")
                    continue
                
                print("\nShell", shell2, "for atom", j+1, ":")
                print("  Distance:", current_distance2, "Element number:",current_atom2)

                bonds_before2 = len(np.where(shells_sorted2 < shell2)[0]) # how many bonds are before this shell
                bond_number2 = m + 1 - bonds_before2 # what number bond it is in that shell
                bond_number_shell_same_atom2 = 1 + sum(1 for k in range(m) if shells_sorted2[k] == shell2 and numbers_sorted2[k] == current_atom2)
                
                neighbour_frac2 = frac_positions_sorted2[m]
                neighbour_cart2 = cart_positions_sorted2[m]

                # Retrieve the stored reference tensor for the appropriate shell and current_atom.
                if shell in shell_data and current_atom in shell_data[shell]:
                    ref_tensor = shell_data[shell][current_atom][1] # retrieve the reference tensor
                else:
                    print("Reference tensor not found for shell", shell, "and current_atom", current_atom)
                    continue
                
                # Use the mapping symmetry operations from the reference bond (from pos) to the new bond (from pos2)
                direct_ops2, reversal_ops2 = symmetry_operations_for_different_bond_mapping_4_points( # find the symmetry operations that map the first bond to this bond
                    space_group, frac_pos, neighbour_frac, frac_pos2, neighbour_frac2)
                direct_rotations2 = np.array([op.rotation for op in direct_ops2])
                reversal_rotations2 = np.array([op.rotation for op in reversal_ops2])
                
                if direct_rotations2.size == 0 and reversal_rotations2.size == 0:
                        print("Bond is not mapped at all (skip)")
                        continue
                print("Direct ops:", direct_rotations2, "reversal ops:", reversal_rotations2)

                def pick_up_to_three(rots):
                    # take at most the first 3 items
                    sel = list(rots[:3])
                    sel += [None] * (3 - len(sel))
                    return sel
                
                if direct_rotations2.size > 0:
                    chosen_rotation, chosen_rotation2, chosen_rotation3 = pick_up_to_three(direct_rotations2)
                    rev = False
                elif reversal_rotations2.size > 0:
                    chosen_rotation, chosen_rotation2, chosen_rotation3 = pick_up_to_three(reversal_rotations2)
                    rev = True
                else:
                    # no rotations at all
                    chosen_rotation = chosen_rotation2 = chosen_rotation3 = None
                    rev = False
                    print("No symmetry operation found for this bond.")
            
                print("bond giving the original simplified tensor is")
                print("frac _ pos:", frac_pos)
                print("neighbour:", neighbour_frac)
                print("bond vector cartesian:", bond_vector_cart)
                print("bond vector fractional:", bond_vector_frac)
                display(simplified_tensor)

                if chosen_rotation is not None:
                    print("chosen rotation 1:", chosen_rotation)
                    print("bond number for this atom in the shell:", bond_number_shell_same_atom)
                    corresponding_tensor2 = tensor_manipulation_corresponding_bond(ref_tensor, chosen_rotation, reversal=rev)
                    print("")
                    print("  Bond between:")
                    print("    Central atom:", frac_pos2)
                    print("    Neighbour atom:", neighbour_frac2, "(element:", numbers_sorted2[m], ")")
                    print("In shell", shell2)
                    print("corresponding tensor for this bond:")
                    display(corresponding_tensor2)

                if chosen_rotation2 is not None:
                    print("chosen rotation 2:", chosen_rotation2)
                    print("bond number for this atom in the shell:", bond_number_shell_same_atom)
                    corresponding_tensor2 = tensor_manipulation_corresponding_bond(ref_tensor, chosen_rotation2, reversal=rev)
                    print("")
                    print("  Bond between:")
                    print("    Central atom:", frac_pos2)
                    print("    Neighbour atom:", neighbour_frac2, "(element:", numbers_sorted2[m], ")")
                    print("In shell", shell2)
                    print("corresponding tensor for this bond:")
                    display(corresponding_tensor2)

                if chosen_rotation3 is not None:
                    print("chosen rotation 3:", chosen_rotation3)
                    print("bond number for this atom in the shell:", bond_number_shell_same_atom)
                    corresponding_tensor2 = tensor_manipulation_corresponding_bond(ref_tensor, chosen_rotation3, reversal=rev)
                    print("")
                    print("  Bond between:")
                    print("    Central atom:", frac_pos2)
                    print("    Neighbour atom:", neighbour_frac2, "(element:", numbers_sorted2[m], ")")
                    print("In shell", shell2)
                    print("corresponding tensor for this bond:")
                    display(corresponding_tensor2)

                if reversal_rotations2.size > 0:
                    all_bond_ops_2 = np.concatenate((direct_rotations2, reversal_rotations2), axis=0)
                    print("with rev ops")
                else:
                    all_bond_ops_2 = direct_rotations2

                pointgroup2 = spglib.get_pointgroup(all_bond_ops_2)
                print("Point group of operations that map the first bond to this bond:", pointgroup2)
    
                if plot_cart:
                    plot_bond(cart_pos2, neighbour_cart2, cart_positions_sorted2, numbers_sorted2, unit_atom_number, current_atom2)
                if plot_frac:
                    plot_bond(frac_pos2, neighbour_frac2, frac_positions_sorted2)

                bond_vector_cart2 = np.array(difference(cart_pos2, neighbour_cart2))
                bond_vector_frac2 = np.array(difference(frac_pos2, neighbour_frac2))

                exchange_data[n]["mappings"][j] = { # store the data for the mapped bond wihtin the data for the original bond
                    "first atom frac": frac_pos2,
                    "first atom cart": cart_pos2,
                    "element 1": unit_atom_number,
                    "neighbour_frac": neighbour_frac2,
                    "neighbour_cart": neighbour_cart2,
                    "element 2": current_atom2,
                    "distance": current_distance2,
                    "bond_vector_frac": bond_vector_frac2,
                    "bond_vector_cart": bond_vector_cart2,
                    "shell": shell2,
                    "bond_number_shell": bond_number2,
                    "bond_number_shell_same_atom": bond_number_shell_same_atom2,
                    "is first tensor": False,
                    "first bond tensor": ref_tensor,
                    "corresponding tensor for this bond": corresponding_tensor2,
                    "direct_ops": direct_rotations2,
                    "reversal_ops": reversal_rotations2,
                    } 
                
            print("stopping loop over second atom shells")
        print("stopping loop over positons")
    print('stopping all loops (over original shells)')

    return exchange_data

In [None]:
def run_code_unit_cell(file):

    """" Main funtion to find the specified exchange tensors for each bond in the unit cell. """

    atoms = io.read(file)  # Read the file to gather information for lattice vectors, positions, and numbers
    space_group = get_space_group_symops(atoms)  # Get the space group
    cell = get_cell(atoms)  # Get the cell object
    lattice, scaled_positions, numbers = cell  # Unpack the cell object
    positions = atoms.get_positions()  # Get the cartesian positions if needed

    print(scaled_positions)
   
    all_data = {}  # Initialize the dictionary to store data for each atom
    all_data['lattice'] = lattice
    skip_elements = [9] 

    for i, frac_pos in enumerate(scaled_positions):

        if skip_elements is not None and numbers[i] in skip_elements:
            print("Skipping atom:", numbers[i])
            continue
        print("Processing atom", i+1, "of", len(scaled_positions), "with fractional position:", frac_pos, "and atomic number:", numbers[i])
        cart_pos = positions[i]  # Get the cartesian position of the atom
      
        distances, indices = get_point_distances(cart_pos, positions)
        scaled_positions_cut = scaled_positions[indices]
        numbers_cut = numbers[indices]
        positions_cut = positions[indices] 
   
        shells = grouping(distances)  # Group the atoms into shells based on cartesian distance
        rounded_distances = np.round(distances, decimals=4)  # Round distances to handle floating point precision
        s_i = np.argsort(rounded_distances)  # Indices that sort the distances in ascending order

        distances_sorted = rounded_distances[s_i] # sort the distances
        shells_sorted = shells[s_i] # sort the shells
        numbers_sorted = numbers_cut[s_i] # sort the atomic numbers
        cart_positions_sorted = positions_cut[s_i] # sort the cartesian positions
        frac_positions_sorted = scaled_positions_cut[s_i]  # sort the fractional positions

        all_data[i] = { # Store the data for this atom in a dictionary
            'cartesian position': cart_pos,
            'fractional position': frac_pos,
            'number': numbers[i],
            'distances_sorted': distances_sorted,
            'shells_sorted': shells_sorted,
            'numbers_sorted': numbers_sorted,
            'cart_positions_sorted': cart_positions_sorted,
            'frac_positions_sorted': frac_positions_sorted,
        }
        
    for i in range(len(scaled_positions)): # loop over all atoms in the unit cel
        atom_data = all_data[i]
        exchange_data = seperate_tensor_for_each_bond(atom_data, space_group, i, numbers_to_skip=skip_elements) # find the exchange tensor for each bond
        all_data[i]['exchange_data'] = exchange_data # store the exchange data for this atom

    return all_data

In [None]:
# def run_code_supercell(file):

def run_code_supercell(
    file,
    scale=3,                                      # decide on the size of the supercell
    pick_atom_index=None,                         # None → loop all atoms; e.g. 7 → only atom i==7
    skip_elements=None,                           # e.g. [32] (Ge) for Mn5Ge3; None → keep all
    cut_window=(-0.5, 1.2),                       # fractional window min/max for each coord
    radius_mode="norm_lengths",                   # "norm_lengths" or "max_length"
    radius_factor=1.01,                           # inflate radius slightly (safety margin)
    do_seperate_tensor_for_each_bond=True,
    do_first_tensor_for_bond_then_others_from_same_atom=False,
    do_seperate_tensor_for_bonds_checks_other_atoms=False,
    plotting=True                                 # master switch for plotting_graphs
):

    """ Main funtion to find the specified exchange tensors for each bond in the supercell. """

    atoms = io.read(file)  # Read the file to gather information for lattice vectors, positions, and numbers
    space_group = get_space_group_symops(atoms)  # Get the space group
    cell = get_cell(atoms)  # Get the cell object
    lattice, frac_positions, numbers = cell  # Unpack the cell object
    cart_positions = atoms.get_positions()  # Get the cartesian positions if needed
    print(lattice)
    print(frac_positions)
    
    supercell = build_supercell(atoms, scale) # build the supercell using ASE
    supercell_cell = get_cell(supercell) #get the supercell cell object
    supercell_lattice, supercell_frac_positions, supercell_numbers = supercell_cell #gather the lattice vectors, fractional positons and numbers for the supercell 
    supercell_cart_positions = supercell.get_positions() # cartesian positons for supercell
    supercell_frac_positions_upscaled = supercell_frac_positions *  scale #multiply the fractional positions by the scale, so the fractional positions equal that of the unit cell
    center_index = (scale - 1) / 2 # this gives the value to subtract from all the fractional positons, so the origin is shifted to the central cell (bottom left)
    recentred_supercell_frac_positions_upscaled =  supercell_frac_positions_upscaled - np.array([center_index, center_index, center_index])

    # # radius = (np.max(lattice.lengths()))*1.1# deciding the size of the radius based on the length of the smallest lattice vector. (Chat GPT helped with this line).
    # # # radius = np.linalg.norm(lattice,axis=1) *1.05# find the norm of the radius
    # lengths = lattice.lengths() # deciding the size of the radius based on the length of the smallest lattice vector. (Chat GPT helped with this line).
    # radius = np.linalg.norm(lengths) * 1.01 # find the norm of the radius
    # # radius = radius /2

    lengths = lattice.lengths() # a,b,c
    if radius_mode == "max_length":
        radius = np.max(lengths) * radius_factor
    else:
        radius = np.linalg.norm(lengths) * radius_factor
    print(lengths)
    print(radius)

    all_data = {}  # Initialize the dictionary to store data for each atom
    low, high = cut_window

    if skip_elements is None:
        skip_elements = [32]

    for i, frac_pos in enumerate(frac_positions):
        if pick_atom_index is not None and i != int(pick_atom_index):
            continue
        if skip_elements is not None and numbers[i] in skip_elements:
            print("Skipping atom:", numbers[i])
            continue

        centred_unit_cell_position = centering(frac_pos, scale) # takes position and returns the centred scaled position in the supercell (after being upscaled) e.g scale = 3, 0.2 -> 1.2
        for j in range(len(supercell_frac_positions_upscaled)):
            #loops over all the upscaled positions and picks out the same positons from the original array so the indexing is correct and it can find the NNs
            if math.isclose(centred_unit_cell_position[0], supercell_frac_positions_upscaled[j][0], rel_tol=1e-3) \
            and math.isclose(centred_unit_cell_position[1], supercell_frac_positions_upscaled[j][1], rel_tol=1e-3) \
            and math.isclose(centred_unit_cell_position[2], supercell_frac_positions_upscaled[j][2], rel_tol=1e-3):
                cart_pos = supercell_cart_positions[j] # picks out correct cartsian positon
                desired_supercell_frac_position = supercell_frac_positions[j] #picks out correct fractional positons (before rescaling)
                desired_recentred_supercell_frac_position_upscaled = recentred_supercell_frac_positions_upscaled[j] # picks out correct fractional positon (after rescalign and centering)
                break
            else:
                continue

        if plotting:
            plotting_graphs(lattice, cart_pos, supercell_cart_positions, scale, radius) # plots the position being looked at within the
                                                                                                 #cell with all other supercell points being plotted
                                                                                                 # with the radius shown 

        supercell_distances, indices = get_point_distances(cart_pos, # obtains the cartesian distances to the other atoms within
                                                        supercell_cart_positions,       # the radius , and the indexes of those atoms
                                                        radius) 
        supercell_cart_positions_cut = supercell_cart_positions[indices] # picks out the atoms within the radius (cartesian positons)
        supercell_frac_positions_cut = supercell_frac_positions[indices] # the atoms within the radius (scaled positons)
        supercell_numbers_cut = supercell_numbers[indices] # those atoms' numbers
        recentred_supercell_frac_positions_upscaled_cut = recentred_supercell_frac_positions_upscaled[indices] # the atoms (centred and rescaled)

        shells = grouping(supercell_distances) # puts the atoms around the atoms being looked at into shells based on cartesian distance
        rounded_supercell_distances = np.round(supercell_distances, decimals=4) # rounds these distances to account for floating point precision
        s_i = np.argsort(rounded_supercell_distances) # gather the indices that sort the distances in ascending order

        supercell_distances_sorted = rounded_supercell_distances[s_i] # sort the distances in ascending order
        shells_sorted = shells[s_i] # sort the shells
        supercell_numbers_sorted = supercell_numbers_cut[s_i] # sort the numbers
        supercell_cart_positions_sorted = supercell_cart_positions_cut[s_i] # sort the cartesian positions
        supercell_frac_positions_sorted = supercell_frac_positions_cut[s_i] # sort the scaled positions
        recentred_supercell_frac_positions_upscaled_sorted = recentred_supercell_frac_positions_upscaled_cut[s_i] # sort the rescaled/ recnetred positons

        print("Processing unit cell atom", i+1, "of", len(frac_positions), "with fractional position:", frac_pos, "and atomic number:", numbers[i])
  
        cut = []
        for k, x in enumerate(recentred_supercell_frac_positions_upscaled_sorted):
            if (x[0] > high or x[1] > high or x[2] > high or x[0] < low or x[1] < low or x[2] < low):
                cut.append(k)
                
        supercell_frac_positions_sorted = np.delete(supercell_frac_positions_sorted, cut, axis=0)
        supercell_cart_positions_sorted = np.delete(supercell_cart_positions_sorted, cut, axis=0)
        supercell_distances_sorted = np.delete(supercell_distances_sorted, cut, axis=0)
        shells_sorted = np.delete(shells_sorted, cut, axis=0)
        supercell_numbers_sorted = np.delete(supercell_numbers_sorted, cut, axis=0)
        recentred_supercell_frac_positions_upscaled_sorted = np.delete(recentred_supercell_frac_positions_upscaled_sorted, cut, axis=0)
    
        all_data[i] = { # store the data for this atom in a dictionary
            'cartesian position': cart_pos,
            'fractional position': frac_pos,
            'number': numbers[i],
            'distances_sorted': supercell_distances_sorted,
            'shells_sorted': shells_sorted,
            'numbers_sorted': supercell_numbers_sorted,
            'cart_positions_sorted': supercell_cart_positions_sorted,
            'frac_positions_sorted': recentred_supercell_frac_positions_upscaled_sorted,
        }

        #could provide extra for loop

        atom_data = all_data[i] # gather the data for this atom                                        # then checks if the bond can be mapped to any other atoms bonds
                                                                                                             # maps tensors.
        if do_seperate_tensor_for_each_bond:
            exchange_data = seperate_tensor_for_each_bond(atom_data, space_group, i, numbers_to_skip=skip_elements)
            all_data[i]['exchange_data'] = exchange_data

        if do_first_tensor_for_bond_then_others_from_same_atom:
            exchange_data = first_tensor_for_bond_then_others_from_same_atom(atom_data, space_group, i, numbers_to_skip=skip_elements)
            all_data[i]['exchange_data'] = exchange_data
    
        if do_seperate_tensor_for_bonds_checks_other_atoms:
            exchange_data = seperate_tensor_for_bonds_checks_other_atoms(atom_data, space_group, i, radius=radius, numbers_to_skip=skip_elements)
            all_data[i]['exchange_data'] = exchange_data
        # atom_data = all_data[i]

    return all_data

In [None]:
def run_code_supercell_all(file):

    """ Main funtion to find the specified exchange tensors for each bond in the supercell. """

    atoms = io.read(file)  # Read the file to gather information for lattice vectors, positions, and numbers
    space_group = get_space_group_symops(atoms)  # Get the space group
    cell = get_cell(atoms)  # Get the cell object
    lattice, frac_positions, numbers = cell  # Unpack the cell object
    cart_positions = atoms.get_positions()  # Get the cartesian positions if needed
    print(lattice)

    # unpack the a, b, c lengths (and ignore the angles)

    print(frac_positions)
    
    scale = 3 # decide on the size of the supercell
    supercell = build_supercell(atoms, scale) # build the supercell using ASE
    supercell_cell = get_cell(supercell) #get the supercell cell object
    supercell_lattice, supercell_frac_positions, supercell_numbers = supercell_cell #gather the lattice vectors, fractional positons and numbers for the supercell 
    supercell_cart_positions = supercell.get_positions() # cartesian positons for supercell
    supercell_frac_positions_upscaled = supercell_frac_positions *  scale #multiply the fractional positions by the scale, so the fractional positions equal that of the unit cell
    # center_index = (scale - 1) / 2 # this gives the value to subtract from all the fractional positons, so the origin is shifted to the central cell (bottom left)
    # recentred_supercell_frac_positions_upscaled =  supercell_frac_positions_upscaled - np.array([center_index, center_index, center_index])

    
    # radius = (np.max(lattice.lengths()))*1.1# deciding the size of the radius based on the length of the smallest lattice vector. (Chat GPT helped with this line).
    # # radius = np.linalg.norm(lattice,axis=1) *1.05# find the norm of the radius

    lengths = lattice.lengths() # deciding the size of the radius based on the length of the smallest lattice vector. (Chat GPT helped with this line).
    radius = np.linalg.norm(lengths) * 1.01 # find the norm of the radius
    # radius = radius /2
    print(lengths)
    print(radius)


    all_data = {}  # Initialize the dictionary to store data for each atom

    skip_elements = [31,28]


    cut1 = []
    for i, frac_pos in enumerate(supercell_frac_positions_upscaled):
        if frac_pos[0] > 1.01 or frac_pos[1] > 1.01 or frac_pos[2] > 1.01 or frac_pos[0] < -0.01 or frac_pos[1] < -0.01 or frac_pos[2] < -0.01: # check if the atom is outside the desired area
                # print("atom outside desired area")
                #cut these atoms out
                cut1.append(i)
                continue
    
    supercell_frac_positions_upscaled = np.delete(supercell_frac_positions_upscaled, cut1, axis=0)
    supercell_cart_positions = np.delete(supercell_cart_positions, cut1, axis=0)
    supercell_numbers = np.delete(supercell_numbers, cut1, axis=0)  
    supercell_frac_positions = np.delete(supercell_frac_positions, cut1, axis=0)
    


    print(supercell_frac_positions_upscaled)
    
    for i, frac_pos in enumerate(supercell_frac_positions_upscaled):


   
        if skip_elements is not None and supercell_numbers[i] in skip_elements:
            print("Skipping atom:", supercell_numbers[i])
            continue

        # centred_unit_cell_position = centering(frac_pos, scale) # takes position and returns the centred scaled position in the supercell (after being upscaled) e.g scale = 3, 0.2 -> 1.2
        # for j in range(len(supercell_frac_positions_upscaled)):
        #     #loops over all the upscaled positions and picks out the same positons from the original array so the indexing is correct and it can find the NNs
        #     if math.isclose(centred_unit_cell_position[0], supercell_frac_positions_upscaled[j][0], rel_tol=1e-3) \
        #     and math.isclose(centred_unit_cell_position[1], supercell_frac_positions_upscaled[j][1], rel_tol=1e-3) \
        #     and math.isclose(centred_unit_cell_position[2], supercell_frac_positions_upscaled[j][2], rel_tol=1e-3):
        #         cart_pos = supercell_cart_positions[j] # picks out correct cartsian positon
        #         desired_supercell_frac_position = supercell_frac_positions[j] #picks out correct fractional positons (before rescaling)
        #         desired_recentred_supercell_frac_position_upscaled = recentred_supercell_frac_positions_upscaled[j] # picks out correct fractional positon (after rescalign and centering)
        #         break
        #     else:
        #         continue

        cart_pos = supercell_cart_positions[i] # picks out correct cartsian positon


        plotting_graphs(lattice, cart_pos, supercell_cart_positions, scale, radius) # plots the position being looked at within the
                                                                                                 #cell with all other supercell points being plotted
                                                                                                 # with the radius shown 

        supercell_distances, indices = get_point_distances(cart_pos, # obtains the cartesian distances to the other atoms within
                                                        supercell_cart_positions,       # the radius , and the indexes of those atoms
                                                        radius) 
        supercell_cart_positions_cut = supercell_cart_positions[indices] # picks out the atoms within the radius (cartesian positons)
        supercell_frac_positions_cut = supercell_frac_positions[indices] # the atoms within the radius (scaled positons)
        supercell_numbers_cut = supercell_numbers[indices] # those atoms' numbers
        supercell_frac_positions_upscaled_cut = supercell_frac_positions_upscaled[indices] # the atoms (centred and rescaled)

        shells = grouping(supercell_distances) # puts the atoms around the atoms being looked at into shells based on cartesian distance
        rounded_supercell_distances = np.round(supercell_distances, decimals=4) # rounds these distances to account for floating point precision
        s_i = np.argsort(rounded_supercell_distances) # gather the indices that sort the distances in ascending order

        supercell_distances_sorted = rounded_supercell_distances[s_i] # sort the distances in ascending order
        shells_sorted = shells[s_i] # sort the shells
        supercell_numbers_sorted = supercell_numbers_cut[s_i] # sort the numbers
        supercell_cart_positions_sorted = supercell_cart_positions_cut[s_i] # sort the cartesian positions
        supercell_frac_positions_sorted = supercell_frac_positions_cut[s_i] # sort the scaled positions
        supercell_frac_positions_upscaled_sorted = supercell_frac_positions_upscaled_cut[s_i] # sort the rescaled/ recnetred positons

        print("Processing unit cell atom", i+1, "of", len(frac_positions), "with fractional position:", frac_pos, "and atomic number:", supercell_numbers[i])
  
        # plotting_graphs_no_radius(lattice, cart_pos, positions)  # Plot the position being looked at

        cut = []
        for k, x in enumerate(supercell_frac_positions_upscaled_sorted):
            if x[0] > 1.01 or x[1] > 1.01 or x[2] > 1.01 or x[0] < -0.01 or x[1] < -0.01 or x[2] < -0.01: # check if the atom is outside the desired area
                # print("atom outside desired area")
                #cut these atoms out
                cut.append(k)
                
        supercell_frac_positions_sorted = np.delete(supercell_frac_positions_sorted, cut, axis=0)
        supercell_cart_positions_sorted = np.delete(supercell_cart_positions_sorted, cut, axis=0)
        supercell_distances_sorted = np.delete(supercell_distances_sorted, cut, axis=0)
        shells_sorted = np.delete(shells_sorted, cut, axis=0)
        supercell_numbers_sorted = np.delete(supercell_numbers_sorted, cut, axis=0)
        supercell_frac_positions_upscaled_sorted = np.delete(supercell_frac_positions_upscaled_sorted, cut, axis=0)
    
        all_data[i] = { # store the data for this atom in a dictionary
            'cartesian position': cart_pos,
            'fractional position': frac_pos,
            'number': supercell_numbers[i],
            'distances_sorted': supercell_distances_sorted,
            'shells_sorted': shells_sorted,
            'numbers_sorted': supercell_numbers_sorted,
            'cart_positions_sorted': supercell_cart_positions_sorted,
            'frac_positions_sorted': supercell_frac_positions_upscaled_sorted,
        }

        #could provide extra for loop

        atom_data = all_data[i] # gather the data for this atom

        do_seperate_tensor_for_each_bond = True# decide how to find the exchange tensors
        do_first_tensor_for_bond_then_others_from_same_atom = False
        do_seperate_tensor_for_bonds_checks_other_atoms = False

        if do_seperate_tensor_for_each_bond == True:
            exchange_data = seperate_tensor_for_each_bond(atom_data, space_group, i, numbers_to_skip=skip_elements) # finds separate tensors for each bond
            all_data[i]['exchange_data'] = exchange_data

        if do_first_tensor_for_bond_then_others_from_same_atom == True:
            exchange_data = first_tensor_for_bond_then_others_from_same_atom(atom_data, space_group, i, numbers_to_skip=skip_elements) # finds the first bond tensor in that shell, 
            all_data[i]['exchange_data'] = exchange_data                                                # then uses that to find the other tensors in the same shell.
    
        if do_seperate_tensor_for_bonds_checks_other_atoms == True:
            exchange_data = seperate_tensor_for_bonds_checks_other_atoms(atom_data, space_group, i, radius=radius, numbers_to_skip=skip_elements) # finds the first bond tensor in that shell,
            all_data[i]['exchange_data'] = exchange_data                                                    # then checks if the bond can be mapped to any other atoms bonds
                                                                                                             # maps tensors.
        atom_data = all_data[i]

    return all_data

In [None]:
species_map = {
    25: {"color": "red",  "label": "Mn"},
    32: {"color": "blue", "label": "Ge"},
}
# usage:
# plot_bond(atom_a, atom_b, positions, numbers, species_map, zslice_at=atom_a[2], show3d=True)

In [None]:
run_code_supercell("/Users/theobeevers/Documents/leeds/year 4/Research project/coding/materials to test/Mn5Ge3.cif") # run the code for the supercell