In [9]:
%load_ext autoreload
%autoreload 2
import os
import pickle
import numpy as np
import polars as pl
from pathlib import Path
from typing import List, Dict, Tuple
from tqdm import tqdm
import logging
import pyarrow as pa
import pyarrow.parquet as pq
import time
from datetime import datetime, timedelta
from try_load_dataset import find_latest_ckpt_file, load_datamodule_from_config, deep_compare
from lobster.model import LobsterPCLM2, LobsterPCLM
import sys

# QM9 and RDKit imports
from atomic_datasets import QM9
from atomic_datasets.utils.rdkit import is_molecule_sane
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs, rdFingerprintGenerator
import selfies as sf

# Import our utility functions
#from utils_qm9 import get_shape_tanimoto
import pandas as pd
#from qm9_pair_generation import is_valency_ok
from qm9_pair_gen.utils_mol import *
import sys
sys.path.append('/homefs/home/lawrenh6/lobster/src/lobster/callbacks')
from _molecule_validation_callback import MoleculeValidationCallback
import torch

# Define the test input molecule
test_mol = "CN(C)C[C@@H](O)[C@@H](c1ccccc1)c1ccc(Cl)cc1"

# Determine the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
overrides = None
config_path = '../src/lobster/hydra_config/train_molecule_improvement.yaml' #train_chembl.yaml'
datamodule, cfg, transform_fn = load_datamodule_from_config(config_path, overrides=overrides)

print("Preparing data...")
datamodule.prepare_data()

stage = "fit"
print(f"Setting up datamodule for stage: {stage}")
datamodule.setup(stage=stage)

dsets = {'train': datamodule._train_dataset, 'val': datamodule._val_dataset, 'test': datamodule._test_dataset}

config_dir: ../src/lobster/hydra_config
config_name: train_molecule_improvement
original_cwd /homefs/home/lawrenh6/lobster/notebooks
hydra initialized


INFO:try_load_dataset:Configuration loaded:
INFO:try_load_dataset:dryrun: false
run_test: null
compile: false
seed: 42
logger:
  project: null
  name: null
  entity: null
  _target_: lightning.pytorch.loggers.WandbLogger
  save_dir: .
  offline: false
  group: null
  notes: null
  tags: null
paths:
  root_dir: first_run
  output_dir: ${paths.root_dir}/${paths.timestamp}
  timestamp: ${now:%Y-%m-%d}T${now:%H-%M-%S.%f}
data:
  _target_: lobster.data._molecule_improvement_datamodule.MoleculeImprovementLightningDataModule
  root: /data/lawrenh6/cache/test_new_code/qm9_pairs_per_mol_5_full
  train_pair_filename: pairs_train.parquet
  val_pair_filename: pairs_val.parquet
  test_pair_filename: pairs_test.parquet
  utility_key: gap
  shape_tanimoto_percentile: null
  shape_tanimoto_num_pairs: 50
  delta: None
  epsilon: 0.001
  batch_size: 64
  shuffle: true
  num_workers: 4
  pin_memory: true
  drop_last: true
  max_train_samples: null
  transform_fn:
    _target_: lobster.transforms.Tokenize

Preparing data...
Setting up datamodule for stage: fit
Loaded 5480540 pairs for split 'train' using utility 'gap'
After shape Tanimoto num_pairs filtering (50 pairs per molecule): 5480540 pairs
After utility filtering (> 0.001): 5204656 pairs
Loaded 87450 pairs for split 'val' using utility 'gap'
After shape Tanimoto num_pairs filtering (50 pairs per molecule): 87450 pairs
After utility filtering (> 0.001): 897 pairs
Loaded 64375 pairs for split 'test' using utility 'gap'
After shape Tanimoto num_pairs filtering (50 pairs per molecule): 64375 pairs
After utility filtering (> 0.001): 688 pairs


In [10]:
# Load the LobsterPMLM model

ckpt_path = find_latest_ckpt_file('/data/lawrenh6/lobster_runs/val_initialized_propen_4', use_val_loss=True) # serious_chembl_large2
print(ckpt_path)

model = LobsterPCLM2("CLM_150M", transform_fn=transform_fn).to(device) # CLM_mini # CLM_150M  # CLM_mini
model.eval()

ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt['state_dict'])


/data/lawrenh6/lobster_runs/val_initialized_propen_4/2025-08-12T06-51-27.224517/epoch=0-step=20330-val_loss=0.5613.ckpt


<All keys matched successfully>

In [11]:
tokenizer = model.tokenizer

In [13]:
tokens = tokenizer.encode(test_mol)
print(tokens)

[0, 8, 16, 10, 8, 11, 8, 20, 10, 12, 11, 20, 10, 9, 13, 9, 9, 9, 9, 9, 13, 11, 9, 13, 9, 9, 9, 10, 26, 11, 9, 9, 13, 2]


In [14]:
dsets['train'][10]['input_ids']

tensor([[ 0, 53,  8,  ...,  1,  1,  1]])

In [18]:
MoleculeValidationCallback._tokens_to_smiles(tokenizer, dsets['train'][10]['input_ids'].reshape(-1).tolist())

('[H]C1([H])C([H])([H])C([H])([H])C([H])([H])C([H])([H])C([H])([H])C1([H])[H]',
 '[H]C([H])([H])[H]')

In [20]:
a, b = MoleculeValidationCallback._tokens_to_smiles(tokenizer, tokenizer.encode(test_mol))
print(a, b)

CN(C)C[C@@H](O)[C@@H](c1ccccc1)c1ccc(Cl)cc1 None


In [19]:
from _molecule_validation_callback import short_print

In [21]:
short_print(a)

'CN(C)C[C@@H](O)[C@@H](c1ccccc1)c1ccc(Cl)cc1'

In [24]:
short_print(a)

'CN(C)C[C@@H](O)[C@@H](c1ccccc1)c1ccc(Cl)cc1'

In [25]:
model = LobsterPCLM2("CLM_150M", transform_fn=transform_fn).to(device) # CLM_mini # CLM_150M  # CLM_mini
model.eval()

LobsterPCLM2(
  (_transform_fn): TokenizerTransform()
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(1226, 816, padding_idx=1)
      (layers): ModuleList(
        (0-19): 20 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=816, out_features=816, bias=False)
            (k_proj): Linear(in_features=816, out_features=816, bias=False)
            (v_proj): Linear(in_features=816, out_features=816, bias=False)
            (o_proj): Linear(in_features=816, out_features=816, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=816, out_features=2048, bias=False)
            (up_proj): Linear(in_features=816, out_features=2048, bias=False)
            (down_proj): Linear(in_features=2048, out_features=816, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((816,), eps=1e-06)
          (post_attention_layernorm): LlamaR

In [29]:
model.tokenizer._tokenizer.post_processor = None  # changes by reference, seemingly -- so 
generated = model.sample(seed_seq="<cls>", temperature=0.95, max_length=300)

Device set to use cuda:0
Both `max_new_tokens` (=256) and `max_length`(=300) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


In [36]:
generated=out

In [28]:
tokenizer.decode(tokenizer.encode("<cls>"))

'<cls> <cls> <eos>'

In [37]:
gen_ids = generated[0].detach().cpu()
gen_full_str = pl_module.tokenizer.decode(gen_ids.tolist())
gen_new_ids = gen_ids[len(unmasked_ids):]
gen_new_str = pl_module.tokenizer.decode(gen_new_ids.tolist()) if gen_new_ids.numel() > 0 else ""

gen_ids = generated[0].detach().cpu()
gen_full_str = pl_module.tokenizer.decode(gen_ids.tolist()) # space separated string
gen_new_ids = gen_ids[len(unmasked_ids):]
gen_new_str = pl_module.tokenizer.decode(gen_new_ids.tolist()) if gen_new_ids.numel() > 0 else ""


MoleculeValidationCallback._tokens_to_smiles(model.tokenizer, generated_as_list)

AttributeError: 'str' object has no attribute 'detach'

In [40]:
temp = out[0].split(" ")
temp[0] = temp[0][5:]
smiles = ''.join(temp)

In [46]:
print(smiles)
print('\n\n\n')
print(short_print(smiles, n=200))

[99Tc+3][89Zr][Lu][Al+3][H-][Ac-][18OH2][33PH][18OH2][Au+][Rh-3][PH-2][111InH3][13N+][12CH2][Sn+2]%23[AlH+2][Ir-][Tm+3][Cr+2][Se@][PoH2][Hs][89Zr+4][S-][54Cr][Pb+3]%48[PH3][7NaH]4[Hg][F-][Bi+5][Rn][Rh+4][Zr+2][Sb+][SrH2][Te+4][13NH2][12cH]%91[PbH2][Li-][177Lu][Ac-]([52Mn+2][Ca+2][Na-2][Pm][Mn+3][68Ga][20CH3][PbH2][ZrH3][Sc+2]5[Tl][BH2-][Al-]/[Xe+][18CH][15nH+][11cH][Pt-][NbH2]%28%48[Ti+5][137Cs+][U+3][Nb+3][SnH2+][18FH][V+2][Ge@@][B@@-]%78[LaH][Pm][Nb+2][SbH4][NH3+2][Ru-3][Dy+3][AlH+][14CH2-][B@-][11CH3][WH][Tc][Ir+4][cH+]%82[Ag-]\[PbH4]$[OsH6][Sn+2][14CH2][Nd+][68GaH3][Ir+2][Ru-3][SH][Os+6][Ce+4]%14[Er][ClH2+][I-][Dy+3][PH-][13CH][Ru+8][188Re][Te+4][Al-][18O][Si@@][ArH][Tc+2][Sg]%24[209Po]%54[67Cu+2][208Po][SH2+][Mn+4][125I][Cu-2]%43[Fe+6][133I][Si@@H][68Zn][Rh-3][Pd-][PtH3][Sr][Li-][SbH2][BH-][CH+][CoH+2][GeH2][13CH-][Si@][121I][122Xe][Ho+3][111InH3][15O][V][Pd+2][10C][Ce][62Ni][Al][Pr+3]\%60[PH-2][Pb+3][SbH-][Rh+][13C@@H][CH3][Al-][10CH2][Cu-][C+4][81Kr][123I-][36SH2][NH2][BH][SnH+3

In [47]:
len(short_print(smiles, n=200))

203