In [2]:
import torch, json
from pyxtal.symmetry import Group
from data_utils import get_site_symmetry_binary_repr

cluster_sites = json.load(open('/home/mila/s/siba-smarak.panigrahi/DiffCSP/cluster_sites.json', 'r'))

In [3]:
# map operations to spacegroups
spacegroup_ops_mapper = {}
for spacegroup in range(1, 231):
    group = Group(spacegroup)
    spacegroup_ops_mapper[spacegroup] = {}
    for i in range(len(group.Wyckoff_positions)):
        group.Wyckoff_positions[i].get_site_symmetry()
        spacegroup_ops_mapper[spacegroup][
            tuple(get_site_symmetry_binary_repr(
                cluster_sites[group.Wyckoff_positions[i].site_symm], label=str(spacegroup)
                ).tolist()
                  )
            ] = torch.stack([torch.from_numpy(x.affine_matrix) for x in group.Wyckoff_positions[i].ops]).float()


# save the mapper using torch
torch.save(spacegroup_ops_mapper, 'spacegroup_ops_mapper.pt')

In [2]:
# collect all wyckoff positions available
wyckoff_labels = []
for spacegroup in range(1, 231):
    group = Group(spacegroup)
    for wp in group.Wyckoff_positions:
        wp_label = wp.get_label()
        if wp_label not in wyckoff_labels: wyckoff_labels.append(wp_label)
        
# sort all the wyckoff labels
wyckoff_labels = sorted(wyckoff_labels)

# save all the wyckoff labels
torch.save(wyckoff_labels, 'wyckoff_labels.pt')

In [3]:
# create masks for spacegroups depending on the which wyckoff positions are present
spacegroup_wyckoff_masks = {}
for spacegroup in range(1, 231):
    group = Group(spacegroup)
    spacegroup_wyckoff_masks[spacegroup] = torch.zeros(len(wyckoff_labels))
    for wp in group.Wyckoff_positions:
        wp_label = wp.get_label()
        spacegroup_wyckoff_masks[spacegroup][wyckoff_labels.index(wp_label)] = 1
        
# save the masks using torch
torch.save(spacegroup_wyckoff_masks, 'spacegroup_wyckoff_masks.pt')