Skip to content

seonghann/FragFM3D

Repository files navigation

FragFM3D

Fragment-based Flow Matching for 3D Structure-Based Drug Design.

Decomposes protein-ligand complexes into fragments, learns fragment-level representations via a hierarchical SE(3) transformer, and generates 3D structures via VP-SDE diffusion + discrete jump processes.

Architecture

  • FragFM3D: Hierarchical transformer with fragment-level and atom-level processing
    • FragmentEmbedder (MPNN) → frag_projHierarchicalTransformerTypeHead + EdgeHead
    • 7 types of hierarchical edges (frag-frag, atom-atom, cross-level)
    • Per-node timestep support (unified training mode)
  • Diffusion: VP-SDE for continuous (positions), CTMC jump process for discrete (types, edges)
  • Training modes: diffusion (default), initializer (atom position pre-training), unified (per-node timestep, single-pass)

Setup

# Create conda environment
conda env create -f environment.yml
conda activate fragfm3d

# Install pip dependencies
pip install -r requirements.txt
pip install -r requirements-pyg.txt
pip install -e .

# Verify
pytest test/test_imports.py -v

Requires Python >= 3.10, CUDA 12.8, PyTorch 2.8.0, PyTorch Geometric 2.7.0.

Data

Download

CrossDocked2020 (small molecules):

mkdir -p data/crossdocked && cd data/crossdocked
gdown https://drive.google.com/drive/folders/1j21cc7-97TedKh_El5E34yI8o5ckI7eK --folder
if [ -d "data" ]; then mv data/* . && rmdir data; fi
tar -xzvf crossdocked_v1.1_rmsd1.0.tar.gz && rm crossdocked_v1.1_rmsd1.0.tar.gz
unzip test_set.zip && rm test_set.zip
cd ../..

Peptides (PepBench):

mkdir -p data/pepbench && cd data/pepbench
aria2c -x 16 -s 16 --user-agent="Mozilla/5.0" -o "train_valid.tar.gz" \
    "https://zenodo.org/records/13373108/files/train_valid.tar.gz"
aria2c -x 16 -s 16 --user-agent="Mozilla/5.0" -o "ProtFrag.tar.gz" \
    "https://zenodo.org/records/13373108/files/ProtFrag.tar.gz"
aria2c -x 16 -s 16 --user-agent="Mozilla/5.0" -o "LNR.tar.gz" \
    "https://zenodo.org/records/13373108/files/LNR.tar.gz"
for f in *.tar.gz; do tar -xzvf $f; done && rm *.tar.gz
cd ../..

Antibodies (SAbDab):

mkdir -p data/sabdab && cd data/sabdab
aria2c -x 16 -s 16 -o summary.csv \
    "https://opig.stats.ox.ac.uk/webapps/sabdab-sabpred/sabdab/summary/all/"
aria2c -x 16 -s 16 -o all_structures.zip \
    "https://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/"
unzip all_structures.zip && rm all_structures.zip
cd ../..

HiQBind (binding affinity):

mkdir -p data/hiqbind && cd data/hiqbind
wget -O hiqbind.tar.gz "https://ndownloader.figshare.com/files/52379345"
wget -O hiqbind_metadata.csv "https://ndownloader.figshare.com/files/52379381"
tar -xzvf hiqbind.tar.gz
cd ../..

Preprocess

# CrossDocked: pocket extraction + LMDB creation (two-pass)
python scripts/preprocess.py --dataset crossdocked

# With options
python scripts/preprocess.py --dataset crossdocked --max-samples 1000 --num-workers 8

# Precompute vocab graphs (required for training)
python scripts/precompute_vocab_graphs.py \
    --vocab_path data/crossdocked/processed/vocab_ligand.csv \
    --output_path data/crossdocked/processed/vocab_graphs.lmdb

Output structure:

data/crossdocked/processed/
├── samples.lmdb           # All preprocessed samples
├── vocab_ligand.csv       # Ligand fragment vocabulary
├── vocab_protein.csv      # Protein fragment vocabulary
├── vocab_graphs.lmdb      # Precomputed vocab fragment graphs
└── splits/
    ├── train.txt
    └── test.txt

Training

All training uses Hydra configs (configs/).

# Diffusion + jump (default)
python scripts/train.py --config-name=train_diffusion

# Unified (per-node timestep, recommended)
python scripts/train.py --config-name=train_unified_overfit

# Initializer only
python scripts/train.py --config-name=train_initializer_overfit

# Override any config
python scripts/train.py --config-name=train_diffusion model.d_hidden=128 max_epoch=5000

Sampling

# Unified sampler (single-pass, per-node timestep)
python scripts/sample_unified.py checkpoint=/path/to/ckpt

# Diffusion sampler (two-pass: diffusion + init model)
python scripts/sample_diffusion.py --config-name=sample_diffusion checkpoint=/path/to/ckpt

Project Structure

src/
├── models/              # FragFM3D model
│   ├── fragfm3d.py          Main model (FragFM3D, create_fragfm3d_from_config)
│   ├── hierarchical_transformer.py   SE(3) transformer with per-node timestep
│   ├── unified_layer.py     GET attention + message passing
│   ├── fragment_embedder.py  MPNN for fragment graph encoding
│   ├── edge_builder.py      Hierarchical edge construction
│   ├── heads/               TypeHead (fragment type), EdgeHead (bond prediction)
│   └── modules/             Attention submodules
├── data/                # Data pipeline
│   ├── datasets.py          CrossDockedDataset (PyG, LMDB-backed)
│   ├── parsers.py           PDB parser
│   ├── processing.py        Preprocessing pipeline
│   ├── features.py          Feature computation
│   ├── ligand.py            Ligand processing (BRICS decomposition)
│   ├── protein.py           Protein processing
│   └── tokenizer.py         Fragment tokenizer
├── trainer/             # Training
│   ├── diffusion_trainer.py     VP-SDE + jump training
│   ├── initializer_trainer.py   Atom position initialization
│   ├── unified_trainer.py       Per-node timestep (init + diffusion)
│   ├── sample_noiser.py         Noise application + atom masking
│   ├── diffusion_scheduler.py   VP-SDE / jump schedules
│   ├── negative_sampler.py      Vocabulary negative sampling
│   └── loss.py                  Loss functions
├── sampler/             # Generation
│   ├── diffusion_sampler.py     Reverse diffusion (DDPM/DDIM)
│   ├── unified_sampler.py       Single-pass with per-node timestep
│   ├── sampling_utils.py        Sampling utilities
│   ├── oracle.py                Pipeline verification
│   └── types.py                 SamplingState, SamplingOutput
├── qm/                  # Quantum mechanics
│   ├── fragment.py          BRICS fragmentation, H-capping, XYZ I/O, Kabsch
│   ├── xtb.py               GFN2-xTB calculator
│   ├── dft.py               DFT/ORCA calculator
│   ├── moi.py               Moment of inertia, rotation utilities
│   └── tensor_store.py      TensorBundle, TensorStore
├── evaluation/          # Metrics (RMSD, clash, diversity, dihedrals)
├── affinity/            # Binding affinity prediction (HiQBind)
├── constant.py          # Protein constants (sidechain atoms, residues)
└── utils/               # Molecular transforms, GNN helpers, logging

scripts/                 # Entry points
├── train.py                 Training (diffusion / initializer / unified)
├── preprocess.py            Data preprocessing (crossdocked / pdbbind)
├── precompute_vocab_graphs.py   Vocab graph LMDB creation
├── sample_diffusion.py      Reverse diffusion sampling
├── sample_unified.py        Unified sampling
└── qm/                      QM extraction scripts

configs/                 # Hydra YAML configs
├── model/fragfm3d.yaml     Model architecture defaults
├── train_*.yaml             Training configs
└── sample_*.yaml            Sampling configs

test/                    # Tests
├── test_imports.py
├── test_crossdocked_ligand.py
├── test_crossdocked_protein.py
├── test_protein_reconstruction.py
└── test_xtb_wrapper.py

analyze/                 # Analysis scripts (permanent record, never delete)
docs/progress/           # Progress logs (YYMMDD_slug.md)
experiments/             # Experiment tracking

Tests

pytest test/ -v
pytest test/test_imports.py -v          # Quick import check
pytest test/test_crossdocked_ligand.py  # Ligand pipeline
pytest test/test_xtb_wrapper.py         # QM (requires xTB)

Experiments

Experiment IDs follow FRAG3D-YYMMDD-NN. Registry at experiments/registry.csv. Progress logs at docs/progress/. Analysis scripts at analyze/.

Key Dependencies

PyTorch, PyTorch Geometric, RDKit, Hydra/OmegaConf, Biopython, wandb, scipy, xTB (optional).

About

FragFM3D - Python project

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages