In [16]:
# %%
import pandas as pd
from pymatgen.io.cif import CifParser
from pyxtal import pyxtal
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
import numpy as np
from pyxtal.io import write_cif

from pymatgen.core.structure import Structure
from tqdm import tqdm
import pickle
from p_tqdm import p_map
import argparse
import os
from pathlib import Path
import warnings

warnings.filterwarnings("ignore")


# %%
def process_cif_to_conventional(cif_str):
    structure = CifParser.from_str(cif_str).get_structures()[0]
    sga = SpacegroupAnalyzer(structure)
    pyx = pyxtal()
    pyx.from_seed(structure, tol=0.01)
    space_group = pyx.group.number
    species = []
    anchors = []
    matrices = []
    coords = []
    for site in pyx.atom_sites:
        specie = site.specie
        anchor = len(matrices)
        # coord = site.position
        for syms in site.wp:
            species.append(specie)
            matrices.append(syms.affine_matrix)
            # coords.append(syms.operate(coord))
            anchors.append(anchor)
    # anchors = np.array(anchors)
    matrices = np.array(matrices).tolist()
    # coords = np.array(coords) % 1.0
    sym_info = {"anchors": anchors, "wyckoff_ops": matrices, "spacegroup": space_group}
    cif = write_cif(pyx)[805:]
    num_sites = len(species)
    formula = pyx.formula
    return cif, sym_info, num_sites, formula


# %%
def process_data(data):
    cif_str = data["cif"]
    cif, sym_info, num_sites, formula = process_cif_to_conventional(cif_str)
    data["cif"] = cif
    data["sym_info"] = sym_info
    data["num_sites"] = num_sites
    data["formula"] = formula
    return data


# main code
# parser = argparse.ArgumentParser()
# parser.add_argument(
#     "--csv_path", type=str, default="/home/holywater2/crystal_gen/mattergen/datasets"
# )
# parser.add_argument("--data_name", type=str, default="mp_20")
# parser.add_argument("--mode", type=str, default="val")
# parser.add_argument("--num_cpus", type=int, default=16)
# args = parser.parse_args()
args = argparse.Namespace(
    csv_path="/home/holywater2/crystal_gen/mattergen/datasets",
    data_name="mp_20",
    mode="val",
    num_cpus=16,
)
print("Starting...")
csv_path = Path(args.csv_path) / args.data_name
print(f"Processing {csv_path}/{args.mode}.csv")
df = pd.read_csv(csv_path / f"{args.mode}.csv", index_col=0)
new_data = p_map(process_data, df[:100].to_dict(orient="records"), num_cpus=args.num_cpus)
new_df = pd.DataFrame(new_data)
# os.makedirs(f"conventional/{args.data_name}", exist_ok=True)
# print(f"Saving to conventional/{args.data_name}/{args.mode}.csv")
# new_df.to_csv(f"conventional/{args.data_name}/{args.mode}.csv")
# print("Done!")


Starting...
Processing /home/holywater2/crystal_gen/mattergen/datasets/mp_20/val.csv


  0%|          | 0/100 [00:00<?, ?it/s]

In [47]:
new_df.drop(columns="sym_info")

Unnamed: 0,material_id,formation_energy_per_atom,dft_band_gap,pretty_formula,e_above_hull,elements,cif,spacegroup_number,azure_bulk_modulus,larsen_score_2d,Si_100_mismatch,azure_band_gap,dft_bulk_modulus,dft_poisson_ratio,dft_mag_density,num_sites,formula
0,mp-865981,-0.436368,0.0000,TmMgHg2,0.000000,"['Hg', 'Mg', 'Tm']",\ndata_\n\n_symmetry_space_group_name_H-M 'Fm-...,225.0,61.426781,0.000000,0.011126,0.000027,,,4.562193e-09,16,Tm4Mg4Hg8
1,mp-1103778,-2.755559,3.5845,HoWClO4,0.025353,"['Cl', 'Ho', 'O', 'W']",\ndata_\n\n_symmetry_space_group_name_H-M 'C2/...,12.0,61.712846,0.000000,0.016360,3.415163,,,2.914640e-06,28,Ho4W4Cl4O16
2,mp-39712,-3.299936,2.4439,NaCaTaTiO6,0.012439,"['Ca', 'Na', 'O', 'Ta', 'Ti']",\ndata_\n\n_symmetry_space_group_name_H-M 'Pc'...,7.0,146.963192,0.000000,0.001477,2.392171,,,3.029915e-06,20,Na2Ca2Ta2Ti2O12
3,mp-754553,-0.021981,0.2182,Zn3N2,0.009962,"['Zn', 'N']",\ndata_\n\n_symmetry_space_group_name_H-M 'Cmc...,62.0,132.293591,0.000000,0.012372,0.000570,132.048332,0.311903,4.167681e-06,20,N8Zn12
4,mp-20332,-0.684002,0.0000,GdMgPd,0.000000,"['Gd', 'Mg', 'Pd']",\ndata_\n\n_symmetry_space_group_name_H-M 'P-6...,189.0,63.058039,0.000000,0.014960,0.000023,,,1.160755e-01,9,Gd3Mg3Pd3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,mp-1079654,-0.546751,0.0000,Sm2MnGa6,0.000000,"['Ga', 'Mn', 'Sm']",\ndata_\n\n_symmetry_space_group_name_H-M 'Fm-...,225.0,69.665882,0.000000,0.005953,-0.000019,,,1.566095e-02,36,Sm8Mn4Ga24
96,mp-1873,-2.602986,3.4773,ZnF2,0.000000,"['Zn', 'F']",\ndata_\n\n_symmetry_space_group_name_H-M 'P42...,136.0,82.017539,0.000000,0.003887,3.488967,98.493027,0.342369,0.000000e+00,6,Zn2F4
97,mp-760482,0.003466,3.0951,F2,0.003466,['F'],\ndata_\n\n_symmetry_space_group_name_H-M 'Cmc...,64.0,2.322351,0.006086,0.008052,3.045773,2.305929,0.239143,2.429159e-05,8,F8
98,mp-1248567,-3.474904,3.2663,AlVF5,0.022430,"['Al', 'F', 'V']",\ndata_\n\n_symmetry_space_group_name_H-M 'Imm...,71.0,73.055847,0.022379,0.007992,3.131949,,,3.409923e-02,14,Al2V2F10


In [17]:
new_df["sym_info"][0]

{'anchors': [0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 8, 8, 8, 8],
 'wyckoff_ops': [[[0.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 1.0],
   [0.0, 0.0, 0.0, 1.0],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 1.0],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 1.0],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 1.0],
   [0.0, 0.0, 0.0, 1.0],
   [0.0, 0.0, 0.0, 0.5],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.0, 0.0, 0.25],
   [0.0, 0.0, 0.0, 0.25],
   [0.0, 0.0, 0.0, 0.25],
   [0.0, 0.0, 0.0, 1.0]],
  [[0.0, 0.

In [10]:
import json

In [19]:
save_dict = {}
for i in range(len(new_df)):
    save_dict[new_df["material_id"][i]] = new_df["sym_info"][i]
with open("sym_info.json", "w") as f:
    json.dump(save_dict, f)

In [20]:
load = json.load(open("sym_info.json"))

In [25]:
import torch

In [40]:
torch.tensor(load["mp-1027109"]["wyckoff_ops"],dtype=torch.float64)[0][0][3]

tensor(0.6667, dtype=torch.float64)

In [34]:
load["mp-1027109"]["wyckoff_ops"]

[[[0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 1.0]],
 [[0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 1.0]],
 [[0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 1.0]],
 [[0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 1.0]],
 [[0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 1.0]],
 [[0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 1.0]],
 [[0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 1.0]],
 [[0.0, 0.0, 0.0, 0.3333333333333333],
  [0.0, 0.0, 0.0, 0.6666666666666666],
  [0.0, 0.0, 1.0, 0.0],
  [0.0, 0