Fragment-based Flow Matching for 3D Structure-Based Drug Design.
Decomposes protein-ligand complexes into fragments, learns fragment-level representations via a hierarchical SE(3) transformer, and generates 3D structures via VP-SDE diffusion + discrete jump processes.
- FragFM3D: Hierarchical transformer with fragment-level and atom-level processing
FragmentEmbedder(MPNN) →frag_proj→HierarchicalTransformer→TypeHead+EdgeHead- 7 types of hierarchical edges (frag-frag, atom-atom, cross-level)
- Per-node timestep support (unified training mode)
- Diffusion: VP-SDE for continuous (positions), CTMC jump process for discrete (types, edges)
- Training modes:
diffusion(default),initializer(atom position pre-training),unified(per-node timestep, single-pass)
# Create conda environment
conda env create -f environment.yml
conda activate fragfm3d
# Install pip dependencies
pip install -r requirements.txt
pip install -r requirements-pyg.txt
pip install -e .
# Verify
pytest test/test_imports.py -vRequires Python >= 3.10, CUDA 12.8, PyTorch 2.8.0, PyTorch Geometric 2.7.0.
CrossDocked2020 (small molecules):
mkdir -p data/crossdocked && cd data/crossdocked
gdown https://drive.google.com/drive/folders/1j21cc7-97TedKh_El5E34yI8o5ckI7eK --folder
if [ -d "data" ]; then mv data/* . && rmdir data; fi
tar -xzvf crossdocked_v1.1_rmsd1.0.tar.gz && rm crossdocked_v1.1_rmsd1.0.tar.gz
unzip test_set.zip && rm test_set.zip
cd ../..Peptides (PepBench):
mkdir -p data/pepbench && cd data/pepbench
aria2c -x 16 -s 16 --user-agent="Mozilla/5.0" -o "train_valid.tar.gz" \
"https://zenodo.org/records/13373108/files/train_valid.tar.gz"
aria2c -x 16 -s 16 --user-agent="Mozilla/5.0" -o "ProtFrag.tar.gz" \
"https://zenodo.org/records/13373108/files/ProtFrag.tar.gz"
aria2c -x 16 -s 16 --user-agent="Mozilla/5.0" -o "LNR.tar.gz" \
"https://zenodo.org/records/13373108/files/LNR.tar.gz"
for f in *.tar.gz; do tar -xzvf $f; done && rm *.tar.gz
cd ../..Antibodies (SAbDab):
mkdir -p data/sabdab && cd data/sabdab
aria2c -x 16 -s 16 -o summary.csv \
"https://opig.stats.ox.ac.uk/webapps/sabdab-sabpred/sabdab/summary/all/"
aria2c -x 16 -s 16 -o all_structures.zip \
"https://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/"
unzip all_structures.zip && rm all_structures.zip
cd ../..HiQBind (binding affinity):
mkdir -p data/hiqbind && cd data/hiqbind
wget -O hiqbind.tar.gz "https://ndownloader.figshare.com/files/52379345"
wget -O hiqbind_metadata.csv "https://ndownloader.figshare.com/files/52379381"
tar -xzvf hiqbind.tar.gz
cd ../..# CrossDocked: pocket extraction + LMDB creation (two-pass)
python scripts/preprocess.py --dataset crossdocked
# With options
python scripts/preprocess.py --dataset crossdocked --max-samples 1000 --num-workers 8
# Precompute vocab graphs (required for training)
python scripts/precompute_vocab_graphs.py \
--vocab_path data/crossdocked/processed/vocab_ligand.csv \
--output_path data/crossdocked/processed/vocab_graphs.lmdbOutput structure:
data/crossdocked/processed/
├── samples.lmdb # All preprocessed samples
├── vocab_ligand.csv # Ligand fragment vocabulary
├── vocab_protein.csv # Protein fragment vocabulary
├── vocab_graphs.lmdb # Precomputed vocab fragment graphs
└── splits/
├── train.txt
└── test.txt
All training uses Hydra configs (configs/).
# Diffusion + jump (default)
python scripts/train.py --config-name=train_diffusion
# Unified (per-node timestep, recommended)
python scripts/train.py --config-name=train_unified_overfit
# Initializer only
python scripts/train.py --config-name=train_initializer_overfit
# Override any config
python scripts/train.py --config-name=train_diffusion model.d_hidden=128 max_epoch=5000# Unified sampler (single-pass, per-node timestep)
python scripts/sample_unified.py checkpoint=/path/to/ckpt
# Diffusion sampler (two-pass: diffusion + init model)
python scripts/sample_diffusion.py --config-name=sample_diffusion checkpoint=/path/to/ckptsrc/
├── models/ # FragFM3D model
│ ├── fragfm3d.py Main model (FragFM3D, create_fragfm3d_from_config)
│ ├── hierarchical_transformer.py SE(3) transformer with per-node timestep
│ ├── unified_layer.py GET attention + message passing
│ ├── fragment_embedder.py MPNN for fragment graph encoding
│ ├── edge_builder.py Hierarchical edge construction
│ ├── heads/ TypeHead (fragment type), EdgeHead (bond prediction)
│ └── modules/ Attention submodules
├── data/ # Data pipeline
│ ├── datasets.py CrossDockedDataset (PyG, LMDB-backed)
│ ├── parsers.py PDB parser
│ ├── processing.py Preprocessing pipeline
│ ├── features.py Feature computation
│ ├── ligand.py Ligand processing (BRICS decomposition)
│ ├── protein.py Protein processing
│ └── tokenizer.py Fragment tokenizer
├── trainer/ # Training
│ ├── diffusion_trainer.py VP-SDE + jump training
│ ├── initializer_trainer.py Atom position initialization
│ ├── unified_trainer.py Per-node timestep (init + diffusion)
│ ├── sample_noiser.py Noise application + atom masking
│ ├── diffusion_scheduler.py VP-SDE / jump schedules
│ ├── negative_sampler.py Vocabulary negative sampling
│ └── loss.py Loss functions
├── sampler/ # Generation
│ ├── diffusion_sampler.py Reverse diffusion (DDPM/DDIM)
│ ├── unified_sampler.py Single-pass with per-node timestep
│ ├── sampling_utils.py Sampling utilities
│ ├── oracle.py Pipeline verification
│ └── types.py SamplingState, SamplingOutput
├── qm/ # Quantum mechanics
│ ├── fragment.py BRICS fragmentation, H-capping, XYZ I/O, Kabsch
│ ├── xtb.py GFN2-xTB calculator
│ ├── dft.py DFT/ORCA calculator
│ ├── moi.py Moment of inertia, rotation utilities
│ └── tensor_store.py TensorBundle, TensorStore
├── evaluation/ # Metrics (RMSD, clash, diversity, dihedrals)
├── affinity/ # Binding affinity prediction (HiQBind)
├── constant.py # Protein constants (sidechain atoms, residues)
└── utils/ # Molecular transforms, GNN helpers, logging
scripts/ # Entry points
├── train.py Training (diffusion / initializer / unified)
├── preprocess.py Data preprocessing (crossdocked / pdbbind)
├── precompute_vocab_graphs.py Vocab graph LMDB creation
├── sample_diffusion.py Reverse diffusion sampling
├── sample_unified.py Unified sampling
└── qm/ QM extraction scripts
configs/ # Hydra YAML configs
├── model/fragfm3d.yaml Model architecture defaults
├── train_*.yaml Training configs
└── sample_*.yaml Sampling configs
test/ # Tests
├── test_imports.py
├── test_crossdocked_ligand.py
├── test_crossdocked_protein.py
├── test_protein_reconstruction.py
└── test_xtb_wrapper.py
analyze/ # Analysis scripts (permanent record, never delete)
docs/progress/ # Progress logs (YYMMDD_slug.md)
experiments/ # Experiment tracking
pytest test/ -v
pytest test/test_imports.py -v # Quick import check
pytest test/test_crossdocked_ligand.py # Ligand pipeline
pytest test/test_xtb_wrapper.py # QM (requires xTB)Experiment IDs follow FRAG3D-YYMMDD-NN. Registry at experiments/registry.csv.
Progress logs at docs/progress/. Analysis scripts at analyze/.
PyTorch, PyTorch Geometric, RDKit, Hydra/OmegaConf, Biopython, wandb, scipy, xTB (optional).