In [3]:
import numpy as np
import pandas as pd
import torch

from bin_refactored.evaluate import load_model_from_checkpoint
from bin_refactored.train import TrainConfig
from decifer_refactored.tokenizer import Tokenizer
from decifer_refactored.decifer_model import Decifer
from decifer_refactored.utility import (
    reinstate_symmetry_loop,
    replace_symmetry_loop_with_P1,
    extract_space_group_symbol,
    generate_continuous_xrd_from_cif,
    discrete_to_continuous_xrd,
)

import warnings
warnings.simplefilter("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model_from_checkpoint('../experiments/model__nocond__context_3076__robust/ckpt.pt', device)
model.eval()

encode = Tokenizer().encode
decode = Tokenizer().decode
padding_id = Tokenizer().padding_id

number of total non-trainable parameters: 26.14M
number of total trainable parameters: 27.72M
number of total conditioning MLP parameters: 0.78M


In [47]:
prompt = torch.tensor(encode(['data_', 'Ca', '3', 'B', 'Ti', 'Ge', '3', 'O', '1', '2', '\n'])).unsqueeze(0).to(device=model.device)

In [53]:
token_ids = model.generate_batched_reps(prompt, max_new_tokens=3073, cond_vec=None, start_indices_batch=[[0]]).cpu().numpy()

Generating sequence:   0%|          | 0/3073 [00:00<?, ?it/s]

In [54]:
token_ids = [ids[ids != padding_id] for ids in token_ids]  # Remove padding tokens

out_cif = decode(list(token_ids[0]))
out_cif = replace_symmetry_loop_with_P1(out_cif)

# Extract space group symbol from the CIF string
spacegroup_symbol = extract_space_group_symbol(out_cif)

# If the space group is not "P1", reinstate symmetry
if spacegroup_symbol != "P 1":
    out_cif = reinstate_symmetry_loop(out_cif, spacegroup_symbol)

In [55]:
print(out_cif)

data_Ca3BTiGe3O12
loop_
_atom_type_symbol
_atom_type_electronegativity
_atom_type_radius
_atom_type_ionic_radius
Ca 1.0000 1.8000 1.1400
B 2.0400 0.8500 0.4100
O 3.4400 0.6000 1.2600
_symmetry_space_group_name_H-M P-62m
_cell_length_a 8.3328
_cell_length_b 8.3328
_cell_length_c 3.7671
_cell_angle_alpha 90.0000
_cell_angle_beta 90.0000
_cell_angle_gamma 120.0000
_symmetry_Int_Tables_number 189
_chemical_formula_structural CaBO4
_chemical_formula_sum 'Ca3 B3 O12'
_cell_volume 226.7012
_cell_formula_units_Z 3
loop_
 _symmetry_equiv_pos_site_id
 _symmetry_equiv_pos_as_xyz
  1  'x-y, -y, z'
  2  '-x, -x+y, -z'
  3  'x, y, -z'
  4  '-x+y, -x, z'
  5  '-x+y, -x, -z'
  6  '-y, x-y, z'
  7  'x, y, z'
  8  '-y, x-y, -z'
  9  'y, x, z'
  10  'x-y, -y, -z'
  11  '-x, -x+y, z'
  12  'y, x, -z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ca Ca0 3 0.0000 0.5996 0.5000 1.0000
B B1 2 0.3333 