# SEMLAFLOW Model Evaluation

This notebook demonstrates how to:
1. Load a trained SEMLAFLOW model
2. Generate molecular conformations
3. Evaluate the quality of generated molecules
4. Visualize and analyze the results

## Setup

In [1]:
import sys
import posecheck
import numpy as np
import torch
import matplotlib.pyplot as plt
import lightning as L
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdMolAlign
import py3Dmol

# Turn off rdkit logging
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Import SEMLAFLOW modules
sys.path.append('..')
import cgflow.scriptutil as util
from cgflow.buildutil import build_dm, build_model
import cgflow.util.metrics as Metrics
import cgflow.util.complex_metrics as ComplexMetrics

# Set torch properties for consistency with the evaluation script
torch.set_float32_matmul_precision("high")
L.seed_everything(12345)

MDAnalysis.topology.tables has been moved to MDAnalysis.guesser.tables. This import point will be removed in MDAnalysis version 3.0.0
[rank: 0] Seed set to 12345


12345

## Configuration

Define paths and parameters for model evaluation

In [2]:
# Path to the model checkpoint
MODEL_CHECKPOINT_PATH = "/projects/jlab/to.shen/cgflow-dev/wandb/equinv-plinder/yao66v3b/checkpoints/last.ckpt"
# Path to validation data
DATA_PATH = "/projects/jlab/to.shen/cgflow-dev/data/complex/plinder/smol"

# Dataset name
DATASET = "plinder"

# Number of molecules to evaluate
NUM_EVAL_MOLS = np.inf

# Number of inference steps
NUM_INFERENCE_STEPS = 100

# Whether the data involves protein-ligand complexes
IS_COMPLEX = DATASET in ["plinder", "crossdock", "zinc15m"] or False

# Create a class to simulate command line arguments
class Args:
    def __init__(self):
        pass

args = Args()

# Set required arguments
args.model_checkpoint = MODEL_CHECKPOINT_PATH
args.data_path = DATA_PATH
args.dataset = DATASET
args.n_validation_mols = NUM_EVAL_MOLS
args.num_inference_steps = NUM_INFERENCE_STEPS
args.num_gpus = 1
args.is_pseudo_complex = False
args.batch_cost = 1200
args.use_complex_metrics = IS_COMPLEX
args.sampling_strategy = "linear"
args.num_workers = 0

# Model architecture parameters - these should match the trained model
args.d_model = 384
args.n_layers = 12
args.d_message = 64
args.d_edge = 128
args.n_coord_sets = 64
args.n_attn_heads = 32
args.d_message_hidden = 96
args.coord_norm = "length"
args.size_emb = 64
args.max_atoms = 256
args.pocket_n_layers = 4
args.pocket_d_inv = 256
args.fixed_equi = False

# Flow matching parameters
args.categorical_strategy = "auto-regressive"
args.conf_coord_strategy = "gaussian"
args.optimal_transport = None
args.cat_sampling_noise_level = 1
args.coord_noise_std_dev = 0.2
args.type_dist_temp = 1.0
args.time_alpha = 1.0
args.time_beta = 1.0
args.t_per_ar_action = 0.3
args.max_interp_time = 0.4
args.max_action_t = 0.6
args.max_num_cuts = 2
args.dist_loss_weight = 0.0
args.type_loss_weight = 0.0
args.bond_loss_weight = 0.0
args.charge_loss_weight = 0.0
args.monitor = "val-strain"
args.monitor_mode = "min"
args.val_check_epochs = 1


# Autoregressive parameters (only needed if model was trained with AR)
args.t_per_ar_action = 0.3  # updated
args.max_interp_time = 0.4  # updated
args.decomposition_strategy = "reaction"  # updated
args.ordering_strategy = "connected"  # updated
args.max_action_t = 0.6  # updated
args.max_num_cuts = 2  # updated
args.min_group_size = 5

# Model loading defaults
args.arch = "semla"
args.trial_run = False
args.use_ema = True
args.self_condition = True
args.lr = 0.0003
args.type_loss_weight = 0.0  # updated
args.bond_loss_weight = 0.0  # updated
args.charge_loss_weight = 0.0  # updated
args.dist_loss_weight = 0.0  # updated
args.lr_schedule = "constant"
args.warm_up_steps = 10000
args.bucket_cost_scale = "linear"
args.epochs = 1
args.acc_batches = 1
args.val_check_epochs = 1  # updated
args.gradient_clip_val = 1.0
args.monitor = "val-strain"  # updated
args.monitor_mode = "min"  # updated

args.n_training_mols = np.inf

## Load Model

Now let's load the trained model from the checkpoint

In [3]:
def load_model(args):
    print("Building vocabulary...")
    vocab = util.build_vocab()
    
    print("Loading validation datamodule...")
    dm = build_dm(args, vocab)
    
    print("Building model from checkpoint...")
    model = build_model(args, dm, vocab)
    
    print(f"Loading checkpoint from {args.model_checkpoint}...")
    checkpoint = torch.load(args.model_checkpoint, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    
    # Set model to eval mode and move to GPU if available
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Set inference parameters
    model.integrator.steps = int(args.num_inference_steps)
    model.sampling_strategy = args.sampling_strategy
    
    return model, dm, vocab

# Load the model
try:
    model, dm, vocab = load_model(args)
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {str(e)}")

Building vocabulary...
Loading validation datamodule...
Using type ARGeometricComplexInterpolant for training
Building model from checkpoint...

items per bucket [30389, 0, 0, 0, 0, 0]
bucket batch sizes [9, 7, 4, 2, 1, 1]
batches per bucket [3377, 0, 0, 0, 0, 0]
Total training steps 3377
Using model class LigandGenerator
Using CFM class ARMolecularCFM
Loading checkpoint from /projects/jlab/to.shen/cgflow-dev/wandb/equinv-plinder/yao66v3b/checkpoints/last.ckpt...
Model loaded successfully!


In [7]:
model.integrator.coord_noise_std_dev = 0.0

In [10]:
model.integrator.steps

100

## Generate Molecular Conformations

Now we'll use the model to generate molecular conformations from the validation dataset

In [8]:
from tqdm import tqdm

def prepare_batch_for_generation(batch, device='cuda'):
    """Prepare a batch from the dataloader for generation"""
    pocket = None
    pocket_raw = None
    if len(batch) == 4:
        prior, data, interpolated, times = batch
        
    elif len(batch) == 6:
        prior, data, interpolated, pocket, pocket_raw, times = batch
    elif len(batch) == 7:  # AR model
        prior, data, interpolated, masked_data, times, rel_times, gen_times = batch
    elif len(batch) == 9:  # AR model with complex
        prior, data, interpolated, masked_data, pocket, pocket_raw, times, rel_times, gen_times = batch
    else:
        raise ValueError(f"Unexpected batch format with {len(batch)} elements")
    
    prior = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in prior.items()
    }
    data = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in data.items()
    }
    pocket = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in pocket.items()
    } if pocket is not None else None
    
    return prior, data, pocket, pocket_raw

def generate_molecules(model, dataloader, num_samples=5):
    """Generate molecular conformations using the model"""
    generated_mols = []
    ground_truth_mols = []
    pocket_data = []
    pocket_raws = []
    
    count = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            prior, data, pocket, pocket_raw = prepare_batch_for_generation(batch)
            # Ensure inputs are on the same device as the model
            device = next(model.parameters()).device
            # Generate molecules
            if args.categorical_strategy == "auto-regressive":
                # AR specific generation
                gen_batch = model._generate(prior, batch[-1].to(device), model.integrator.steps, 
                                                model.sampling_strategy, pocket_batch=pocket)
            else:
                # Standard generation
                gen_batch = model._generate(prior, model.integrator.steps, 
                                                model.sampling_strategy, pocket_batch=pocket)
            
            # Convert generated tensors to RDKit molecules
            gen_mols = model._generate_mols(gen_batch)
            
            # Get ground truth molecules
            data = model._batch_to_onehot(data)
            data_mols = model._generate_mols(data, rescale=True)
            
            # Add molecules to lists
            generated_mols.extend(gen_mols)
            ground_truth_mols.extend(data_mols)
            
            # Store pocket data if available
            if len(batch) == 6 or len(batch) == 9:
                pocket_data.append(pocket)
                pocket_raws.extend(pocket_raw)
            
            count += len(gen_mols)
            if count >= num_samples:
                break
    
    return generated_mols, ground_truth_mols, pocket_data if pocket_data else None, pocket_raws if pocket_raws else None

# Prepare dataloader
dataloader = dm.val_dataloader()

# Generate molecules
generated_mols, ground_truth_mols, pocket_data, pocket_raw = generate_molecules(model, dataloader, num_samples=12)

print(f"Generated {len(generated_mols)} molecules")
print(f"Retrieved {len(ground_truth_mols)} ground truth molecules")


items per bucket [56, 0, 0, 0, 0, 0]
bucket batch sizes [9, 7, 4, 2, 1, 1]
batches per bucket [7, 0, 0, 0, 0, 0]


 14%|█▍        | 1/7 [01:05<06:34, 65.74s/it]

Generated 18 molecules
Retrieved 18 ground truth molecules






Let's calculate quality metrics for the generated molecules

In [12]:
import numpy as np
import pandas as pd

def calculate_metrics(generated_mols, reference_mols=None, pocket_raw=None):
    """Calculate quality metrics for generated molecules"""
    all_metrics = {}

    # Initialize all metrics with empty lists to gather per-sample results
    metric_functions = {
        "validity": Metrics.Validity(),
        "fc_validity": Metrics.Validity(connected=True),
        "uniqueness": Metrics.Uniqueness(),
        "energy_validity": Metrics.EnergyValidity(),
        "average_energy": Metrics.AverageEnergy(),
        "average_energy_per_atom": Metrics.AverageEnergy(per_atom=True),
        "average_strain": Metrics.AverageStrainEnergy(),
        "average_strain_per_atom": Metrics.AverageStrainEnergy(per_atom=True),
        "average_opt_rmsd": Metrics.AverageOptRmsd()
    }

    if reference_mols:
        metric_functions.update({
            "molecular_accuracy": Metrics.MolecularAccuracy(),
            "pair_rmsd": Metrics.MolecularPairRMSD(),
            "pair_no_align_rmsd": Metrics.MolecularPairRMSD(align=False)
        })

    if pocket_raw and IS_COMPLEX:
        metric_functions.update({
            "clash_score": ComplexMetrics.Clash(),
            "interactions": ComplexMetrics.Interactions()
        })

    # Collect individual metric values
    for key, metric in metric_functions.items():
        results = []
        for idx in range(len(generated_mols)):
            mol = generated_mols[idx:idx+1]
            ref = reference_mols[idx:idx+1] if reference_mols else None
            pocket = pocket_raw[idx:idx+1] if pocket_raw else None

            if key.startswith("interactions"):
                interaction_values = metric(mol, pocket or pocket_raw)
                for int_key, val in interaction_values.items():
                    all_metrics.setdefault(f"interactions_{int_key}", []).append(val)
            else:
                if "pair" in key or "accuracy" in key:
                    val = metric(mol, ref)
                elif "clash" in key or "interactions" in key:
                    val = metric(mol, pocket or pocket_raw)
                else:
                    val = metric(mol)
                all_metrics.setdefault(key, []).append(val)

    # Compute mean, median, std for each metric
    summary = {"Metric": [], "Mean": [], "Median": [], "Std": []}
    for key, values in all_metrics.items():
        values = np.array(values, dtype=np.float32)
        # remove nan values
        values = values[~np.isnan(values)]
        
        summary["Metric"].append(key)
        summary["Mean"].append(values.mean())
        summary["Median"].append(np.median(values))
        summary["Std"].append(values.std())

    return pd.DataFrame(summary), all_metrics


# Calculate metrics
n = 108
metrics_df, all_metrics = calculate_metrics(generated_mols[:n], ground_truth_mols[:n])  # , pocket_raw[:n])

# Display table
display(metrics_df)

Unnamed: 0,Metric,Mean,Median,Std
0,validity,0.888889,1.0,0.31427
1,fc_validity,0.888889,1.0,0.31427
2,uniqueness,1.0,1.0,0.0
3,energy_validity,0.888889,1.0,0.31427
4,average_energy,17602.753906,660.838135,41239.601562
5,average_energy_per_atom,386.339783,38.085571,839.565735
6,average_strain,17598.210938,668.529297,41254.933594
7,average_strain_per_atom,386.176666,38.813705,839.913513
8,average_opt_rmsd,1.742558,1.559513,0.994739
9,molecular_accuracy,1.0,1.0,0.0


In [23]:
len([i.item() for i in all_metrics['pair_no_align_rmsd'] if i.item() < 2]) / len(all_metrics['pair_no_align_rmsd'])

0.16666666666666666

## Visualize Molecules

Let's visualize some of the generated molecules alongside their ground truth counterparts

In [11]:
def visualize_molecule_2d(mol, title="Molecule"):
    """Visualize an RDKit molecule in 2D"""
    if mol is None:
        return HTML(f"<p>{title}: Invalid molecule</p>")
    
    mol = Chem.Mol(mol)
    AllChem.Compute2DCoords(mol)
    img = Draw.MolToImage(mol, size=(300, 300))
    
    plt.figure(figsize=(3, 3))
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def visualize_molecule_3d(mol, width=400, height=400, style="stick"):
    """Visualize an RDKit molecule in 3D using py3Dmol"""
    if mol is None:
        return HTML("<p>Invalid molecule</p>")
    
    mol = Chem.Mol(mol)
    
    viewer = py3Dmol.view(width=width, height=height)
    mb = Chem.MolToMolBlock(mol)
    viewer.addModel(mb, 'mol')
    viewer.setStyle({style:{}})
    viewer.zoomTo()
    viewer.render()
    return viewer

def compare_molecules_3d(gen_mol, ref_mol, align=True, width=400, height=400):
    """Compare generated and reference molecules in 3D"""
    if gen_mol is None or ref_mol is None:
        return HTML("<p>One or more invalid molecules</p>")
    
    gen_mol = Chem.Mol(gen_mol)
    ref_mol = Chem.Mol(ref_mol)
    
    # Align molecules if requested
    if align:
        rdMolAlign.AlignMol(gen_mol, ref_mol)
    
    viewer = py3Dmol.view(width=width, height=height)
    
    # Add reference molecule (green)
    mb_ref = Chem.MolToMolBlock(ref_mol)
    viewer.addModel(mb_ref, 'ref')
    viewer.setStyle({'ref': {'stick': {'color': 'green'}}})
    
    # Add generated molecule (blue)
    mb_gen = Chem.MolToMolBlock(gen_mol)
    viewer.addModel(mb_gen, 'gen')
    viewer.setStyle({'gen': {'stick': {'color': 'blue'}}})
    
    viewer.zoomTo()
    viewer.render()
    return viewer

# Visualize a few molecules
for i in range(len(generated_mols)):
    if generated_mols[i] is not None and ground_truth_mols[i] is not None:
        print(f"\nMolecule {i+1}")
        print("Generated SMILES:", Chem.MolToSmiles(generated_mols[i]) if generated_mols[i] else "Invalid")
        print("Reference SMILES:", Chem.MolToSmiles(ground_truth_mols[i]) if ground_truth_mols[i] else "Invalid")

        # 3D visualization of generated molecule
        print("\n3D Structure (Generated):")
        gen_view = visualize_molecule_3d(generated_mols[i], width=400, height=400)
        display(gen_view)
        
        # 3D visualization of reference molecule
        print("\n3D Structure (Reference):")
        ref_view = visualize_molecule_3d(ground_truth_mols[i], width=400, height=400)
        display(ref_view)
        # Calculate RMSD between the molecules with and without alignment
        pair_no_align_rmsd_metric = Metrics.MolecularPairRMSD(align=False)
        pair_no_align_rmsd = pair_no_align_rmsd_metric([generated_mols[i]], [ground_truth_mols[i]])
        print(f"Pair RMSD (no alignment): {pair_no_align_rmsd:.3f} Å")


Molecule 1
Generated SMILES: O=P(O)(O)OCC1OC(OP(=O)(O)O)C(O)C1O
Reference SMILES: O=P(O)(O)OCC1OC(OP(=O)(O)O)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x14930383e850>


3D Structure (Reference):


<py3Dmol.view at 0x149303013690>

Pair RMSD (no alignment): 1.221 Å

Molecule 2
Generated SMILES: CC(=O)NC1C(O)OC(CO)C(O)C1O
Reference SMILES: CC(=O)NC1C(O)OC(CO)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x1492c49af450>


3D Structure (Reference):


<py3Dmol.view at 0x14930383e850>

Pair RMSD (no alignment): 5.369 Å

Molecule 3
Generated SMILES: Nc1ncnc2c1ncn2C1OC(COP(=O)(O)OC(=O)CCCCC2CCSS2)C(O)C1O
Reference SMILES: Nc1ncnc2c1ncn2C1OC(COP(=O)(O)OC(=O)CCCCC2CCSS2)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x1492c4729e50>


3D Structure (Reference):


<py3Dmol.view at 0x1492c49af450>

Pair RMSD (no alignment): 7.508 Å

Molecule 5
Generated SMILES: CCC(=O)OC(CCOCc1ccccc1)C(C)C(O)C(C)CCC(=O)C(C)C(OC(C)=O)C(C)CCN(C)C=O
Reference SMILES: CCC(=O)OC(CCOCc1ccccc1)C(C)C(O)C(C)CCC(=O)C(C)C(OC(C)=O)C(C)CCN(C)C=O

3D Structure (Generated):


<py3Dmol.view at 0x1493031be310>


3D Structure (Reference):


<py3Dmol.view at 0x1492c4729e50>

Pair RMSD (no alignment): 12.255 Å

Molecule 6
Generated SMILES: CC12CCC(O)CC1CCC1C2CCC2(C)C(c3ccc(=O)oc3)CCC12O
Reference SMILES: CC12CCC(O)CC1CCC1C2CCC2(C)C(c3ccc(=O)oc3)CCC12O

3D Structure (Generated):


<py3Dmol.view at 0x1493034d8a10>


3D Structure (Reference):


<py3Dmol.view at 0x1493031be310>

Pair RMSD (no alignment): 8.107 Å

Molecule 7
Generated SMILES: NCC1=C2C(=O)N=C(N)N=C2N=C1
Reference SMILES: NCC1=C2C(=O)N=C(N)N=C2N=C1

3D Structure (Generated):


<py3Dmol.view at 0x1493031bd190>


3D Structure (Reference):


<py3Dmol.view at 0x149305422650>

Pair RMSD (no alignment): 1.151 Å

Molecule 8
Generated SMILES: CC(=O)NC1C(O)OC(CO)C(O)C1O
Reference SMILES: CC(=O)NC1C(O)OC(CO)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x149417b108d0>


3D Structure (Reference):


<py3Dmol.view at 0x1493031bd190>

Pair RMSD (no alignment): 4.098 Å

Molecule 9
Generated SMILES: CNCC1CCCCC1CN1CNCN1
Reference SMILES: CNCC1CCCCC1CN1CNCN1

3D Structure (Generated):


<py3Dmol.view at 0x149302fe9d10>


3D Structure (Reference):


<py3Dmol.view at 0x149417b108d0>

Pair RMSD (no alignment): 3.855 Å

Molecule 10
Generated SMILES: CC=C(C)C(=O)OC1C(C)=C2C(C1OC(=O)CCCCCCC)C(C)(OC(C)=O)CC(OC(=O)CCC)C1(O)C2OC(=O)C1(C)O
Reference SMILES: CC=C(C)C(=O)OC1C(C)=C2C(C1OC(=O)CCCCCCC)C(C)(OC(C)=O)CC(OC(=O)CCC)C1(O)C2OC(=O)C1(C)O

3D Structure (Generated):


<py3Dmol.view at 0x1493031bd810>


3D Structure (Reference):


<py3Dmol.view at 0x1492c4728a50>

Pair RMSD (no alignment): 12.104 Å

Molecule 11
Generated SMILES: O=C(O)CCC(NC(=O)COc1cccc(C(=O)O)c1)C(=O)O
Reference SMILES: O=C(O)CCC(NC(=O)COc1cccc(C(=O)O)c1)C(=O)O

3D Structure (Generated):


<py3Dmol.view at 0x149417fb47d0>


3D Structure (Reference):


<py3Dmol.view at 0x1493049806d0>

Pair RMSD (no alignment): 2.659 Å

Molecule 12
Generated SMILES: NC(=[NH2+])NCCCC(N)C(=O)O
Reference SMILES: NC(=[NH2+])NCCCC(N)C(=O)O

3D Structure (Generated):


<py3Dmol.view at 0x1492c46d9310>


3D Structure (Reference):


<py3Dmol.view at 0x1492c47a3450>

Pair RMSD (no alignment): 1.337 Å

Molecule 13
Generated SMILES: CC1OC(OC2CC(O)C3(CO)C4C(O)CC5(C)C(C6=CC(=O)OC6)CCC5(O)C4CCC3(O)C2)C(O)C(O)C1O
Reference SMILES: CC1OC(OC2CC(O)C3(CO)C4C(O)CC5(C)C(C6=CC(=O)OC6)CCC5(O)C4CCC3(O)C2)C(O)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x149303013690>


3D Structure (Reference):


<py3Dmol.view at 0x149302fdad90>

Pair RMSD (no alignment): 11.116 Å

Molecule 14
Generated SMILES: CC(=O)NC1C2OCC(O2)C(O)C1OC(C)C(=O)O
Reference SMILES: CC(=O)NC1C2OCC(O2)C(O)C1OC(C)C(=O)O

3D Structure (Generated):


<py3Dmol.view at 0x1492c4729f10>


3D Structure (Reference):


<py3Dmol.view at 0x149302fdbb50>

Pair RMSD (no alignment): 4.943 Å

Molecule 15
Generated SMILES: NCCCC(N)CC(=O)NCCCC(N)CC(=O)NCCCC(N)CC(=O)NC1C(NC2=NC3C(=O)NCC(O)C3N2)OC(CO)C(OC(N)=O)C1O
Reference SMILES: NCCCC(N)CC(=O)NCCCC(N)CC(=O)NCCCC(N)CC(=O)NC1C(NC2=NC3C(=O)NCC(O)C3N2)OC(CO)C(OC(N)=O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x149302fea510>


3D Structure (Reference):


<py3Dmol.view at 0x149304981750>

Pair RMSD (no alignment): 14.959 Å

Molecule 16
Generated SMILES: CCCCNC1CC(O)(CO)C(O)C(O)C1O
Reference SMILES: CCCCNC1CC(O)(CO)C(O)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x149302fdad10>


3D Structure (Reference):


<py3Dmol.view at 0x1492c47afa90>

Pair RMSD (no alignment): 4.236 Å

Molecule 17
Generated SMILES: Nc1ncnc2c1ncn2C1OC(COS(=O)(=O)NC(=O)C(N)CCC(=O)O)C(O)C1O
Reference SMILES: Nc1ncnc2c1ncn2C1OC(COS(=O)(=O)NC(=O)C(N)CCC(=O)O)C(O)C1O

3D Structure (Generated):


<py3Dmol.view at 0x1493031bee50>


3D Structure (Reference):


<py3Dmol.view at 0x1492c4870b50>

Pair RMSD (no alignment): 3.529 Å


## Complex Visualization (for Protein-Ligand Complexes)

If we're working with protein-ligand complexes, let's visualize the binding poses

In [24]:
%load_ext autoreload
%autoreload 2

from cgflow.util.visualize import complex_to_3dview

# Visualize protein-ligand complexes if applicable
if IS_COMPLEX and pocket_raw is not None:
    for i in range(min(3, len(generated_mols))):
        if generated_mols[i] is not None:
            print(f"\nProtein-Ligand Complex {i+1}")
            view = complex_to_3dview(generated_mols[i], pocket_raw[i])
            display(view)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Protein-Ligand Complex 1


<py3Dmol.view at 0x14e660cf91d0>

In [13]:

import os

# Create output directories
output_dir = "../temp"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "molecules"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "proteins"), exist_ok=True)

# Save generated molecules to SDF
from rdkit import Chem
for i, mol in enumerate(generated_mols):
    if mol is not None:
        # Save individual molecule
        mol = Chem.AddHs(mol, addCoords=True)
        mol_path = os.path.join(output_dir, "molecules", f"molecule_{i+1}.sdf")
        writer = Chem.SDWriter(mol_path)
        writer.write(mol)
        writer.close()

# Save all molecules to a single SDF file
all_mols_path = os.path.join(output_dir, "all_molecules.sdf")
writer = Chem.SDWriter(all_mols_path)
for mol in generated_mols:
    if mol is not None:
        writer.write(mol)
writer.close()
print(f"Saved {sum(1 for mol in generated_mols if mol is not None)} molecules to {output_dir}")

# Save protein pockets to PDB
if pocket_raw and len(pocket_raw) > 0:
    for i, pocket in enumerate(pocket_raw):
        if pocket is not None:
            pocket_path = os.path.join(output_dir, "proteins", f"pocket_{i+1}.pdb")
            pocket.write_pdb(pocket_path)
    print(f"Saved {sum(1 for p in pocket_raw if p is not None)} protein pockets to {output_dir}")

print(f"All files saved to {output_dir}")

Saved 8 molecules to ../temp
Saved 9 protein pockets to ../temp
All files saved to ../temp


## Conclusion

In this notebook, we loaded a trained SEMLAFLOW model, generated molecular conformations, evaluated them with several metrics, and visualized both 2D and 3D structures including complex (e.g. protein-ligand) representations.