In [1]:
!pip install pymatgen
!pip install scipy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


Go to my working directory to load modules and related data.

In [3]:
import os
os.chdir('/content/gdrive/My Drive/Colab Notebooks/')

Load packages for generating materials

In [4]:
import warnings
warnings.filterwarnings("ignore")
import json
import glob
import pickle5 as pickle
import shutil
import argparse
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from collections import Counter
from model import Generator
from simple_dist import sp_lookup
np.set_printoptions(precision=4,suppress=True)

Set arguments

In [5]:
batch_size = 32
n_samples = 500
matdir = 'ternary_gen_cifs'

Define the function that prepares inputs to the generator

In [6]:
def fake_generator(n_samples, n_spacegroup, sp_prob, ele_num):
    label_sp = np.random.choice(n_spacegroup,n_samples,p=sp_prob)

    with open('data/elements_id_NoPo.json', 'r') as f:
        e_d = json.load(f)
        element_ids = list(e_d.values())

    label_elements = []
    for i in range(n_samples):
        fff = np.random.choice(element_ids,ele_num,replace=False)
        label_elements.append(fff)
    label_elements = np.array(label_elements)
    return label_sp,label_elements

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cpu device


In [8]:
#loading data
with open('data/paras.pickle', 'rb') as handle:
    AUX_DATA = pickle.load(handle)
n_spacegroup = len(AUX_DATA[-2])
atom_embedding = np.load('data/elements_features.npy')
atom_embedding = torch.Tensor(atom_embedding).to(device)
box_abc_pat = torch.Tensor(AUX_DATA[0]).to(device)
box_angle_pat = torch.Tensor(AUX_DATA[1]).to(device)
cr = torch.Tensor(AUX_DATA[3]).to(device)
lr = torch.Tensor([AUX_DATA[2]]).to(device)

Load trained generator to generate crystal structures.

In [9]:
netG = Generator(atom_embedding.shape[1], 128).to(device)
checkpoint = torch.load('/content/gdrive/My Drive/Colab Notebooks/models/frac12/generator_weights.pth', map_location=torch.device('cpu'))
netG.load_state_dict(checkpoint['state_dict'])
netG.eval()
spinfo =  sp_lookup(device, AUX_DATA[-2])

In [10]:
random_sp,random_ele = fake_generator(n_samples, n_spacegroup, AUX_DATA[-1], 3) # this is only for ternary mateirals. We generate 2000 for materials
sp2id = AUX_DATA[-2]
id2sp = {sp2id[k]:k for k in sp2id}

In [11]:
#prepare elemente labels
with open('data/elements_id_NoPo.json', 'r') as f:
    e_d = json.load(f)
    re = {e_d[k]:k for k in e_d}

In [12]:
if os.path.exists('%s/'%(matdir)):
    os.system('rm -rf %s/'%(matdir))
os.system('mkdir %s/'%(matdir))

0

# Generate CIFs using PGCGM

In [13]:

count_sample = 0
with torch.no_grad():
    for i in range(n_samples//batch_size+1):
        start_idx = i*batch_size
        end_idx = min((i+1)*batch_size, n_samples)

        sp_id = random_sp[start_idx:end_idx]
        sp_id = torch.Tensor(sp_id).type(torch.int64)

        ele = random_ele[start_idx:end_idx]
        ele = torch.Tensor(ele).to(device)
        ele_ids = ele.type(torch.int64)
        e = atom_embedding[ele_ids]

        z = torch.normal(0.0, 1.0, size=(end_idx-start_idx, 128)).to(device)
        coords,box_abc = netG(spinfo.symm_op_collection[sp_id], torch.transpose(e,1,2), z)

        arr_coords0 = coords[:,0,:,:]*cr + cr
        arr_coords0 = arr_coords0.cpu().detach().numpy()

        arr_coords1 = coords[:,1,:,:]*cr + cr
        arr_coords1 = arr_coords1.cpu().detach().numpy()

        arr_coords2 = coords[:,2,:,:]*cr + cr
        arr_coords2 = arr_coords2.cpu().detach().numpy()

        arr_coords0 = np.round(arr_coords0, 2)
        arr_coords1 = np.round(arr_coords1, 2)
        arr_coords2 = np.round(arr_coords2, 2)
        
        box_abc = torch.exp(box_abc*lr + lr)
        box_abc = torch.einsum('bl,bls->bs', box_abc, box_abc_pat[sp_id])
        box_angles = box_angle_pat[sp_id]

        arr_lengths = box_abc.cpu().detach().numpy()
        arr_angles = box_angles.cpu().detach().numpy()
        arr_ele = ele.cpu().detach().numpy()
        arr_spid = sp_id.cpu().detach().numpy()
        
        for j in range(arr_coords0.shape[0]):
            for rot in range(3):
                f = open('data/cif-template.txt', 'r')
                template = f.read()
                f.close()
                if rot == 0:
                    coords = arr_coords0[j]
                elif rot == 1:
                    coords = arr_coords1[j]
                else:
                    coords = arr_coords2[j]

                lengths = arr_lengths[j]
                angles  = arr_angles[j]
                elements = arr_ele[j]

                template = template.replace('SYMMETRY-SG', id2sp[arr_spid[j]])
                template = template.replace('LAL', str(lengths[0]))
                template = template.replace('LBL', str(lengths[1]))
                template = template.replace('LCL', str(lengths[2]))
                template = template.replace('DEGREE1', str(angles[0]))
                template = template.replace('DEGREE2', str(angles[1]))
                template = template.replace('DEGREE3', str(angles[2]))
                f = open('data/symmetry-equiv/%s.txt'%id2sp[arr_spid[j]].replace('/','#'), 'r')
                sym_ops = f.read()
                f.close()

                template = template.replace('TRANSFORMATION\n', sym_ops)

                for m in range(3):
                    row = ['',re[elements[m]],re[elements[m]]+str(m),\
                        str(coords[m][0]),str(coords[m][1]),str(coords[m][2]),'1']
                    row = '  '.join(row)+'\n'
                    template+=row

                template += '\n'
                f = open('%s/%s---%d_%d.cif'%\
                    (matdir,id2sp[arr_spid[j]].replace('/','#'), count_sample, rot),'w')
                f.write(template)
                f.close()
            count_sample += 1

# Get pymatgen readable cifs

In [14]:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.io.vasp import Poscar
from pymatgen.core.lattice import Lattice
from pymatgen.core.periodic_table import Element

In [15]:
gen_cifs = os.listdir(matdir)
if os.path.exists('ternary_symm_cifs/'):
    os.system('rm -rf ternary_symm_cifs/')
os.system('mkdir ternary_symm_cifs/')

0

In [16]:
for cif in gen_cifs:
    try:
        
        sp = cif.split('---')[0].replace('#','/')
        i = cif.split('---')[1].replace('.cif','')
        
        crystal = Structure.from_file(os.path.join(matdir, cif))
        formula = crystal.composition.reduced_formula
        sg_info = crystal.get_space_group_info(symprec=0.1)
        if sp == sg_info[0]:
            crystal.to(fmt='cif',\
              filename='ternary_symm_cifs/%d-%s-%d-%s.cif'%\
              (len(crystal),formula,sg_info[1],i),symprec=0.1)
    except Exception as e:
        # print(e)
        pass

# Cluster and Merge generated cifs

In [17]:
from pymatgen.core.composition import Composition
from pymatgen.core.structure import Structure
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.spatial.distance import squareform
from pymatgen.core.sites import PeriodicSite
from collections import defaultdict,Counter

In [18]:
class StrucPost(Structure):
    def merge_same_element(self, prop = 0.6):
        mapping = defaultdict(list)
        for site in self.sites:
            mapping[site.species.reduced_formula].append(site)

        new_sites = []
        elements = list(mapping.keys())
        for symbol in elements:
            holes = mapping[symbol]
            if len(holes) == 1:
                new_sites.append(holes[0])
                continue
            frac_coords = np.array([site.frac_coords for site in holes])
            d = self.lattice.get_all_distances(frac_coords, frac_coords)
            np.fill_diagonal(d, 0)
            clusters = fcluster(linkage(squareform((d + d.T) / 2)), float(holes[0].specie.atomic_radius)*2*prop, "distance")
            
            for c in np.unique(clusters):
                inds = np.where(clusters == c)[0]
                species = holes[inds[0]].species
                coords = holes[inds[0]].frac_coords
                for n, i in enumerate(inds[1:]):
                    offset = holes[i].frac_coords - coords
                    coords = coords + ((offset - np.round(offset)) / (n + 2)).astype(coords.dtype)
                new_sites.append(PeriodicSite(species, coords, self.lattice))
        self._sites = new_sites

    def check_dist(self, prop=0.75):
        d = self.distance_matrix
        iu = np.triu_indices(len(d),1)
        d = d[iu]

        atom_radius = []
        for i in range(len(self)):
            for j in range(i+1, len(self)):
                atom_radius.append(self[i].specie.atomic_radius+self[j].specie.atomic_radius)
        atom_radius = np.array(atom_radius)*prop
        
        return np.all(d > atom_radius)

In [19]:
if os.path.exists('ternary_final_cifs/'):
    os.system('rm -rf ternary_final_cifs/')
os.system('mkdir ternary_final_cifs/')

0

In [20]:
symm_cifs = os.listdir('ternary_symm_cifs')

In [21]:
for cif in symm_cifs:
    try:
        crystal = StrucPost.from_file('ternary_symm_cifs/' + cif)
        crystal.merge_same_element(0.5)
        n = len(crystal)
        if n > 100:
            crystal.merge_same_element(1.5)
            n = len(crystal)

        formula = crystal.composition.reduced_formula
        sg_info = crystal.get_space_group_info(symprec=0.1)
        cif = cif.split('-')
        cif[0] = str(n)
        cif[1] = formula
        sgid = cif[2]
        if cif[2] == str(sg_info[1]) and crystal.check_dist(0.75):
            cif = '-'.join(cif)
            crystal.to(fmt='cif', filename='ternary_final_cifs/'+cif, symprec=0.1)

    except Exception as e:
        print(e)
        pass