# SEMLAFLOW Model Evaluation

This notebook demonstrates how to:
1. Load a trained SEMLAFLOW model
2. Generate molecular conformations
3. Evaluate the quality of generated molecules
4. Visualize and analyze the results

## Setup

In [1]:
import sys
import posecheck
import numpy as np
import torch
import matplotlib.pyplot as plt
import lightning as L
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdMolAlign
import py3Dmol

# Turn off rdkit logging
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Import SEMLAFLOW modules
sys.path.append('../../..')
import src.cgflow.scriptutil as util
from src.cgflow.buildutil import build_dm, build_model
import src.cgflow.util.metrics as Metrics
import src.cgflow.util.complex_metrics as ComplexMetrics

# Set torch properties for consistency with the evaluation script
torch.set_float32_matmul_precision("high")
L.seed_everything(12345)

MDAnalysis.topology.tables has been moved to MDAnalysis.guesser.tables. This import point will be removed in MDAnalysis version 3.0.0
[rank: 0] Seed set to 12345


12345

## Configuration

Define paths and parameters for model evaluation

In [2]:
# Path to the model checkpoint
MODEL_CHECKPOINT_PATH = "/home/to.shen/projects/CGFlow/wandb/equinv-plinder/ig1k2kp4/checkpoints/last.ckpt"

# Path to validation data
DATA_PATH = "/home/to.shen/projects/CGFlow/data/complex/topo2/smol"

# Dataset name
DATASET = "plinder"

# Number of molecules to evaluate
NUM_EVAL_MOLS = np.inf

# Number of inference steps
NUM_INFERENCE_STEPS = 100

# Whether the data involves protein-ligand complexes
IS_COMPLEX = DATASET in ["plinder", "crossdock", "zinc15m"] or False

# Create a class to simulate command line arguments
class Args:
    def __init__(self):
        pass

args = Args()

# Set required arguments
args.model_checkpoint = MODEL_CHECKPOINT_PATH
args.data_path = DATA_PATH
args.dataset = DATASET
args.n_validation_mols = NUM_EVAL_MOLS
args.num_inference_steps = NUM_INFERENCE_STEPS
args.num_gpus = 1
args.is_pseudo_complex = False
args.batch_cost = 1200
args.use_complex_metrics = IS_COMPLEX
args.sampling_strategy = "linear"
args.num_workers = 0

# Model architecture parameters - these should match the trained model
args.d_model = 384
args.n_layers = 12
args.d_message = 64
args.d_edge = 128
args.n_coord_sets = 64
args.n_attn_heads = 32
args.d_message_hidden = 96
args.coord_norm = "length"
args.size_emb = 64
args.max_atoms = 256
args.pocket_n_layers = 4
args.pocket_d_inv = 256
args.fixed_equi = False

# Flow matching parameters
args.categorical_strategy = "auto-regressive"
args.conf_coord_strategy = "gaussian"
args.optimal_transport = None
args.cat_sampling_noise_level = 1
args.coord_noise_std_dev = 0.2
args.type_dist_temp = 1.0
args.time_alpha = 1.0
args.time_beta = 1.0
args.t_per_ar_action = 0.3
args.max_interp_time = 0.4
args.max_action_t = 0.6
args.max_num_cuts = 2
args.dist_loss_weight = 0.0
args.type_loss_weight = 0.0
args.bond_loss_weight = 0.0
args.charge_loss_weight = 0.0
args.monitor = "val-strain"
args.monitor_mode = "min"
args.val_check_epochs = 1


# Autoregressive parameters (only needed if model was trained with AR)
args.t_per_ar_action = 0.3  # updated
args.max_interp_time = 0.4  # updated
args.decomposition_strategy = "reaction"  # updated
args.ordering_strategy = "connected"  # updated
args.max_action_t = 0.6  # updated
args.max_num_cuts = 2  # updated
args.min_group_size = 5

# Model loading defaults
args.arch = "semla"
args.trial_run = False
args.use_ema = True
args.self_condition = True
args.lr = 0.0003
args.type_loss_weight = 0.0  # updated
args.bond_loss_weight = 0.0  # updated
args.charge_loss_weight = 0.0  # updated
args.dist_loss_weight = 0.0  # updated
args.lr_schedule = "constant"
args.warm_up_steps = 10000
args.bucket_cost_scale = "linear"
args.epochs = 1
args.acc_batches = 1
args.val_check_epochs = 1  # updated
args.gradient_clip_val = 1.0
args.monitor = "val-strain"  # updated
args.monitor_mode = "min"  # updated

args.n_training_mols = np.inf

## Load Model

Now let's load the trained model from the checkpoint

In [3]:
def load_model(args):
    print("Building vocabulary...")
    vocab = util.build_vocab()
    
    print("Loading validation datamodule...")
    dm = build_dm(args, vocab)
    
    print("Building model from checkpoint...")
    model = build_model(args, dm, vocab)
    
    print(f"Loading checkpoint from {args.model_checkpoint}...")
    checkpoint = torch.load(args.model_checkpoint, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    
    # Set model to eval mode and move to GPU if available
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Set inference parameters
    model.integrator.steps = int(args.num_inference_steps)
    model.sampling_strategy = args.sampling_strategy
    
    return model, dm, vocab

# Load the model
model, dm, vocab = load_model(args)
print("Model loaded successfully!")


Building vocabulary...
Loading validation datamodule...
Using type ARGeometricComplexInterpolant for training
Building model from checkpoint...

items per bucket [94233, 0, 0, 0, 0, 0]
bucket batch sizes [9, 7, 4, 2, 1, 1]
batches per bucket [10471, 0, 0, 0, 0, 0]
Total training steps 10471
Using model class LigandGenerator
Using CFM class ARMolecularCFM
Loading checkpoint from /home/to.shen/projects/CGFlow/wandb/equinv-plinder/ig1k2kp4/checkpoints/last.ckpt...
Model loaded successfully!


## Generate Molecular Conformations

Now we'll use the model to generate molecular conformations from the validation dataset

In [4]:
from tqdm import tqdm

def prepare_batch_for_generation(batch, device='cuda'):
    """Prepare a batch from the dataloader for generation"""
    pocket = None
    pocket_raw = None
    if len(batch) == 4:
        prior, data, interpolated, times = batch
        
    elif len(batch) == 6:
        prior, data, interpolated, pocket, pocket_raw, times = batch
    elif len(batch) == 7:  # AR model
        prior, data, interpolated, masked_data, times, rel_times, gen_times = batch
    elif len(batch) == 9:  # AR model with complex
        prior, data, interpolated, masked_data, pocket, pocket_raw, times, rel_times, gen_times = batch
    else:
        raise ValueError(f"Unexpected batch format with {len(batch)} elements")
    
    prior = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in prior.items()
    }
    data = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in data.items()
    }
    pocket = {
        k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in pocket.items()
    } if pocket is not None else None
    
    return prior, data, pocket, pocket_raw

def generate_molecules(model, dataloader, num_samples=5):
    """Generate molecular conformations using the model"""
    generated_mols = []
    ground_truth_mols = []
    pocket_data = []
    pocket_raws = []
    
    count = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            prior, data, pocket, pocket_raw = prepare_batch_for_generation(batch)
            # Ensure inputs are on the same device as the model
            device = next(model.parameters()).device
            # Generate molecules
            if args.categorical_strategy == "auto-regressive":
                # AR specific generation
                gen_batch = model._generate(prior, batch[-1].to(device), model.integrator.steps, 
                                                model.sampling_strategy, pocket_batch=pocket)
            else:
                # Standard generation
                gen_batch = model._generate(prior, model.integrator.steps, 
                                                model.sampling_strategy, pocket_batch=pocket)
            
            # Convert generated tensors to RDKit molecules
            gen_mols = model._generate_mols(gen_batch)
            
            # Get ground truth molecules
            data = model._batch_to_onehot(data)
            data_mols = model._generate_mols(data, rescale=True)
            
            # Add molecules to lists
            generated_mols.extend(gen_mols)
            ground_truth_mols.extend(data_mols)
            
            # Store pocket data if available
            if len(batch) == 6 or len(batch) == 9:
                pocket_data.append(pocket)
                pocket_raws.extend(pocket_raw)
            
            count += len(gen_mols)
            if count >= num_samples:
                break
    
    return generated_mols, ground_truth_mols, pocket_data if pocket_data else None, pocket_raws if pocket_raws else None

# Prepare dataloader
dataloader = dm.val_dataloader()

# Generate molecules
generated_mols, ground_truth_mols, pocket_data, pocket_raw = generate_molecules(model, dataloader, num_samples=8)

print(f"Generated {len(generated_mols)} molecules")
print(f"Retrieved {len(ground_truth_mols)} ground truth molecules")


items per bucket [4960, 0, 0, 0, 0, 0]
bucket batch sizes [9, 7, 4, 2, 1, 1]
batches per bucket [552, 0, 0, 0, 0, 0]


  0%|          | 0/552 [00:00<?, ?it/s]

coords torch.Size([9, 32, 3]) torch.Size([9, 32, 3])
atomics torch.Size([9, 32]) torch.Size([9, 32, 18])
bonds torch.Size([9, 32, 32]) torch.Size([9, 32, 32, 5])
charges torch.Size([9, 32]) torch.Size([9, 32, 7])
residues torch.Size([9, 32]) torch.Size([9, 32])
mask torch.Size([9, 32]) torch.Size([9, 32])


  0%|          | 0/552 [00:00<?, ?it/s]


AttributeError: 'LigandDecoder' object has no attribute 'pocket_invs'


Let's calculate quality metrics for the generated molecules

In [None]:
%load_ext autoreload
%autoreload 2

def calculate_metrics(generated_mols, reference_mols=None, pocket_raw=None):
    """Calculate quality metrics for generated molecules"""
    metrics = {}
    
    # Basic validity metrics
    validity_metric = Metrics.Validity()
    validity = validity_metric(generated_mols)
    metrics["validity"] = float(validity)
    
    # Connected molecules validity
    fc_validity_metric = Metrics.Validity(connected=True)
    fc_validity = fc_validity_metric(generated_mols)
    metrics["fc_validity"] = float(fc_validity)
    
    # Uniqueness
    uniqueness_metric = Metrics.Uniqueness()
    uniqueness = uniqueness_metric(generated_mols)
    metrics["uniqueness"] = float(uniqueness)
    
    # Energy-based metrics
    energy_validity_metric = Metrics.EnergyValidity()
    energy_validity = energy_validity_metric(generated_mols)
    metrics["energy_validity"] = float(energy_validity)
    
    energy_metric = Metrics.AverageEnergy()
    energy = energy_metric(generated_mols)
    metrics["average_energy"] = float(energy)
    
    energy_per_atom_metric = Metrics.AverageEnergy(per_atom=True)
    energy_per_atom = energy_per_atom_metric(generated_mols)
    metrics["average_energy_per_atom"] = float(energy_per_atom)
    
    # Strain energy
    strain_metric = Metrics.AverageStrainEnergy()
    strain = strain_metric(generated_mols)
    metrics["average_strain"] = float(strain)
    
    strain_per_atom_metric = Metrics.AverageStrainEnergy(per_atom=True)
    strain_per_atom = strain_per_atom_metric(generated_mols)
    metrics["average_strain_per_atom"] = float(strain_per_atom)
    
    # RMSD after optimization
    opt_rmsd_metric = Metrics.AverageOptRmsd()
    opt_rmsd = opt_rmsd_metric(generated_mols)
    metrics["average_opt_rmsd"] = float(opt_rmsd)
    

    # If we have reference molecules, calculate additional metrics
    if reference_mols:
        # Molecular accuracy
        mol_accuracy_metric = Metrics.MolecularAccuracy()
        mol_accuracy = mol_accuracy_metric(generated_mols, reference_mols)
        metrics["molecular_accuracy"] = float(mol_accuracy)
        
        # RMSD between pairs
        pair_rmsd_metric = Metrics.MolecularPairRMSD()
        pair_rmsd = pair_rmsd_metric(generated_mols, reference_mols)
        metrics["pair_rmsd"] = float(pair_rmsd)
        
        pair_no_align_rmsd_metric = Metrics.MolecularPairRMSD(align=False)
        pair_no_align_rmsd = pair_no_align_rmsd_metric(generated_mols, reference_mols)
        metrics["pair_no_align_rmsd"] = float(pair_no_align_rmsd)
        
    # If we have pocket data, calculate complex metrics
    if pocket_data and IS_COMPLEX:
        # Clash score
        clash_metric = ComplexMetrics.Clash()
        clash = clash_metric(generated_mols, pocket_raw)
        metrics["clash_score"] = float(clash)
        
        # Interactions
        interactions_metric = ComplexMetrics.Interactions()
        interactions = interactions_metric(generated_mols, pocket_raw)
        for key, value in interactions.items():
            metrics[f"interactions_{key}"] = float(value)
        
        
    return metrics

# Calculate metrics
n = 1
metrics = calculate_metrics(generated_mols[:n], ground_truth_mols[:n], pocket_raw[:n])

# Display metrics as a table
metrics_df = pd.DataFrame({"Metric": list(metrics.keys()), "Value": list(metrics.values())})
display(metrics_df)



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Metric,Value
0,validity,1.0
1,fc_validity,1.0
2,uniqueness,1.0
3,energy_validity,1.0
4,average_energy,269.753357
5,average_energy_per_atom,10.652372
6,average_strain,262.232788
7,average_strain_per_atom,10.233875
8,average_opt_rmsd,1.051045
9,molecular_accuracy,1.0


In [5]:
import tempfile
from cgflow.utils import gnina
from openbabel import pybel

with tempfile.TemporaryDirectory() as tmp_dir:
    for mol, pocket in zip(generated_mols, pocket_raw):
        # write mol to sdf 
        writer = Chem.SDWriter(tmp_dir + '/molecule.sdf')
        writer.write(mol)
        writer.close()
        
        pocket.write_pdb(tmp_dir + '/pocket.pdb')
                
        pocket = next(pybel.readfile("pdb", tmp_dir + '/pocket.pdb'))
        pocket.write("pdbqt", tmp_dir + '/pocket.pdbqt', overwrite=True)
        pocket_pdbqt_path = tmp_dir + '/pocket.pdbqt'
        
        # run docking
        scores = gnina.local_opt(tmp_dir + '/molecule.sdf', pocket_pdbqt_path, tmp_dir + '/result.sdf', num_workers=1)
        print(scores)
        

NameError: name 'generated_mols' is not defined

## Visualize Molecules

Let's visualize some of the generated molecules alongside their ground truth counterparts

In [8]:
def visualize_molecule_2d(mol, title="Molecule"):
    """Visualize an RDKit molecule in 2D"""
    if mol is None:
        return HTML(f"<p>{title}: Invalid molecule</p>")
    
    mol = Chem.Mol(mol)
    AllChem.Compute2DCoords(mol)
    img = Draw.MolToImage(mol, size=(300, 300))
    
    plt.figure(figsize=(3, 3))
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def visualize_molecule_3d(mol, width=400, height=400, style="stick"):
    """Visualize an RDKit molecule in 3D using py3Dmol"""
    if mol is None:
        return HTML("<p>Invalid molecule</p>")
    
    mol = Chem.Mol(mol)
    
    viewer = py3Dmol.view(width=width, height=height)
    mb = Chem.MolToMolBlock(mol)
    viewer.addModel(mb, 'mol')
    viewer.setStyle({style:{}})
    viewer.zoomTo()
    viewer.render()
    return viewer

def compare_molecules_3d(gen_mol, ref_mol, align=True, width=400, height=400):
    """Compare generated and reference molecules in 3D"""
    if gen_mol is None or ref_mol is None:
        return HTML("<p>One or more invalid molecules</p>")
    
    gen_mol = Chem.Mol(gen_mol)
    ref_mol = Chem.Mol(ref_mol)
    
    # Align molecules if requested
    if align:
        rdMolAlign.AlignMol(gen_mol, ref_mol)
    
    viewer = py3Dmol.view(width=width, height=height)
    
    # Add reference molecule (green)
    mb_ref = Chem.MolToMolBlock(ref_mol)
    viewer.addModel(mb_ref, 'ref')
    viewer.setStyle({'ref': {'stick': {'color': 'green'}}})
    
    # Add generated molecule (blue)
    mb_gen = Chem.MolToMolBlock(gen_mol)
    viewer.addModel(mb_gen, 'gen')
    viewer.setStyle({'gen': {'stick': {'color': 'blue'}}})
    
    viewer.zoomTo()
    viewer.render()
    return viewer

# Visualize a few molecules
for i in range(len(generated_mols)):
    if generated_mols[i] is not None and ground_truth_mols[i] is not None:
        print(f"\nMolecule {i+1}")
        print("Generated SMILES:", Chem.MolToSmiles(generated_mols[i]) if generated_mols[i] else "Invalid")
        print("Reference SMILES:", Chem.MolToSmiles(ground_truth_mols[i]) if ground_truth_mols[i] else "Invalid")

        # 3D visualization of generated molecule
        print("\n3D Structure (Generated):")
        gen_view = visualize_molecule_3d(generated_mols[i], width=400, height=400)
        display(gen_view)
        
        # 3D visualization of reference molecule
        print("\n3D Structure (Reference):")
        ref_view = visualize_molecule_3d(ground_truth_mols[i], width=400, height=400)
        display(ref_view)
        # Calculate RMSD between the molecules with and without alignment
        pair_no_align_rmsd_metric = Metrics.MolecularPairRMSD(align=False)
        pair_no_align_rmsd = pair_no_align_rmsd_metric([generated_mols[i]], [ground_truth_mols[i]])
        print(f"Pair RMSD (no alignment): {pair_no_align_rmsd:.3f} Ã…")


Molecule 1
Generated SMILES: COCCC1CCN(C(=O)C2(C(=O)C3CCC[NH2+]C3)CCC(C(C)C)CC2)C1
Reference SMILES: COCCC1CCN(C(=O)C2(C(=O)C3CCC[NH2+]C3)CCC(C(C)C)CC2)C1

3D Structure (Generated):


<py3Dmol.view at 0x1468c496b6d0>


3D Structure (Reference):


<py3Dmol.view at 0x14679473aad0>

Pair RMSD (no alignment): 3.448 Ã…

Molecule 2
Generated SMILES: CC(=O)C1(C(=O)N2CC(C3COC3)C2)CC1c1cnccn1
Reference SMILES: CC(=O)C1(C(=O)N2CC(C3COC3)C2)CC1c1cnccn1

3D Structure (Generated):


<py3Dmol.view at 0x1468c4878150>


3D Structure (Reference):


<py3Dmol.view at 0x1468c496b6d0>

Pair RMSD (no alignment): 7.517 Ã…

Molecule 3
Generated SMILES: O=C(C=C([O-])C1(F)CC2C=CC1C2)C1CCCCN1
Reference SMILES: O=C(C=C([O-])C1(F)CC2C=CC1C2)C1CCCCN1

3D Structure (Generated):


<py3Dmol.view at 0x1468c486f550>


3D Structure (Reference):


<py3Dmol.view at 0x1468c4878150>

Pair RMSD (no alignment): 6.189 Ã…

Molecule 4
Generated SMILES: CC1(C)CC(Cn2cc(C3C4CCC(C4)N3C(=O)C3(F)CCCCC3)nn2)CCO1
Reference SMILES: CC1(C)CC(Cn2cc(C3C4CCC(C4)N3C(=O)C3(F)CCCCC3)nn2)CCO1

3D Structure (Generated):


<py3Dmol.view at 0x1468c487af50>


3D Structure (Reference):


<py3Dmol.view at 0x1468c486f550>

Pair RMSD (no alignment): 8.427 Ã…

Molecule 6
Generated SMILES: [NH3+]C1CC(F)(C(=O)N2CC3CC(C2)C3C(=O)N2CCC(C3CC3)CC2)C1
Reference SMILES: [NH3+]C1CC(F)(C(=O)N2CC3CC(C2)C3C(=O)N2CCC(C3CC3)CC2)C1

3D Structure (Generated):


<py3Dmol.view at 0x14679cc10c50>


3D Structure (Reference):


<py3Dmol.view at 0x146793462310>

Pair RMSD (no alignment): 5.894 Ã…

Molecule 7
Generated SMILES: O=C([O-])C1CCC2(CN(C(=O)C3(NC(=O)C4COCCN4)CCCC3)C2)NC1
Reference SMILES: O=C([O-])C1CCC2(CN(C(=O)C3(NC(=O)C4COCCN4)CCCC3)C2)NC1

3D Structure (Generated):


<py3Dmol.view at 0x1468c4a42610>


3D Structure (Reference):


<py3Dmol.view at 0x14679cc10c50>

Pair RMSD (no alignment): 1.249 Ã…

Molecule 8
Generated SMILES: CC1CN(C(=O)C2CCCCC2NC(=O)C2CCC([NH3+])C2)CC1C(=O)[O-]
Reference SMILES: CC1CN(C(=O)C2CCCCC2NC(=O)C2CCC([NH3+])C2)CC1C(=O)[O-]

3D Structure (Generated):


<py3Dmol.view at 0x1468c486e7d0>


3D Structure (Reference):


<py3Dmol.view at 0x1468c497e810>

Pair RMSD (no alignment): 1.814 Ã…

Molecule 9
Generated SMILES: CC1(C(=O)N2CCC3CC(O)CC3C2)CCC2(CN(C(=O)C3(C)CC4C[NH2+]CCC43)C2)O1
Reference SMILES: CC1(C(=O)N2CCC3CC(O)CC3C2)CCC2(CN(C(=O)C3(C)CC4C[NH2+]CCC43)C2)O1

3D Structure (Generated):


<py3Dmol.view at 0x1468c48710d0>


3D Structure (Reference):


<py3Dmol.view at 0x1468c4870310>

Pair RMSD (no alignment): 9.179 Ã…


## Complex Visualization (for Protein-Ligand Complexes)

If we're working with protein-ligand complexes, let's visualize the binding poses

In [9]:
%load_ext autoreload
%autoreload 2

from cgflow.util.visualize import complex_to_3dview

# Visualize protein-ligand complexes if applicable
if IS_COMPLEX and pocket_raw is not None:
    for i in range(min(3, len(generated_mols))):
        if generated_mols[i] is not None:
            print(f"\nProtein-Ligand Complex {i+1}")
            view = complex_to_3dview(generated_mols[i], pocket_raw[i])
            display(view)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Protein-Ligand Complex 1


<py3Dmol.view at 0x146793315550>


Protein-Ligand Complex 2


<py3Dmol.view at 0x146794739a90>


Protein-Ligand Complex 3


<py3Dmol.view at 0x146793315550>

In [10]:

import os

# Create output directories
output_dir = "../temp"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "molecules"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "proteins"), exist_ok=True)

# Save generated molecules to SDF
from rdkit import Chem
for i, mol in enumerate(generated_mols):
    if mol is not None:
        # Save individual molecule
        mol = Chem.AddHs(mol, addCoords=True)
        mol_path = os.path.join(output_dir, "molecules", f"molecule_{i+1}.sdf")
        writer = Chem.SDWriter(mol_path)
        writer.write(mol)
        writer.close()

# Save all molecules to a single SDF file
all_mols_path = os.path.join(output_dir, "all_molecules.sdf")
writer = Chem.SDWriter(all_mols_path)
for mol in generated_mols:
    if mol is not None:
        writer.write(mol)
writer.close()
print(f"Saved {sum(1 for mol in generated_mols if mol is not None)} molecules to {output_dir}")

# Save protein pockets to PDB
if pocket_raw and len(pocket_raw) > 0:
    for i, pocket in enumerate(pocket_raw):
        if pocket is not None:
            pocket_path = os.path.join(output_dir, "proteins", f"pocket_{i+1}.pdb")
            pocket.write_pdb(pocket_path)
    print(f"Saved {sum(1 for p in pocket_raw if p is not None)} protein pockets to {output_dir}")

print(f"All files saved to {output_dir}")

Saved 8 molecules to ../temp
Saved 9 protein pockets to ../temp
All files saved to ../temp


## Conclusion

In this notebook, we loaded a trained SEMLAFLOW model, generated molecular conformations, evaluated them with several metrics, and visualized both 2D and 3D structures including complex (e.g. protein-ligand) representations.