In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from argparse import Namespace
import yaml
import sys
import numpy as np  
from pathlib import Path
from rdkit import Chem
import random

REPO_ROOT = "/mnt/STORAGE3/sebastian2/DiffSBDD"
if REPO_ROOT not in sys.path:
    sys.path.append(REPO_ROOT)

from lightning_modules import LigandPocketDDPM

def to_ns(d):
    if isinstance(d, dict):
        return Namespace(**{k: to_ns(v) for k, v in d.items()})
    return d

def load_model_using_config(config, ckpt_path=None):
    with open(config, "r") as f:
        cfg = yaml.safe_load(f)
    args = to_ns(cfg)

    # args.wandb_params.mode = "disabled"  # disable wandb in notebook runs
    args.enable_progress_bar = True


    # Required histogram file from the processed dataset
    histogram_file = Path(args.datadir, "size_distribution.npy")
    histogram = np.load(histogram_file).tolist()

    # Build LightningModule with text conditioning
    pl_module = LigandPocketDDPM(
        outdir=Path(args.logdir, args.run_name),
        dataset=args.dataset,
        datadir=args.datadir,
        batch_size=args.batch_size,
        lr=args.lr,
        egnn_params=args.egnn_params,
        diffusion_params=args.diffusion_params,
        num_workers=args.num_workers,
        augment_noise=args.augment_noise,
        augment_rotation=args.augment_rotation,
        clip_grad=args.clip_grad,
        eval_epochs=args.eval_epochs,
        eval_params=args.eval_params,
        visualize_sample_epoch=args.visualize_sample_epoch,
        visualize_chain_epoch=args.visualize_chain_epoch,
        auxiliary_loss=args.auxiliary_loss,
        loss_params=args.loss_params,
        mode=args.mode,
        node_histogram=histogram,
        pocket_representation=args.pocket_representation,
        text_model_name=args.text_model_name,
        text_embeddings_path=args.text_embeddings_path,
        # text_csv=TEXT_CSV,
    )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if ckpt_path is not None:
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        _, _ = pl_module.load_state_dict(checkpoint['state_dict'], strict=False)
    pl_module = pl_module.to(device)
    pl_module.eval()
    
    return pl_module
    

In [None]:


# ---- User inputs ----
CKPT_PATH = "/mnt/STORAGE3/sebastian2/DiffSBDD/logs/SE3-cond-fullv2/checkpoints/best-model-epoch=epoch=42.ckpt"
CONFIG_YML = "/mnt/STORAGE3/sebastian2/DiffSBDD/configs/FT_crossdock_fullatom_cond.yml"
TEXT_DESCRIPTION = """
Generate a molecule containg a sulfate group bound to a ring system"""

N_SAMPLES = 20
POCKET_PDB = "/mnt/STORAGE3/sebastian2/DiffSBDD/example/5ndu.pdb"  # Example PDB file#
LIGAND_SDF = "/mnt/STORAGE3/sebastian2/DiffSBDD/example/5ndu_C_8V2.sdf"  # Example SDF file
OUT_SDF = Path(Path(POCKET_PDB).stem + "_mol.sdf")
# --- Model and device setup ---


model = load_model_using_config(CONFIG_YML, CKPT_PATH)

# --- Inference: Generate ligands ---
# If your model uses text conditioning, pass text_description. Otherwise, omit it.
with torch.no_grad():
    molecules = model.generate_ligands(
        pdb_file=POCKET_PDB,
        n_samples=N_SAMPLES,
        text_description=TEXT_DESCRIPTION,
        ref_ligand=LIGAND_SDF,
        sanitize=True,
        largest_frag=True,
        relax_iter=200,  # Optional: number of force field optimization steps
        n_nodes_min=15
    )

from rdkit import Chem


# # Save to SDF
w = Chem.SDWriter(str(OUT_SDF))
for mol in molecules:
    if mol is not None:
        w.write(mol)
w.close()

# Optionally, print SMILES
from rdkit.Chem import MolToSmiles
for i, mol in enumerate(molecules):
    if mol is not None:
        print(f"Molecule {i+1}: {MolToSmiles(mol)}")

In [None]:
import py3Dmol

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(POCKET_PDB, 'r').read(), 'pdb')
view.setStyle({'model': -1}, {'cartoon': {'color': 'lime'}})
# view.addSurface(py3Dmol.VDW, {'opacity': 0.4, 'color': 'lime'})
view.addModelsAsFrames(open(OUT_SDF, 'r').read())
view.setStyle({'model': -1}, {'stick': {}})
view.zoomTo({'model': -1})
view.zoom(0.5)
# if target == "example (3rfm)":
#   view.rotate(90, 'y')
view.animate({'loop': "forward", 'interval': 1000})
view.show()