In [1]:
import numpy as np
import json
import spglib
import scipy.io
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import matplotlib.cm as cm

class SpaceGroup:
    def __init__(self, nx, symbol, hall_number=None):
        """
        Args:
            nx (tuple/list): Number of grid points in each dimension (e.g., [64, 64]).
            symbol (str): Space/Plane/Line group symbol (e.g., 'Ia-3d', 'p4m', 'P 4 m m').
            hall_number (int, optional): Explicit spglib Hall number. 
                                         If provided, it overrides the symbol lookup.
        """
        # --- Internal Mappings (based on unique axis c for 2D) ---
        PLANE_TO_SERIAL = {
            'p1': 1, 'p2': 4, 'pm': 18, 'pg': 21, 'cm': 30,
            'pmm2': 125, 'pma2': 137, 'pba2': 161, 'cmm2': 173,
            'p4': 349, 'p4mm': 376, 'p4bm': 377,
            'p3': 430, 'p3m1': 446, 'p31m': 447,
            'p6': 462, 'p6mm': 477,
        }
        
        # Seto symbols often use spaces or specific labels; we normalize them for the lookup
        # This maps shorthand like 'p4m' or 'P 4 m m' to the same ID.
        SHORTHAND_2D = {
            'p1': 'p1', 'p2': 'p2', 'pm': 'pm', 'p1m1': 'pm', 'pg': 'pg', 'p1c1': 'pg', 'cm': 'cm',
            'pmm': 'pmm2', 'pmg': 'pma2', 'pgg': 'pba2', 'cmm': 'cmm2',
            'p4m': 'p4mm', 'p4g': 'p4bm', 'p6m': 'p6mm'
        }

        LINE_TO_SERIAL = {
            'p1': 1, 
            'pm': 20  # ITA 6, Setting P m 1 1 (Mirror perp to x)
        }

        self.nx = np.array(nx)
        self.dim = len(self.nx)

        # if self.dim != 3:
        #     raise ValueError("Currently, Only 3D space groups are supported.")

        print(f"---------- {self.dim}D Space Group ----------")

        # --- Resolve Hall Number and Symbol ---
        # Clean symbol for dictionary lookup
        sym_clean = symbol.lower().replace(" ", "")

        if self.dim == 3:
            pass
        elif self.dim == 2:
            sym_clean = SHORTHAND_2D.get(sym_clean, sym_clean)
            if sym_clean.lower() not in PLANE_TO_SERIAL:
                raise ValueError(f"Unknown 2D symbol '{symbol}'. Use ITA or shorthand (e.g., 'p4m').")
        elif self.dim == 1:
            if sym_clean not in LINE_TO_SERIAL:
                raise ValueError(f"Unknown 1D symbol '{symbol}'. Use 'p1' or 'pm'.")
        else:
            raise ValueError("Dimension must be 1, 2, or 3.")

        if hall_number is None:
            # Find the ITA Hall number via spglib
            h_nums = [h for h in range(1, 531)
                        if spglib.get_spacegroup_type(h).international_short.strip().lower() == symbol.strip().lower()]
            if not h_nums:
                raise ValueError(f"Symbol '{symbol}' not found in spglib.")
            if len(h_nums) > 1:
                print(f"Warning: Multiple settings for {symbol}. Using Hall {h_nums[0]}.")
            self.hall_number = h_nums[0]
        else:
            # Check the Hall number
            inter_short = spglib.get_spacegroup_type(hall_number).international_short
            if inter_short.strip().lower() != symbol.strip().lower():
                raise ValueError(f"Symbol mismatch: {self.spacegroup_symbol} != {symbol}. Please check the provided symbol and Hall number.")
            self.hall_number = hall_number

        self.spacegroup_symbol = spglib.get_spacegroup_type(self.hall_number).international_short
        print(f"Using Hall Number: {self.hall_number}")
        print(f"Symmetry Symbol: {self.spacegroup_symbol}")

        # Use spglib to get information from the Hall number
        spg_type_info = spglib.get_spacegroup_type(self.hall_number)
        self.spacegroup_number = spg_type_info.number

        # Map the international number to the crystal system and lattice parameters
        self.crystal_system, self.lattice_parameters = self.get_crystal_system(self.spacegroup_number)

        # Print the international space group number and symbol
        print(f"International space group number: {self.spacegroup_number}")
        # print(f"International space group symbol: {self.spacegroup_symbol}")
        print(f"Crystal system: {self.crystal_system}")

        # --- Symmetry and Mesh Generation ---
        self.symmetry_operations = self.get_symmetry_ops(self.hall_number, dim=self.dim)
        self.irreducible_mesh, self.indices = self.find_irreducible_mesh(self.nx, self.symmetry_operations)

        print(f"Number of symmetry operations: {len(self.symmetry_operations)}")
        print(f"Original mesh size: {np.prod(self.nx)}")
        print(f"Irreducible mesh size: {len(self.irreducible_mesh)}")

        # --- Flattened mappings for performance ---
        # multi_indices: (dim, num_irreducible_points)
        multi_indices = np.array(self.irreducible_mesh).T
        self.reduced_basis_flat_indices_ = np.ravel_multi_index(multi_indices, self.nx)
        self.full_to_reduced_map_flat_ = self.indices.flatten()

    def get_symmetry_ops(self, hall_number, dim=3):
        """
        Generates the symmetry operations for a given Hall number.
        This function works for 1D, 2D, and 3D space groups.

        Args:
            hall_number (int): The Hall number (from 1 to 530).

        Returns:
            list: A list of symmetry operations.
                Each element is a tuple (rotation_matrix, translation_vector).
        """
        # Get symmetry operations directly from the spglib database
        symmetries = spglib.get_symmetry_from_database(hall_number)
        
        if symmetries is None:
            raise RuntimeError(f"Could not get symmetry operations for Hall number {hall_number}")

        rotations = symmetries['rotations']
        translations = symmetries['translations']
        
        dim = rotations.shape[1]
        # print(f"Space group Hall number: {hall_number}")
        # print(f"Dimension: {dim}D")
        # print(f"Number of symmetry operations: {len(rotations)}")
        
        # Combine rotations and translations into a list of operation pairs
        symmetry_operations = list(zip(rotations, translations))

        return symmetry_operations

        # unique_ops = []
        # seen = set()
        # for R, t in zip(rotations, translations):
        #     # Take the sub-matrix/vector for 1D (x) or 2D (x, y)
        #     R_sub = R[:dim, :dim]
        #     t_sub = t[:dim]
            
        #     # Key uses byte representation to handle numpy arrays in a set
        #     key = (R_sub.tobytes(), np.round(t_sub, 6).tobytes())
        #     if key not in seen:
        #         unique_ops.append((R_sub, t_sub))
        #         seen.add(key)
        # return unique_ops

    def find_irreducible_mesh(self, grid_size, symmetry_operations):
        """        Finds the irreducible set of points for a grid of any dimension (1D, 2D, or 3D)
        under the symmetry of the specified space group.
        Args:
            grid_size (list or tuple): The size of the grid in each dimension.
            symmetry_operations (list): A list of symmetry operations, where each operation is a tuple
                containing a rotation matrix and a translation vector.
                Example: [(R1, t1), (R2, t2), ...]
        Returns:
            tuple: A tuple containing:
                - irreducible_points: A list of tuples representing the irreducible mesh points.
                - indices: A NumPy array of shape grid_size, where each element is the index of the
                  irreducible point that corresponds to that grid point.
        """

        # Infer dimension from grid_size and convert to a NumPy array for vectorized math
        # Ensure the order is correct for multi-dimensional arrays
        grid_size_arr = np.array(np.flip(grid_size))
        
        # Create a boolean grid to keep track of visited points.
        indices = np.zeros(grid_size_arr, dtype=np.int32)-1

        # Create an empty list to store the representative points.
        irreducible_points = []

        count = 0
        # Iterates over 1D, 2D, or 3D indices automatically
        for point_int in np.ndindex(*grid_size_arr):
            if indices[tuple(point_int)] == -1:
                irreducible_points.append(point_int)
                p_frac = np.array(point_int) / grid_size_arr
                
                for R, t in symmetry_operations:
                    # p' = R*p + t (periodic)
                    new_p_frac = np.dot(R, p_frac) + t
                    new_p_int = np.mod(np.round(new_p_frac * grid_size_arr).astype(int), grid_size_arr)
                    indices[tuple(new_p_int)] = count
                count += 1

        # Reverse the order for correct indexing in multi-dimensional arrays
        indices = np.transpose(indices, (2, 1, 0))
        # Reverse the order of each point
        for i in range(len(irreducible_points)):
            irreducible_points[i] = irreducible_points[i][::-1] 

        return irreducible_points, indices

    def hall_numbers_from_ita_symbol(self, international_short):
        """Return every Hall number that belongs to the given ITA symbol."""
        return [
            h for h in range(1, 531)
            if spglib.get_spacegroup_type(h).international_short.strip() == international_short.strip()
        ]

    def hall_numbers_from_ita_number(self, ita):
        """Return every Hall number that belongs to the given ITA number."""
        return [
            h for h in range(1, 531)
            if spglib.get_spacegroup_type(h).number == ita
        ]
    
    def to_reduced_basis(self, fields):
        fields_flat = np.reshape(fields, (fields.shape[0], -1))
        return fields_flat[:, self.reduced_basis_flat_indices_].copy()

    def from_reduced_basis(self, reduced_fields):
        return reduced_fields[:, self.full_to_reduced_map_flat_].copy()

    def get_crystal_system(self, spg_number):
        """
        Maps an international space group number to its crystal system name and lattice parameters
        """
        if 1 <= spg_number <= 2:
            return "Triclinic", ["a", "b", "c", "alpha", "beta", "gamma"]
        elif 3 <= spg_number <= 15:
            return "Monoclinic", ["a", "b", "c", "beta"]
        elif 16 <= spg_number <= 74:
            return "Orthorhombic", ["a", "b", "c"]
        elif 75 <= spg_number <= 142:
            return "Tetragonal", ["a", "c"]
        elif 143 <= spg_number <= 167:
            return "Trigonal", ["a", "c"]
        elif 168 <= spg_number <= 194:
            return "Hexagonal", ["a", "c"]
        elif 195 <= spg_number <= 230:
            return "Cubic", ["a"]
        else:
            raise ValueError("Invalid space group number.")

# ==============================================================================
# Main execution block
# ==============================================================================
if __name__ == '__main__':

    # for hall_number in range(477,490):
    #     spg_type_info = spglib.get_spacegroup_type(hall_number)
    #     print(hall_number, spg_type_info.international_short)
    
    # with open('../../examples/scft/phases/C2D.json', 'r', encoding='utf-8') as file:
    #     data = json.load(file)
    
    # # The 'data' variable now holds a Python dictionary (or list)
    # print(data.keys())
    
    # nx = data["nx"]
    # lx = data["lx"]
    # w = np.reshape(np.array(data["w_A"]), (1, nx[0], nx[1]))
    # w = np.swapaxes(w, 1, 2)
    # nx = nx[::-1]
    
    # grid = 9
    nx = [12, 8]
    w = np.random.uniform(0.0, 1.0, size=[1]+nx)

    print(f"nx: {nx}")

#     'p1': 1, 'p2': 4, 'pm': 18, 'pg': 21, 'cm': 30,
#     'pmm2': 125, 'pma2': 137, 'pba2': 161, 'cmm2': 173,
#     'p4': 349, 'p4mm': 376, 'p4bm': 377,
#     'p3': 430, 'p3m1': 446, 'p31m': 447,
#     'p6': 462, 'p6mm': 477,

    # hall_number = 477
    # spg_type_info = spglib.get_spacegroup_type(hall_number)

    # sg = SpaceGroup(nx, spg_type_info.international_short, hall_number=hall_number)
    # # sg = SpaceGroup(nx, "P6mm", hall_number=477)
    # # sg = SpaceGroup(nx, "P6_3/mmc", hall_number=488)

    # w_reduced_basis = sg.to_reduced_basis(w)
    # print("Reduced basis w:", w_reduced_basis.shape)
    # w_converted = sg.from_reduced_basis(w_reduced_basis)
    # print("Converted w:", w_converted.shape)
    # w_converted = w_converted.reshape(nx)

    # # w_converted = np.roll(w_converted, 0, axis=(0,1))
    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

    # # Display the first image on the left axes (ax1)
    # ax1.imshow(w[0,:,:], cmap='jet')
    # ax1.set_title(f'Image {hall_number}')
    # ax1.axis('off') # Hide axes ticks for cleaner image display

    # # Display the first image on the right axes (ax12)
    # ax2.imshow(w_converted, cmap='jet')
    # ax2.set_title(f'Image {hall_number}')
    # ax2.axis('off') # Hide axes ticks for cleaner image display

    # # Adjust layout to prevent titles/labels from overlapping
    # plt.tight_layout()

    # # Display the figure
    # plt.show()

nx: [12, 8]


In [3]:
# PL, C2D, HCP

with open('../../examples/scft/phases/HCP.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

# The 'data' variable now holds a Python dictionary (or list)
print(data.keys())

nx = data["nx"]
lx = data["lx"]

phi_a = np.reshape(np.array(data["phi_A"]), (1, nx[0], nx[1], nx[2]))

# phi_a = np.swapaxes(phi_a, 1, 2)
# nx = np.array([nx[1], nx[0], nx[2]])

# phi_a = np.swapaxes(phi_a, 1, 3)
# nx = np.array([nx[2], nx[1], nx[0]])

# phi_a = np.swapaxes(phi_a, 2, 3)
# nx = np.array([nx[0], nx[2], nx[1]])

# phi_a = np.flip(phi_a, axis=3)

# grid = 7
# nx = [grid, round(grid*1.0), grid]
# w = np.random.uniform(0.0, 1.0, size=[1, np.prod(nx)])

phi_a = np.reshape(phi_a, (1, np.prod(nx)))

print(f"nx: {nx}")

hall_number_and_std = {}
for hall_number in range(1,531):
# for hall_number in [1, 115, 124, 434, 492, 493, 508, 509, 510]: # SG
# for hall_number in  [1, 2, 336]: # Fddd
    try:
        # hall_number = 488
        # hall_number = data["initial_params"]["space_group"]["number"]
        spg_type_info = spglib.get_spacegroup_type(hall_number)

        sg = SpaceGroup(nx, spg_type_info.international_short, hall_number=hall_number)
        # sg = SpaceGroup(nx, "P6mm", hall_number=477)
        # sg = SpaceGroup(nx, "P6_3/mmc", hall_number=488)

        # phi_a = np.reshape(np.array(data["phi_A"]), (1, np.prod(nx)))

        phi_a_reduced_basis = sg.to_reduced_basis(phi_a)
        print("Reduced basis phi_a:", phi_a_reduced_basis.shape)
        phi_a_converted = sg.from_reduced_basis(phi_a_reduced_basis)
        print("Converted phi_a:", phi_a_converted.shape)
        # phi_a_converted = phi_a_converted.reshape(np.array(nx))
        
        phi_a_std = np.std(phi_a - phi_a_converted)
        print("phi_a_std: ", phi_a_std)
        if (phi_a_std < 1e-2):
            hall_number_and_std[hall_number] = phi_a_reduced_basis.shape[1]

        # X, Y, Z = np.mgrid[0:nx[0], 0:nx[1], 0:nx[2]]

        # fig = make_subplots(rows=1, cols=2, subplot_titles=("Plot 1", "Plot 2"))

        # # phi_b = 1.0 - phi_a

        # # 2. Create the subplot figure
        # # Specify 1 row and 2 columns, and set the subplot types to 'scene' for 3D plots
        # fig = make_subplots(
        #     rows=1, cols=2,
        #     specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        #     subplot_titles=("phi_A", "phi_A_converted")
        # )

        # # Create the first isosurface trace
        # fig.add_trace(go.Isosurface(
        #     x=X.flatten(),
        #     y=Y.flatten(),
        #     z=Z.flatten(),
        #     value=np.reshape(phi_a, nx).flatten(),
        #     isomin=np.mean(phi_a),
        #     isomax=1.0,
        #     opacity=1.0, 
        #     surface_count=2, # number of isosurface. 2 by default:only min and max
        #     colorscale='RdBu', #RdBu, jet, Plotly3, OrRd
        #     colorbar=dict(x=0.45, thickness=30,tickfont=dict(size=30,color="black")),
        #     reversescale=True,
        #     ), row=1, col=1
        # )

        # # Create the second isosurface trace
        # fig.add_trace(go.Isosurface(
        #     x=X.flatten(),
        #     y=Y.flatten(),
        #     z=Z.flatten(),
        #     value=np.reshape(phi_a_converted, nx).flatten(),
        #     isomin=np.mean(phi_a_converted),
        #     isomax=1.0,
        #     opacity=1.0, 
        #     surface_count=2, # number of isosurface. 2 by default:only min and max
        #     colorscale='RdBu', #RdBu, jet, Plotly3, OrRd
        #     colorbar=dict(x=1.0,thickness=30,tickfont=dict(size=30,color="black")),
        #     reversescale=True,
        #     ), row=1, col=2
        # )

        # fig.update_layout(
        #     autosize=False,
        #     width=1000,
        #     height=500,
        # )
        # fig.update_scenes(
        #     camera_projection_type="orthographic",
        #     xaxis_visible=False,
        #     yaxis_visible=False,
        #     zaxis_visible=False)
        # fig.show() 

    except Exception as e:
        print(f"Error Occur: {e}")

print(hall_number_and_std)

dict_keys(['initial_params', 'dim', 'nx', 'lx', 'monomer_types', 'chi_n', 'chain_model', 'ds', 'eigenvalues', 'matrix_a', 'matrix_a_inverse', 'w_A', 'w_B', 'phi_A', 'phi_B'])
nx: [64, 32, 64]
---------- 3D Space Group ----------
Using Hall Number: 1
Symmetry Symbol: P1
International space group number: 1
Crystal system: Triclinic
Number of symmetry operations: 1
Original mesh size: 131072
Irreducible mesh size: 131072
Reduced basis phi_a: (1, 131072)
Converted phi_a: (1, 131072)
phi_a_std:  0.0
---------- 3D Space Group ----------
Using Hall Number: 2
Symmetry Symbol: P-1
International space group number: 2
Crystal system: Triclinic
Number of symmetry operations: 2
Original mesh size: 131072
Irreducible mesh size: 65540
Reduced basis phi_a: (1, 65540)
Converted phi_a: (1, 131072)
phi_a_std:  0.20060218841447386
---------- 3D Space Group ----------
Using Hall Number: 3
Symmetry Symbol: P2
International space group number: 3
Crystal system: Monoclinic
Number of symmetry operations: 2
Ori