In [1]:
import re
import json
import operator
import numpy as np

from pprint import pprint
from functools import reduce

from itertools import chain
from collections import Counter, defaultdict
from pymatgen import Structure, PeriodicSite, Lattice, Element
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from multiprocessing import Pool
from numpy.random import uniform

from crystallus import WyckoffCfgGenerator, CrystalGenerator, SpaceGroupDB

## Build DB

The step is to build a DB which contains Wyckoff information for each space group. The DB will be searched in Wyckoff position generation.

## Structure generator

To generate legal structure under a given space group with specific chemical composition, basicly the following four steps are needed.

1. calculate possible Wyckoff configurations under a given space group for each chemical composition.
2. generate fraction positions for each element with given a Wyckoff configuration which is calculated from step 1), randomly.
3. generate lattice for the given space group which is used in step 1), randomly.
4. combine the results from step 2) and 3) to obtain a legal structure.

The following codes implement step 1) ~ 4), respectively.

### Wyckoff position generation

Wyckoff position in the international tabel is a schema which has something like this:

```
(x,y,1/2)	(-y,x-y,1/2)	(-x+y,-x,1/2)	(-x,-y,1/2)
(y,-x+y,1/2)	(x-y,x,1/2)	(y,x,1/2)	(x-y,-y,1/2)
(-x,-x+y,1/2)	(-y,-x,1/2)	(-x+y,y,1/2)	(x,x-y,1/2)
```

We need to build a simple parser to convert the schema to a calculator on the fly. This work can be splitted into 2 steps:
1. generate one coordinate for `x`, `y` or `z` axis.
2. loop step 1) for each axis.

In [2]:
class Coordinate:
    patten = re.compile(r'(?P<xyz>-?\d?[xyz])|(?P<cons_frac>\d\/\d?)|(?P<cons>\d)')
    
    def __call__(self, coordinates):
        const = 0
        x_coeff, y_coeff, z_coeff, const = 0, 0, 0, 0

        for e in self.patten.findall(coordinates):

            if e[0] != '':
                s = e[0].lower()
                if 'x' in s:
                    if '-' in s:
                        x_coeff = -1
                    else:
                        x_coeff = 1
                    if '2' in s:
                        x_coeff *= 2
                    continue

                if 'y' in s:
                    if '-' in s:
                        y_coeff = -1
                    else:
                        y_coeff = 1
                    if '2' in s:
                        y_coeff *= 2
                    continue

                if 'z' in s:
                    if '-' in s:
                        z_coeff = -1
                    else:
                        z_coeff = 1
                    if '2' in s:
                        z_coeff *= 2
                    continue

            if e[1] != '':
                s = e[1].split('/')
                const = float(s[0]) / float(s[1])
                continue
                
            if e[2] != '':
                const = float(e[2])
                continue

        return x_coeff, y_coeff, z_coeff, const

In [3]:
Coordinate()('-2x+1/2')

(-2, 0, 0, 0.5)

In [4]:
class Particle():
    patten = re.compile(r',\s*')
    
    def __init__(self):
        self.Coordinate = Coordinate()
        
    def __call__(self, position):
        return [self.Coordinate(coor) for coor in self.patten.split(position)]
        

In [5]:
Particle()('(x,1/2,-z)'[1:-1])

[(1, 0, 0, 0), (0, 0, 0, 0.5), (0, 0, -1, 0)]

In [6]:
def index(idx, id):
    try:
        return idx.index(id) + 1
    except ValueError:
        return ''

In [7]:
multiplicity_table = defaultdict(dict)

with open('wyckoffs.json', 'r') as f:
    wyckoff_table = json.load(f)

for i, sg in enumerate(wyckoff_table):
    for wy, info in sg.items():
        multiplicity_table[i + 1][wy] = info[0]

In [8]:
def reformat_wy_info(spg, data):
    l = []
    for k, v in data.items():
        l += [k] * int(v / multiplicity_table[spg][k])
    return tuple(sorted(l))

In [9]:
def count_wy(wys, elems):
    counter = defaultdict(lambda: defaultdict(int))

    for wy, e in zip(wys, elems):
        if isinstance(e, Element):
            e = e.symbol
        counter[e][wy] += 1

    # to dict
    return {k: dict(v) for k, v in counter.items()}

In [20]:
from pymatgen import Structure
from tqdm.notebook import tqdm
import gc
import pandas as pd

def parse_structure(structure_list):
    gc.disable()

    structure_cans = []
    for s in tqdm(structure_list, desc='parsing structures'):
        tmp = {}
        struct = Structure(lattice=s['lattice'], species=s['species'], coords=s['coords'])
        tmp['structure'] = struct.get_primitive_structure()
        tmp['volume'] = s['volume']
        tmp['wy_letters'] = s['wyckoff_letters']
        tmp['spacegroup_num'] = s['spacegroup_num']
        tmp['species'] = s['species']
        tmp['composition'] = dict(struct.composition.as_dict())
        tmp['formula'] = struct.composition.formula
        tmp['reduced_formula'] = struct.composition.reduced_formula
        tmp['num_atoms'] = struct.composition.num_atoms
        structure_cans.append(tmp)

    gc.enable()
    gc.collect()
    
    return pd.DataFrame(structure_cans)

In [11]:
true_struct = Structure.from_file('cifs/mp-9770.cif')
analyzer = SpacegroupAnalyzer(true_struct)
true_spg = analyzer.get_space_group_number()

if true_spg not in [146, 148, 155, 160, 161, 166, 167]:
    true_struct = SpacegroupAnalyzer(true_struct).get_conventional_standard_structure()
sa = SpacegroupAnalyzer(true_struct).get_symmetry_dataset()
ss_symbol = sa['site_symmetry_symbols']
wy_data = sa['wyckoffs']

true_wy_count = count_wy(wy_data, true_struct.species)
true_wy_reformat = {k: reformat_wy_info(true_spg, v) for k, v in true_wy_count.items()}
true_wy_pattern = tuple(sorted([v for v in true_wy_reformat.values()]))
true_wy_pattern_loose = tuple(sorted(reduce(operator.add, true_wy_pattern)))
true_wy_unique = tuple(sorted(set(true_wy_pattern_loose)))

In [12]:
class WyckoffPos():
    patten = re.compile(r'(?<=\)),\s*')
    
    def __init__(self, spacegroup_num):
        wys = SpaceGroupDB.get(SpaceGroupDB.spacegroup_num == spacegroup_num).wyckoffs
        self.particle = Particle()
        self.wyckoff_pos = {wy.letter: self.patten.split(wy.positions)[0][1:-1] for wy in wys}
        
    def __call__(self, wy_letter, b):
        a = np.array(self.particle(self.wyckoff_pos[wy_letter]))
        idx = []

        if np.count_nonzero(a[:, 0]):
            idx.append(0)
        if np.count_nonzero(a[:, 1]):
            idx.append(1)
        if np.count_nonzero(a[:, 2]):
            idx.append(2)
        b[idx] -= a[idx, -1]

        if len(idx) > 1:
            solves = np.linalg.solve(a[idx][:, idx], b[idx] - a[idx, -1])
            b[idx] = solves

        return b.copy()

In [13]:
b = np.array([0.9949, 0.0950, 0.7716])
xyz = WyckoffPos(true_spg)('a', b)
xyz

array([0.9949, 0.095 , 0.7716])

In [15]:
cg = CrystalGenerator(
    spacegroup_num=true_spg,
    estimated_volume=1254.599172,
    estimated_variance=10,
    empirical_coords=dict(
        a=[tuple(xyz.tolist())]
    ),
    empirical_coords_variance=0
)
cg

CrystalGenerator(            
    spacegroup_num=33,            
    estimated_volume=1254.599172,            
    estimated_variance=10,            
    angle_range=(30.0, 150.0),            
    angle_tolerance=20.0,            
    max_attempts_number=5000,            
    empirical_coords=...,            
    empirical_coords_variance=0,            
    n_jobs=-1            
)

In [25]:
ss = cg.gen_one(Ag=['a'])
ss = parse_structure([ss])
ss.structure[0]

HBox(children=(HTML(value='parsing structures'), FloatProgress(value=0.0, max=1.0), HTML(value='')))




Structure Summary
Lattice
    abc : 7.66624283213384 9.852965708621726 9.852965708621726
 angles : 79.80557030551672 67.10570345167348 67.10570345167348
 volume : 631.3188867798196
      A : -0.0 -0.0 7.66624283213384
      B : 6.320548652632781 6.514505214877541 3.8331214160669207
      C : -6.320548652632782 6.514505214877541 3.83312141606692
PeriodicSite: Ag (-0.0000, 1.2378, 5.9153) [0.6766, 0.0950, 0.0950]
PeriodicSite: Ag (-0.0000, 11.7913, 9.7484) [0.3666, 0.9050, 0.9050]

In [26]:
ss = cg.gen_one(Ag=['a'])
ss = parse_structure([ss])
ss.structure[0]

HBox(children=(HTML(value='parsing structures'), FloatProgress(value=0.0, max=1.0), HTML(value='')))




Structure Summary
Lattice
    abc : 6.010835927601179 10.705249074328446 10.705249074328446
 angles : 82.94616563169949 73.69547795017682 73.69547795017682
 volume : 633.8390683893151
      A : 3.680575489468348e-16 -6.010835927601179 -3.680575489468348e-16
      B : -7.089745184856023 -3.0054179638005896 -7.436755597710979
      C : 7.089745184856023 -3.0054179638005896 -7.436755597710979
PeriodicSite: Ag (-0.0000, -5.4398, -3.3971) [0.6766, 0.2284, 0.2284]
PeriodicSite: Ag (0.0000, -6.5819, -10.8339) [0.3666, 0.7284, 0.7284]

In [157]:
true_wy_reformat

{'Ag': ('a', 'a', 'a', 'a', 'a', 'a', 'a', 'a'),
 'Ge': ('a',),
 'S': ('a', 'a', 'a', 'a', 'a', 'a')}

In [149]:
true_struct

Structure Summary
Lattice
    abc : 7.62616066 10.74315461 15.31324474
 angles : 90.0 90.0 90.0
 volume : 1254.5991812902842
      A : 7.62616066 0.0 4.669676621026233e-16
      B : 1.7276310257318973e-15 10.74315461 6.578284952940816e-16
      C : 0.0 0.0 15.31324474
PeriodicSite: Ag (7.5876, 1.0208, 11.8156) [0.9949, 0.0950, 0.7716]
PeriodicSite: Ag (6.6626, 4.3578, 11.3221) [0.8737, 0.4056, 0.7394]
PeriodicSite: Ag (0.9635, 9.7294, 3.9911) [0.1263, 0.9056, 0.2606]
PeriodicSite: Ag (4.7766, 4.3578, 3.6655) [0.6263, 0.4056, 0.2394]
PeriodicSite: Ag (3.9392, 1.2013, 7.4041) [0.5165, 0.1118, 0.4835]
PeriodicSite: Ag (3.6869, 6.5729, 7.9091) [0.4835, 0.6118, 0.5165]
PeriodicSite: Ag (7.5000, 1.2013, 15.0607) [0.9835, 0.1118, 0.9835]
PeriodicSite: Ag (0.1261, 6.5729, 0.2525) [0.0165, 0.6118, 0.0165]
PeriodicSite: Ag (6.7044, 2.0538, 3.5062) [0.8791, 0.1912, 0.2290]
PeriodicSite: Ag (0.9218, 7.4253, 11.8070) [0.1209, 0.6912, 0.7710]
PeriodicSite: Ag (4.7349, 2.0538, 11.1629) [0.6209, 0.191

In [148]:
a = SpacegroupAnalyzer(true_struct)
s = a.get_symmetrized_structure()

Error: Pip module debugpy is required for debugging cells. You will need to install it to debug cells.

In [41]:
def to_s(x):
    return "%0.6f" % x

In [43]:
for i, sites in enumerate(s.equivalent_sites):
    site = sites[0]
    row = [str(i), site.species_string]
    row.extend([to_s(j) for j in site.frac_coords])
    row.append(s.wyckoff_symbols[i])
    print(row)

['0', 'Ag', '0.994945', '0.095017', '0.771593', '4a']
['1', 'Ag', '0.873654', '0.405638', '0.739368', '4a']
['2', 'Ag', '0.516540', '0.111819', '0.483510', '4a']
['3', 'Ag', '0.879130', '0.191169', '0.228967', '4a']
['4', 'Ag', '0.610870', '0.190429', '0.087637', '4a']
['5', 'Ag', '0.566539', '0.522052', '0.066411', '4a']
['6', 'Ag', '0.733839', '0.335581', '0.436881', '4a']
['7', 'Ag', '0.770548', '0.869876', '0.370507', '4a']
['8', 'Ge', '0.733802', '0.850394', '0.123830', '4a']
['9', 'S', '0.765449', '0.105633', '0.376252', '4a']
['10', 'S', '0.812687', '0.367857', '0.112236', '4a']
['11', 'S', '0.725622', '0.727762', '0.241888', '4a']
['12', 'S', '0.973621', '0.973780', '0.126849', '4a']
['13', 'S', '0.769252', '0.727800', '0.505203', '4a']
['14', 'S', '0.996893', '0.475094', '0.379234', '4a']
