In [8]:
import sys
import os
sys.path.append(os.path.abspath("../src"))
from embedders import PhlyoGPNEmbedder, PhlyoGPNPooledEmbedder

### Steps for setting up PhyloGPN and PhyloGPN-pooled in BEND:

1. **Copy the Embedder Classes**  
   Copy the `PhlyoGPNEmbedder` and `PhlyoGPNPooledEmbedder` classes from `src/embedders.py` in this repository to `bend/utils/embedders.py` in BEND.

2. **Add Model Instantiators**  
   Add the following model instantiators to `conf/embedding/embed.yaml` in BEND:
   ```yaml
   PhyloGPN:
       _target_: bend.utils.embedders.PhlyoGPNEmbedder
       config_path: path_to_config.yaml
       model_path: path_to_checkpoint.pt

   PhyloGPN-pooled:
       _target_: bend.utils.embedders.PhlyoGPNPooledEmbedder
       config_path: path_to_config.yaml
       model_path: path_to_checkpoint.pt
    ```

3. **Set Embedding Dimensions**  
   Add both model's embedding dimensions to `conf/datadims/embedding_dims.yaml` in BEND:
    ```yaml
        PhyloGPN: 960
        PhyloGPN-pooled: 960
    ```

4. **Run Scripts**  
   Execute the following scripts as instructed in BEND:
   - `precompute_embeddings.py`
   - `train_on_task.py`

5. **Add Models to VEP Tasks**  
   For BEND VEP tasks, add `PhyloGPN` and `PhyloGPN-pooled` to `predict_variant_effects.py`.

**Note**: We have also provided the `Caduceus` embedder for benchmarking. Its usage is similar to that of `PhyloGPN` and `PhyloGPN-pooled`.


### Demo: PhyloGPN as embedder in BEND

In [2]:
config_path = "../PhyloGPN/config.yaml"
model_path = "../PhyloGPN/checkpoint.pt"
embedder = PhlyoGPNEmbedder(config_path=config_path, model_path=model_path)
embedder

<embedders.PhlyoGPNEmbedder at 0x7fddcf4b6610>

In [3]:
sequences = ["AGACAAAGTTTTNTTAAAANNTTTCCCGGGGG"*16]
print(f"sequence length: {len(sequences[0])}:", sequences)
embeddings = embedder.embed(sequences, disable_tqdm=True)
print("Embedding shape:", embeddings[0].shape)

sequence length: 512: ['AGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGG']
Embedding shape: (1, 512, 960)


In [4]:
sequences = ["AGACAAAGTTTTNTTAAAANNTTTCCCGGGGG"* 312 + "AAAAAAAATTTTAAAA"]
print(f"sequence length: {len(sequences[0])}:", sequences)
embeddings = embedder.embed(sequences, disable_tqdm=True)
print("Embedding shape:", embeddings[0].shape)

sequence length: 10000: ['AGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNT

### Demo: PhyloGPN-pooled as embedder in BEND

In [5]:
config_path = "../PhyloGPN/config.yaml"
model_path = "../PhyloGPN/checkpoint.pt"
embedder = PhlyoGPNPooledEmbedder(config_path=config_path, model_path=model_path)
embedder

<embedders.PhlyoGPNPooledEmbedder at 0x7fdd78059160>

In [6]:
sequences = ["AGACAAAGTTTTNTTAAAANNTTTCCCGGGGG"*16]
print(f"sequence length: {len(sequences[0])}:", sequences)
embeddings = embedder.embed(sequences, disable_tqdm=True)
print("Embedding shape:", embeddings[0].shape)

sequence length: 512: ['AGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGG']
Embedding shape: (1, 512, 960)


In [7]:
sequences = ["AGACAAAGTTTTNTTAAAANNTTTCCCGGGGG"* 312 + "AAAAAAAATTTTAAAA"]
print(f"sequence length: {len(sequences[0])}:", sequences)
embeddings = embedder.embed(sequences, disable_tqdm=True)
print("Embedding shape:", embeddings[0].shape)

sequence length: 10000: ['AGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNTTAAAANNTTTCCCGGGGGAGACAAAGTTTTNT