Before inference, obtain the atomic number distribution of functional materials and add it to the `~/DiffCSP/scripts/generation.py` file. We have provided an example of atomic number distribution statistics in the `InvDesFlow.ipynb` notebook.

In [None]:

from tqdm import tqdm
from pymatgen.core.structure import Structure
import numpy as np
import os
import csv


def read_csv(file_path):
    data_list = []
    with open(file_path, "r", newline="") as csv_file:
        reader = csv.reader(csv_file)
        
        for row in reader:
            data_list.append(row)
    return data_list

def get_pN(name='magnetic_150'):
        
    data = read_csv(f'./data/{name}/data_materials.csv')
    number_list = []
    for d in data[1:]:
        cif_str = d[8]
        crystal = Structure.from_str(cif_str, fmt='cif')
        atom_types = crystal.atomic_numbers
        atom_types = np.array(atom_types)
        num_atoms = atom_types.shape[0]
        number_list.append(num_atoms)
    count_list = []
    for num in range(max(number_list)+1):
        
        count = number_list.count(num)
        count_list.append(count)
    
    
    p_list = [count/sum(count_list) for count in count_list]
    return p_list
p_list = get_pN()
print(p_list)


The generated crystals are stored in `.pt` file format. Here, we provide an example to convert the `.pt` file into a `.cif` file.

In [None]:

pt_dir = '' #pt file
step2_dir = '' # saved dir
mp_project = '/home/xqhan/DiffCSP/data/mpid_formula.pkl'

import pickle
import torch
import csv
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
import torch 
from pymatgen.core.lattice import Lattice
from pymatgen.io.cif import CifWriter
import os
from tqdm import tqdm
import json
import numpy as np
from scipy.spatial.distance import cdist


# You can store a library to remove duplicate chemical formulas.
# with open('/home/xqhan/DiffCSP/data/mpid_formula.pkl', 'rb') as f:
#     mp_project = pickle.load(f)
# mp_project_formula_set = set(mp_project.values())
# formula_set1 = '/home/xqhan/DiffCSP/DP_cup/formula_list.json'
# with open(formula_set1, 'r') as f:
#     dp_name_set1 = json.load(f)
# dp_set = set(dp_name_set1)
# mp_project_formula_set = mp_project_formula_set | dp_set
# print('len(mp_project_formula_set):',len(mp_project_formula_set))

chemical_symbols = [
    # 0
    'X',
    # 1
    'H', 'He',
    # 2
    'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
    # 3
    'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar',
    # 4
    'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn',
    'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr',
    # 5
    'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd',
    'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
    # 6
    'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy',
    'Ho', 'Er', 'Tm', 'Yb', 'Lu',
    'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi',
    'Po', 'At', 'Rn',
    # 7
    'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk',
    'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr',
    'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc',
    'Lv', 'Ts', 'Og']

cup_e_set = set( ['Ne', 'He', 'Ar', 'F', 'O', 'Cl', 'N', 'Kr', 'Br', 'I', 'Xe', 'S', 'Se', 
         'C', 'Au', 'W', 'Pb', 'Rh', 'Pt', 'Ru', 'Pd', 'Os', 'Ir', 'H', 'P', 'As', 
         'Mo', 'Te', 'Sb', 'B', 'Bi', 'Ge', 'Hg', 'Sn', 'Ag', 'Ni', 'Tc', 'Si', 
         'Re', 'Cu', 'Co', 'Fe', 'Ga', 'In', 'Cd', 'Cr', 'Zn', 'V', 'Tl', 'Al', 
         'Nb', 'Be', 'Mn', 'Ti', 'Ta', 'Pa', 'U', 'Sc', 'Np', 'Zr', 'Mg', 'Th', 
         'Hf', 'Pu', 'Lu', 'Tm', 'Er', 'Ho', 'Y', 'Dy', 'Gd', 'Eu', 'Sm', 'Nd', 
         'Pr', 'Pm', 'Ce', 'Yb', 'Tb', 'La', 'Ac', 'Ca', 'Li', 'Sr', 'Na', 'Ba', 
         'Rb', 'K', 'Cs'] )
def get_atoms_types(atom_types_onehot):
    symbols = []
    ith_list = []
    
    for i in range(  atom_types_onehot.shape[0] ):
        idx = torch.argmax( atom_types_onehot[i] ).item() +1
        
        atom_symbol = chemical_symbols[idx]
        symbols.append(atom_symbol)
        ith_list.append(idx)
    return symbols, ith_list


save_dir = step2_dir
os.makedirs(save_dir, exist_ok=True)
save_num = 0
save_set = set()
for device in [0,1,2]:
    for i in range(160):
        print('device:',device,'i:',i)
        gen_pt = f'{pt_dir}/device_4090_{device}_{i}_1000.pt'
        if not os.path.exists(gen_pt):
            break
        gen_data = torch.load(gen_pt)
        num_atoms = [n.item() for n in gen_data['num_atoms']]
        idx_start = 0
        idx_end = 0
        atom_types_list = []
        for j,num in tqdm(enumerate( num_atoms)):
            idx_end += num
            frac_coord = gen_data['frac_coords'][idx_start:idx_end].numpy()
            atom_types, ith_list = get_atoms_types(gen_data['atom_types'][idx_start:idx_end])
            idx_start += num
            if len( set(atom_types) - cup_e_set )==0: 
                lengths = gen_data['lengths'][j]
                angles = gen_data['angles'][j]
                crystal = Structure(lattice=Lattice.from_parameters(*(lengths.tolist() + angles.tolist())), 
                                    species=atom_types, 
                                    coords=frac_coord, 
                                    coords_are_cartesian=False)
                formula_str = crystal.composition.reduced_formula
                # print(formula_str)
                # if (formula_str  not in mp_project_formula_set) and (formula_str not in save_set) :
                save_num+=1
                save_file_name = f'{step2_dir}/formula_{crystal.composition.reduced_formula}_device{device}_ith{i}_{j}.cif'
                cif_writer = CifWriter(crystal)
                cif_writer.write_file(save_file_name)    
                save_set.add(crystal.composition.reduced_formula)
                
cif_files_num = len(os.listdir(step2_dir))
print(cif_files_num,save_num)