# SEMLAFLOW Model Evaluation

This notebook demonstrates how to:
1. Load a trained SEMLAFLOW model
2. Generate molecular conformations
3. Evaluate the quality of generated molecules
4. Visualize and analyze the results

## Setup

In [1]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import lightning as L
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdMolAlign
import py3Dmol

# Turn off rdkit logging
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Import SEMLAFLOW modules
sys.path.append('..')
import cgflow.scriptutil as util
from cgflow.buildutil import build_dm, build_model
import cgflow.util.metrics as Metrics
import cgflow.util.complex_metrics as ComplexMetrics

# Set torch properties for consistency with the evaluation script
torch.set_float32_matmul_precision("high")
L.seed_everything(12345)

[rank: 0] Seed set to 12345


12345

## Configuration

Define paths and parameters for model evaluation

In [13]:
from types import SimpleNamespace

# Path to the model checkpoint
MODEL_CHECKPOINT_PATH = "/projects/jlab/to.shen/cgflow-dev/weights/plinder_till_end.ckpt"
# Path to validation data
DATA_PATH = "/projects/jlab/to.shen/cgflow-dev/experiments/data/complex/plinder_15A"

checkpoint = torch.load(MODEL_CHECKPOINT_PATH, map_location='cpu')
args = SimpleNamespace(**checkpoint['hyper_parameters'])

In [34]:
args.num_inference_steps = 70
args.data_path = DATA_PATH
args.model_checkpoint = MODEL_CHECKPOINT_PATH
args.sampling_strategy = "linear"
args.max_atoms = 2048
IS_COMPLEX = True

## Load Model

Now let's load the trained model from the checkpoint

In [32]:
def load_model(args):
    print("Building vocabulary...")
    vocab = util.build_vocab()
    
    print("Loading validation datamodule...")
    dm = build_dm(args, vocab, mode="val")
    
    print("Building model from checkpoint...")
    model = build_model(args, dm, vocab)
    
    print(f"Loading checkpoint from {args.model_checkpoint}...")
    checkpoint = torch.load(args.model_checkpoint, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    
    # Set model to eval mode and move to GPU if available
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Set inference parameters
    model.integrator.steps = int(args.num_inference_steps)
    model.sampling_strategy = args.sampling_strategy
    
    return model, dm, vocab

# Load the model
try:
    model, dm, vocab = load_model(args)
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {str(e)}")

Building vocabulary...
Loading validation datamodule...


100%|██████████| 938/938 [00:00<00:00, 55874.81it/s]

Using type ARGeometricComplexInterpolant for training
Building model from checkpoint...

items per bucket [1, 7, 59, 342, 512, 17]
bucket batch sizes [101, 76, 50, 25, 12, 6]
batches per bucket [1, 1, 2, 14, 43, 3]
Total training steps 640000
Using model class LigandGenerator





Using CFM class ARMolecularCFM
Loading checkpoint from /projects/jlab/to.shen/cgflow-dev/weights/plinder_till_end.ckpt...
Model loaded successfully!


## Generate Molecular Conformations

Now we'll use the model to generate molecular conformations from the validation dataset

In [33]:
from tqdm import tqdm

def prepare_batch_for_generation(batch, device='cuda'):
    """Prepare a batch from the dataloader for generation"""
    pocket = None
    pocket_raw = None
    if len(batch) == 4:
        prior, data, interpolated, times = batch
        
    elif len(batch) == 6:
        prior, data, interpolated, pocket, pocket_raw, times = batch
    elif len(batch) == 7:  # AR model
        prior, data, interpolated, masked_data, times, rel_times, gen_times = batch
    elif len(batch) == 9:  # AR model with complex
        prior, data, interpolated, masked_data, pocket, pocket_raw, times, rel_times, gen_times = batch
    else:
        raise ValueError(f"Unexpected batch format with {len(batch)} elements")
    
    prior = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in prior.items()
    }
    data = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in data.items()
    }
    pocket = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in pocket.items()
    } if pocket is not None else None
    
    return prior, data, pocket, pocket_raw

def generate_molecules(model, dataloader, num_samples=5):
    """Generate molecular conformations using the model"""
    generated_mols = []
    ground_truth_mols = []
    pocket_data = []
    pocket_raws = []
    
    count = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            prior, data, pocket, pocket_raw = prepare_batch_for_generation(batch)
            # Ensure inputs are on the same device as the model
            device = next(model.parameters()).device
            # Generate molecules
            if args.categorical_strategy == "auto-regressive":
                # AR specific generation
                gen_batch = model._generate(prior, batch[-1].to(device), model.integrator.steps, 
                                                model.sampling_strategy, pocket_batch=pocket)
            else:
                # Standard generation
                gen_batch = model._generate(prior, model.integrator.steps, 
                                                model.sampling_strategy, pocket_batch=pocket)
            
            # Convert generated tensors to RDKit molecules
            gen_mols = model._generate_mols(gen_batch)
            
            # Get ground truth molecules
            data = model._batch_to_onehot(data)
            data_mols = model._generate_mols(data, rescale=True)
            
            # Add molecules to lists
            generated_mols.extend(gen_mols)
            ground_truth_mols.extend(data_mols)
            
            # Store pocket data if available
            if len(batch) == 6 or len(batch) == 9:
                pocket_data.append(pocket)
                pocket_raws.extend(pocket_raw)
            
            count += len(gen_mols)
            if count >= num_samples:
                break
    
    return generated_mols, ground_truth_mols, pocket_data if pocket_data else None, pocket_raws if pocket_raws else None

# Prepare dataloader
dataloader = dm.train_dataloader()

# Generate molecules
generated_mols, ground_truth_mols, pocket_data, pocket_raw = generate_molecules(model, dataloader, num_samples=10)

print(f"Generated {len(generated_mols)} molecules")
print(f"Retrieved {len(ground_truth_mols)} ground truth molecules")


items per bucket [1, 7, 59, 342, 512, 17]
bucket batch sizes [101, 76, 50, 25, 12, 6]
batches per bucket [1, 1, 2, 14, 43, 3]


  0%|          | 0/64 [01:34<?, ?it/s]

Generated 12 molecules
Retrieved 12 ground truth molecules






Let's calculate quality metrics for the generated molecules

In [35]:
import numpy as np
import pandas as pd

def calculate_metrics(generated_mols, reference_mols=None, pocket_raw=None):
    """Calculate quality metrics for generated molecules"""
    all_metrics = {}

    # Initialize all metrics with empty lists to gather per-sample results
    metric_functions = {
        "validity": Metrics.Validity(),
        "fc_validity": Metrics.Validity(connected=True),
        "uniqueness": Metrics.Uniqueness(),
        "energy_validity": Metrics.EnergyValidity(),
        "average_energy": Metrics.AverageEnergy(),
        "average_energy_per_atom": Metrics.AverageEnergy(per_atom=True),
        "average_strain": Metrics.AverageStrainEnergy(),
        "average_strain_per_atom": Metrics.AverageStrainEnergy(per_atom=True),
        "average_opt_rmsd": Metrics.AverageOptRmsd()
    }

    if reference_mols:
        metric_functions.update({
            "molecular_accuracy": Metrics.MolecularAccuracy(),
            "pair_rmsd": Metrics.MolecularPairRMSD(),
            "pair_no_align_rmsd": Metrics.MolecularPairRMSD(align=False)
        })

    if pocket_raw and IS_COMPLEX:
        metric_functions.update({
            "clash_score": ComplexMetrics.Clash(),
            "interactions": ComplexMetrics.Interactions()
        })

    # Collect individual metric values
    for key, metric in metric_functions.items():
        results = []
        for idx in range(len(generated_mols)):
            mol = generated_mols[idx:idx+1]
            ref = reference_mols[idx:idx+1] if reference_mols else None
            pocket = pocket_raw[idx:idx+1] if pocket_raw else None

            if key.startswith("interactions"):
                interaction_values = metric(mol, pocket or pocket_raw)
                for int_key, val in interaction_values.items():
                    all_metrics.setdefault(f"interactions_{int_key}", []).append(val)
            else:
                if "pair" in key or "accuracy" in key:
                    val = metric(mol, ref)
                elif "clash" in key or "interactions" in key:
                    val = metric(mol, pocket or pocket_raw)
                else:
                    val = metric(mol)
                all_metrics.setdefault(key, []).append(val)

    # Compute mean, median, std for each metric
    summary = {"Metric": [], "Mean": [], "Median": [], "Std": []}
    for key, values in all_metrics.items():
        values = np.array(values, dtype=np.float32)
        # remove nan values
        values = values[~np.isnan(values)]
        
        summary["Metric"].append(key)
        summary["Mean"].append(values.mean())
        summary["Median"].append(np.median(values))
        summary["Std"].append(values.std())

    return pd.DataFrame(summary), all_metrics


# Calculate metrics
n = 10
metrics_df, all_metrics = calculate_metrics(generated_mols[:n], ground_truth_mols[:n], pocket_raw[:n])

# Display table
display(metrics_df)

MDAnalysis.topology.tables has been moved to MDAnalysis.guesser.tables. This import point will be removed in MDAnalysis version 3.0.0


  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Metric,Mean,Median,Std
0,validity,0.8,1.0,0.4
1,fc_validity,0.8,1.0,0.4
2,uniqueness,1.0,1.0,0.0
3,energy_validity,0.8,1.0,0.4
4,average_energy,210.92749,153.877335,228.632675
5,average_energy_per_atom,7.155328,7.416623,5.780839
6,average_strain,254.789749,192.071136,207.724167
7,average_strain_per_atom,9.453641,9.989138,3.702522
8,average_opt_rmsd,1.45938,1.35145,0.798215
9,molecular_accuracy,1.0,1.0,0.0


In [34]:
len([i.item() for i in all_metrics['pair_no_align_rmsd'] if i.item() < 2]) / len(all_metrics['pair_no_align_rmsd'])

0.21428571428571427

## Visualize Molecules

Let's visualize some of the generated molecules alongside their ground truth counterparts

In [37]:
def visualize_molecule_2d(mol, title="Molecule"):
    """Visualize an RDKit molecule in 2D"""
    if mol is None:
        return HTML(f"<p>{title}: Invalid molecule</p>")
    
    mol = Chem.Mol(mol)
    AllChem.Compute2DCoords(mol)
    img = Draw.MolToImage(mol, size=(300, 300))
    
    plt.figure(figsize=(3, 3))
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def visualize_molecule_3d(mol, width=400, height=400, style="stick"):
    """Visualize an RDKit molecule in 3D using py3Dmol"""
    if mol is None:
        return HTML("<p>Invalid molecule</p>")
    
    mol = Chem.Mol(mol)
    
    viewer = py3Dmol.view(width=width, height=height)
    mb = Chem.MolToMolBlock(mol)
    viewer.addModel(mb, 'mol')
    viewer.setStyle({style:{}})
    viewer.zoomTo()
    viewer.render()
    return viewer

def compare_molecules_3d(gen_mol, ref_mol, align=True, width=400, height=400):
    """Compare generated and reference molecules in 3D"""
    if gen_mol is None or ref_mol is None:
        return HTML("<p>One or more invalid molecules</p>")
    
    gen_mol = Chem.Mol(gen_mol)
    ref_mol = Chem.Mol(ref_mol)
    
    # Align molecules if requested
    if align:
        rdMolAlign.AlignMol(gen_mol, ref_mol)
    
    viewer = py3Dmol.view(width=width, height=height)
    
    # Add reference molecule (green)
    mb_ref = Chem.MolToMolBlock(ref_mol)
    viewer.addModel(mb_ref, 'ref')
    viewer.setStyle({'ref': {'stick': {'color': 'green'}}})
    
    # Add generated molecule (blue)
    mb_gen = Chem.MolToMolBlock(gen_mol)
    viewer.addModel(mb_gen, 'gen')
    viewer.setStyle({'gen': {'stick': {'color': 'blue'}}})
    
    viewer.zoomTo()
    viewer.render()
    return viewer

# Visualize a few molecules
for i in range(len(generated_mols)):
    if generated_mols[i] is not None and ground_truth_mols[i] is not None:
        print(f"\nMolecule {i+1}")
        print("Generated SMILES:", Chem.MolToSmiles(generated_mols[i]) if generated_mols[i] else "Invalid")
        print("Reference SMILES:", Chem.MolToSmiles(ground_truth_mols[i]) if ground_truth_mols[i] else "Invalid")

        # 3D visualization of generated molecule
        print("\n3D Structure (Generated):")
        gen_view = visualize_molecule_3d(generated_mols[i], width=400, height=400)
        display(gen_view)
        
        # 3D visualization of reference molecule
        print("\n3D Structure (Reference):")
        ref_view = visualize_molecule_3d(ground_truth_mols[i], width=400, height=400)
        display(ref_view)
        # Calculate RMSD between the molecules with and without alignment
        pair_no_align_rmsd_metric = Metrics.MolecularPairRMSD(align=False)
        pair_no_align_rmsd = pair_no_align_rmsd_metric([generated_mols[i]], [ground_truth_mols[i]])
        print(f"Pair RMSD (no alignment): {pair_no_align_rmsd:.3f} Å")


Molecule 1
Generated SMILES: CC(C=CC=C(C)C=CC1=C(C)CCCC1(C)C)=CC=O
Reference SMILES: CC(C=CC=C(C)C=CC1=C(C)CCCC1(C)C)=CC=O

3D Structure (Generated):


<py3Dmol.view at 0x14c9365a9c50>


3D Structure (Reference):


<py3Dmol.view at 0x14c9354d2b10>

Pair RMSD (no alignment): 1.633 Å

Molecule 2
Generated SMILES: CC(=O)[O-]
Reference SMILES: CC(=O)[O-]

3D Structure (Generated):


<py3Dmol.view at 0x14c9188a24d0>


3D Structure (Reference):


<py3Dmol.view at 0x14c9365a9c50>

Pair RMSD (no alignment): 2.582 Å

Molecule 3
Generated SMILES: CN1CCN(CCCN2c3ccccc3Sc3ccc(Cl)cc32)CC1
Reference SMILES: CN1CCN(CCCN2c3ccccc3Sc3ccc(Cl)cc32)CC1

3D Structure (Generated):


<py3Dmol.view at 0x14c918956d90>


3D Structure (Reference):


<py3Dmol.view at 0x14c9188a24d0>

Pair RMSD (no alignment): 3.762 Å

Molecule 5
Generated SMILES: CC(=O)NC1COC(CO)C(OC2OC(CO)C(O)C(O)C2NC(C)=O)C1O
Reference SMILES: CC(=O)NC1COC(CO)C(OC2OC(CO)C(O)C(O)C2NC(C)=O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x14c9329255d0>


3D Structure (Reference):


<py3Dmol.view at 0x14c918956d90>

Pair RMSD (no alignment): 4.765 Å

Molecule 6
Generated SMILES: OCC1OC(OC2C(CO)OC(OC3C(CO)OC(O)C(O)C3O)C(O)C2O)C(O)C(O)C1O
Reference SMILES: OCC1OC(OC2C(CO)OC(OC3C(CO)OC(O)C(O)C3O)C(O)C2O)C(O)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x14c91896ab50>


3D Structure (Reference):


<py3Dmol.view at 0x14ca90328750>

Pair RMSD (no alignment): 5.234 Å

Molecule 7
Generated SMILES: CC(=O)NC1COC(CO)C(OC2OC(CO)CC(O)C2NC(C)=O)C1O
Reference SMILES: CC(=O)NC1COC(CO)C(OC2OC(CO)CC(O)C2NC(C)=O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x14c935416f10>


3D Structure (Reference):


<py3Dmol.view at 0x14ca90615a10>

Pair RMSD (no alignment): 5.003 Å

Molecule 9
Generated SMILES: Cc1ncc(COP(=O)(O)O)c(CO)c1O
Reference SMILES: Cc1ncc(COP(=O)(O)O)c(CO)c1O

3D Structure (Generated):


<py3Dmol.view at 0x14ca90328750>


3D Structure (Reference):


<py3Dmol.view at 0x14c932924490>

Pair RMSD (no alignment): 2.712 Å

Molecule 10
Generated SMILES: CC(C)(COP(=O)(O)OP(=O)(O)OCC1OC(n2cnc3c(N)ncnc32)C(O)C1OP(=O)(O)O)C(O)C(=O)NCCC(=O)NCCS
Reference SMILES: CC(C)(COP(=O)(O)OP(=O)(O)OCC1OC(n2cnc3c(N)ncnc32)C(O)C1OP(=O)(O)O)C(O)C(=O)NCCC(=O)NCCS

3D Structure (Generated):


<py3Dmol.view at 0x14c935797590>


3D Structure (Reference):


<py3Dmol.view at 0x14c91873b110>

Pair RMSD (no alignment): 10.608 Å

Molecule 11
Generated SMILES: OCC(O)CO
Reference SMILES: OCC(O)CO

3D Structure (Generated):


<py3Dmol.view at 0x14c9189d0490>


3D Structure (Reference):


<py3Dmol.view at 0x14c9187bdb90>

Pair RMSD (no alignment): 1.309 Å

Molecule 12
Generated SMILES: Nc1ccn(C2OC(COP(=O)(O)O)C(O)C2O)c(=O)n1
Reference SMILES: Nc1ccn(C2OC(COP(=O)(O)O)C(O)C2O)c(=O)n1

3D Structure (Generated):


<py3Dmol.view at 0x14c935332910>


3D Structure (Reference):


<py3Dmol.view at 0x14c932924490>

Pair RMSD (no alignment): 2.206 Å

Molecule 13
Generated SMILES: Nc1ccn(C2OC(COP(=O)(O)OP(=O)(O)OP(=O)(O)O)C(O)C2O)c(=O)n1
Reference SMILES: Nc1ccn(C2OC(COP(=O)(O)OP(=O)(O)OP(=O)(O)O)C(O)C2O)c(=O)n1

3D Structure (Generated):


<py3Dmol.view at 0x14c9354a41d0>


3D Structure (Reference):


<py3Dmol.view at 0x14c935795750>

Pair RMSD (no alignment): 2.438 Å

Molecule 14
Generated SMILES: CC(=O)NC1COC(CO)C(O)C1O
Reference SMILES: CC(=O)NC1COC(CO)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x14c93422cad0>


3D Structure (Reference):


<py3Dmol.view at 0x14c933e1f310>

Pair RMSD (no alignment): 1.654 Å


## Complex Visualization (for Protein-Ligand Complexes)

If we're working with protein-ligand complexes, let's visualize the binding poses

In [39]:
%load_ext autoreload
%autoreload 2

from cgflow.util.visualize import complex_to_3dview

# Visualize protein-ligand complexes if applicable
if IS_COMPLEX and pocket_raw is not None:
    for i in range(min(3, len(generated_mols))):
        if generated_mols[i] is not None:
            print(f"\nProtein-Ligand Complex {i+1}")
            view = complex_to_3dview(generated_mols[i], pocket_raw[i])
            display(view)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Protein-Ligand Complex 1


<py3Dmol.view at 0x14c934276590>


Protein-Ligand Complex 2


<py3Dmol.view at 0x14c934257710>


Protein-Ligand Complex 3


<py3Dmol.view at 0x14c934276590>

In [None]:

import os

# Create output directories
output_dir = "../temp"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "molecules"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "proteins"), exist_ok=True)

# Save generated molecules to SDF
from rdkit import Chem
for i, mol in enumerate(generated_mols):
    if mol is not None:
        # Save individual molecule
        mol = Chem.AddHs(mol, addCoords=True)
        mol_path = os.path.join(output_dir, "molecules", f"molecule_{i+1}.sdf")
        writer = Chem.SDWriter(mol_path)
        writer.write(mol)
        writer.close()

# Save all molecules to a single SDF file
all_mols_path = os.path.join(output_dir, "all_molecules.sdf")
writer = Chem.SDWriter(all_mols_path)
for mol in generated_mols:
    if mol is not None:
        writer.write(mol)
writer.close()
print(f"Saved {sum(1 for mol in generated_mols if mol is not None)} molecules to {output_dir}")

# Save protein pockets to PDB
if pocket_raw and len(pocket_raw) > 0:
    for i, pocket in enumerate(pocket_raw):
        if pocket is not None:
            pocket_path = os.path.join(output_dir, "proteins", f"pocket_{i+1}.pdb")
            pocket.write_pdb(pocket_path)
    print(f"Saved {sum(1 for p in pocket_raw if p is not None)} protein pockets to {output_dir}")

print(f"All files saved to {output_dir}")

Saved 12 molecules to ../temp
Saved 14 protein pockets to ../temp
All files saved to ../temp


: 

## Conclusion

In this notebook, we loaded a trained SEMLAFLOW model, generated molecular conformations, evaluated them with several metrics, and visualized both 2D and 3D structures including complex (e.g. protein-ligand) representations.