In [2]:
import numpy as np
import pandas as pd
import torch

from decifer import (
    DeciferDataset,
    Decifer,
    DeciferConfig,
    load_model_from_checkpoint,
    Tokenizer,
    extract_prompt,
    replace_symmetry_loop_with_P1,
    extract_space_group_symbol,
    reinstate_symmetry_loop,
)

In [3]:
# load a cif string and its xrd
dataset = DeciferDataset('../data/crystallm/full/serialized/test.h5', ['cif_string', 'cif_tokenized', 'xrd_cont.iq'])
decode = Tokenizer().decode
padding_id = Tokenizer().padding_id
dataset_iter = iter(dataset)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model_from_checkpoint('../deep_conditioning/ckpt.pt', device)
model.eval()

number of total parameters: 25.75M


Decifer(
  (transformer): ModuleDict(
    (cond_embedding): Sequential(
      (0): Linear(in_features=1000, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=512, bias=True)
    )
    (wte): Embedding(372, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=False)
          (c_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=512, out_features=2048, bias=False)
          (c_proj): Linear(in_features=2048, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): Layer

In [6]:
cif_string, cif_tokenized, xrd = next(dataset_iter)

prompt = extract_prompt(cif_tokenized, model.device, add_composition=True, add_spacegroup=True).unsqueeze(0)
# print(prompt)
# print(len())

cond_vec = xrd.to(model.device).unsqueeze(0)
print(cond_vec.shape)
print(len())
token_ids = model.generate_batched_reps(prompt, max_new_tokens=1000, cond_vec=cond_vec, start_indices_batch=[[0]]).cpu().numpy()
token_ids = [ids[ids != padding_id] for ids in token_ids]  # Remove padding tokens

out_cif = decode(list(token_ids[0]))
out_cif = replace_symmetry_loop_with_P1(out_cif)

# Extract space group symbol from the CIF string
spacegroup_symbol = extract_space_group_symbol(out_cif)

# If the space group is not "P1", reinstate symmetry
if spacegroup_symbol != "P 1":
    out_cif = reinstate_symmetry_loop(out_cif, spacegroup_symbol)
    
print(cif_string)
print()
print(out_cif)

torch.Size([1, 1000])


TypeError: len() takes exactly one argument (0 given)