In [1]:
import sys
import os
from pathlib import Path
sys.path.append(Path.cwd().parent.as_posix())
from util import *
from visualizer import visualizer
vis = visualizer()

from mattergen.self_guidance.wyckoff_dataset_prim import *
from mattergen.self_guidance.wyckoff_sampler_prim_new import *
from mattergen.common.data.collate import collate

MODELS_PROJECT_ROOT: /home/holywater2/crystal_gen/mattergen/mattergen


In [9]:
def _project_to_space_group(batch,pos=None,cell=None):
    cf = CrystalFamily()
    cf.set_device(batch.pos.device)
    lat_perm, _, perm_for_A2 = get_latttice_permutations(device=batch.pos.device)
    if pos is not None:
        batch = batch.replace(pos=pos)
    else:
        batch = batch.replace(pos=torch.rand_like(batch.pos))
    if cell is not None:
        batch = batch.replace(cell=cell)
    else:
        batch = batch.replace(cell=torch.tensor(initialize_random_lattice().matrix).unsqueeze(0).float())

    wyckoff_batch = batch.wyckoff_bat.clone()
    idx, cum = 0, 0
    for len, num_atom in zip(batch.wyckoff_bat_len, batch.num_atoms):
        wyckoff_batch[idx : idx + len] = wyckoff_batch[idx : idx + len] + cum
        idx += len
        cum += num_atom

    # project latice
    conv_lat = torch.bmm(batch.prim_to_conv, batch.cell)
    conv_lat_vec = cf.m2v(cf.de_so3(conv_lat))
    conv_lat_vec_proj = cf.proj_k_to_spacegroup(conv_lat_vec, batch.space_groups)
    conv_lat_proj = cf.v2m(conv_lat_vec_proj)

    rank = torch.argsort(-torch.norm(conv_lat_proj, dim=-1), dim=-1)
    idx = torch.cat([batch.space_groups.unsqueeze(-1), rank], dim=-1)
    perm = lat_perm[idx[:, 0], idx[:, 1], idx[:, 2], idx[:, 3]]
    perm_A = perm_for_A2[batch.space_groups]

    # perm_conv_lat_proj = torch.bmm(torch.bmm(perm, conv_lat_proj), perm.transpose(-1, -2))
    # perm_conv_lat_proj = torch.bmm(
    #     torch.bmm(perm_A, perm_conv_lat_proj), perm_A.transpose(-1, -2)
    # )
    perm_conv_lat_proj = conv_lat_proj

    # prim_lat_proj = torch.bmm(torch.bmm(torch.bmm(perm_A,batch.conv_to_prim),perm_A.transpose(-1,-2)), perm_conv_lat_proj)
    prim_lat_proj = torch.bmm(batch.conv_to_prim, perm_conv_lat_proj)


    pos_cart = torch.einsum("bi,bij->bj", batch.pos, batch.cell[batch.batch])
    pos_frac_conv = torch.einsum(
        "bi,bij->bj", pos_cart, torch.inverse(perm_conv_lat_proj)[batch.batch]
    )
    pos_tran = torch.cat(
        [
            pos_frac_conv[wyckoff_batch],
            torch.ones(pos_frac_conv[wyckoff_batch].shape[0], 1, device=batch.pos.device),
        ],
        dim=1,
    )

    pos_frac_proj = (
        torch.einsum("bij,bj->bi", batch.wyckoff_ops, pos_tran).squeeze(-1)[:, :3] % 1.0
    )
    pos_frac_proj = torch.einsum("bij,bj->bi", perm[batch.batch[wyckoff_batch]], pos_frac_proj)
    pos_cart_porj = torch.einsum(
        "bi,bij->bj", pos_frac_proj, perm_conv_lat_proj[batch.batch[wyckoff_batch]]
    )

    prim_lat_inv = torch.inverse(prim_lat_proj)
    pos_prim_frac_proj_all = torch.einsum(
        "bi,bij->bj", pos_cart_porj, prim_lat_inv[batch.batch[wyckoff_batch]]
    )

    ## Get prim idx
    for i in range(10):
        random_pos_frac_conv = torch.rand_like(pos_frac_conv).to(pos_frac_conv.device)
        random_pos_tran = torch.cat(
            [
                random_pos_frac_conv[wyckoff_batch],
                torch.ones(
                    random_pos_frac_conv[wyckoff_batch].shape[0],
                    1,
                    device=pos_frac_conv.device,
                ),
            ],
            dim=1,
        )
        random_pos_frac_proj = (
            torch.einsum("bij,bj->bi", batch.wyckoff_ops, random_pos_tran).squeeze(-1)[
                :, :3
            ]
            % 1.0
        ) % 1.0
        random_pos_frac_proj = torch.einsum(
            "bij,bj->bi", perm[batch.batch[wyckoff_batch]], random_pos_frac_proj
        )
        random_pos_cart_proj = torch.einsum(
            "bi,bij->bj",
            random_pos_frac_proj,
            perm_conv_lat_proj[batch.batch[wyckoff_batch]],
        )

        random_fracs = torch.einsum(
            "bi,bij->bj",
            random_pos_cart_proj,
            prim_lat_inv[batch.batch[wyckoff_batch]],
        )
        random_fracs = random_fracs % 1.0
        random_fracs_diff = random_fracs.unsqueeze(1) - random_fracs.unsqueeze(0)
        random_fracs_diff = random_fracs_diff - torch.round(random_fracs_diff)
        EPSILON = 5e-4
        random_fracs_diff_is_zero = torch.all(
            torch.isclose(
                random_fracs_diff,
                torch.zeros_like(random_fracs_diff),
                rtol=EPSILON,
                atol=EPSILON,
            ),
            dim=-1,
        )
        random_fracs_idx = random_fracs_diff_is_zero & (
            wyckoff_batch.unsqueeze(0) == wyckoff_batch.unsqueeze(1)
        )
        random_fracs_idx = ~(random_fracs_idx.triu(diagonal=1).any(dim=0))
        # random_fracs_prim = random_fracs[random_fracs_idx]
        assert random_fracs_idx.shape[0] == pos_prim_frac_proj_all.shape[0]
        pos_prim_frac_proj = pos_prim_frac_proj_all[random_fracs_idx]
        if pos_prim_frac_proj.shape[0] == batch.pos.shape[0]:
            frac_batch = random_fracs_idx
            break
    pos_prim_frac_proj = pos_prim_frac_proj_all[frac_batch]

    # prim_lat_proj = torch.bmm(torch.bmm(perm_A, prim_lat_proj),perm_A.transpose(-1,-2))
    # pos_prim_frac_proj = torch.einsum("bij,bj->bi", perm_A, pos_prim_frac_proj)

    return (pos_prim_frac_proj % 1.0), prim_lat_proj

In [2]:
space_group_input = np.arange(1, 231).tolist()
info_path = "/home/holywater2/crystal_gen/mattergen/_my_scripts/space_group_info/prim/mp_20_val"

dataset = WyckoffDataset.from_space_group_list(
    space_group_infos_path=info_path,
    num_samples=100,
    space_groups=[38]
)

In [8]:
pos = 
_project_to_space_group(collate(dataset[0]))

(tensor([[1.9996e-02, 9.8000e-01, 5.0000e-01],
         [9.9995e-01, 5.2892e-05, 5.0000e-01],
         [5.2525e-01, 4.7475e-01, 5.0000e-01],
         [6.7178e-01, 3.2822e-01, 5.0000e-01],
         [7.2675e-01, 2.7325e-01, 0.0000e+00],
         [7.7708e-01, 2.2292e-01, 0.0000e+00],
         [2.8999e-01, 7.1001e-01, 0.0000e+00],
         [8.1521e-01, 1.8479e-01, 0.0000e+00],
         [1.5466e-01, 8.4534e-01, 0.0000e+00]]),
 tensor([[[ 4.9717, -8.2792,  0.0000],
          [ 4.9717,  8.2792,  0.0000],
          [ 0.0000,  0.0000,  7.4843]]]))

In [10]:
collate(dataset[0])

ChemGraphBatch(pos=[9, 3], anchors=[9], anchors_len=[1], wyckoff_ops=[18, 4, 4], wyckoff_ops_pinv=[18, 3, 3], wyckoff_bat=[18], wyckoff_bat_len=[1], space_groups=[1], num_atoms=[1], conv_to_prim=[1, 3, 3], prim_to_conv=[1, 3, 3], species=[9], uniques=[9], uniques_len=[1], num_nodes=9, atomic_numbers=[9], cell=[1, 3, 3], batch=[9], ptr=[2])

In [29]:
res = []
success = []
failed = []
error = []
torch.set_num_threads(24)
for idx, data in tqdm(enumerate(dataset)):
    data_batch = collate([data])
    _success = False
    for i in range(3):
        if not _success:
            coords, lattice = _project_to_space_group(data_batch)
            structure = Structure(lattice=lattice, coords=coords, species=data_batch.species)
            try:
                spga = SpacegroupAnalyzer(structure)
            except:
                print(f"Idx: {idx} failed, spg {data.space_groups}")
                continue
            _success = True
    if not _success:
        error.append(idx)
        failed.append(idx)
        continue
    sgn = spga.get_space_group_number()
    res.append(sgn)
    if sgn == data.space_groups.item():
        success.append(idx)
    else:
        failed.append(idx)
        print(f"Idx: {idx} failed, spg {data.space_groups} != {sgn}")
print("Len of success: ", len(success))
print("Len of failed: ", len(failed))
print("Len of error: ", len(error))

0it [00:00, ?it/s]

Len of success:  100
Len of failed:  0
Len of error:  0


In [7]:
# 9994/10000 * 100 = 0.9994

In [None]:
failed_spg = []
for f in failed:
    failed_spg.append(dataset[f].space_groups.item())
# # count
import collections
failed_spg = collections.Counter(failed_spg)
print(failed_spg)

In [61]:
for f in failed:
    print(f, dataset[f].space_groups.item(), res[f])

139 225 225
229 225 225
194 5 5
221 12 12
225 2 2
225 2 2
139 225 225
221 12 12
221 12 12
139 225 225


In [62]:
failed

[139, 229, 194, 221, 225, 225, 139, 221, 221, 139]

In [6]:
res = []
success = []
failed = []
error = []
torch.set_num_threads(24)
for idx, data in tqdm(enumerate(dataset)):
    data_batch = collate([data])
    _success1 = False
    _success2 = False
    for i in range(2):
        for i in range(3):
            if not _success1:
                coords, lattice = _project_to_space_group(data_batch)
                structure = Structure(lattice=lattice, coords=coords, species=data_batch.species)
                try:
                    spga = SpacegroupAnalyzer(structure)
                except:
                    print(f"Idx: {idx} failed, spg {data.space_groups}")
                    continue
                _success1 = True
        if not _success1 and not _success2:
            error.append(idx)
            failed.append(idx)
            continue
        
        if _success1:
            sgn = spga.get_space_group_number()
            if sgn != data.space_groups.item():
                print(f"Idx: {idx} failed, spg {data.space_groups} != {sgn}, retry")
                _success2 = False
                _success1 = False
                continue
            else:
                _success2 = True
    res.append(sgn)
    if sgn == data.space_groups.item():
        success.append(idx)
    else:
        failed.append(idx)
        print(f"Idx: {idx} failed, spg {data.space_groups} != {sgn}")
print("Len of success: ", len(success))
print("Len of failed: ", len(failed))
print("Len of error: ", len(error))

0it [00:00, ?it/s]

Idx: 73 failed, spg 38
Idx: 362 failed, spg 194
Idx: 426 failed, spg 119
Idx: 506 failed, spg 136
Idx: 649 failed, spg 119
Idx: 667 failed, spg 156
Idx: 974 failed, spg 166 != 221, retry
Idx: 974 failed, spg 166 != 221, retry
Idx: 974 failed, spg 166 != 221
Idx: 1038 failed, spg 166 != 221, retry
Idx: 1359 failed, spg 194
Idx: 1729 failed, spg 38
Idx: 1859 failed, spg 71
Idx: 2538 failed, spg 139
Idx: 2603 failed, spg 38
Idx: 3649 failed, spg 63
Idx: 3732 failed, spg 160
Idx: 3830 failed, spg 148 != 166, retry
Idx: 3986 failed, spg 164
Idx: 4293 failed, spg 227
Idx: 4376 failed, spg 11


spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.


Idx: 4656 failed, spg 186 != 194, retry
Idx: 4826 failed, spg 163 != 194, retry
Idx: 5042 failed, spg 38
Idx: 5259 failed, spg 146
Idx: 5432 failed, spg 205
Idx: 5441 failed, spg 160
Idx: 5563 failed, spg 71
Idx: 5700 failed, spg 59
Idx: 6071 failed, spg 160
Idx: 6247 failed, spg 47
Idx: 6248 failed, spg 63
Idx: 6377 failed, spg 166 != 221, retry
Idx: 6998 failed, spg 166 != 221, retry
Idx: 7030 failed, spg 71 != 139, retry
Idx: 7116 failed, spg 174
Idx: 7202 failed, spg 71 != 139, retry
Idx: 7278 failed, spg 38
Idx: 7784 failed, spg 61 != 64, retry
Idx: 7981 failed, spg 26
Idx: 7998 failed, spg 20
Idx: 8201 failed, spg 156
Idx: 8427 failed, spg 71
Idx: 8918 failed, spg 42
Idx: 9090 failed, spg 139


spglib: ssm_get_exact_positions failed.
spglib: get_bravais_exact_positions_and_lattice failed.


Len of success:  9999
Len of failed:  1
Len of error:  0
