In [1]:
import os
from dataclasses import dataclass
from functools import cached_property
import scanpy as sc
import anndata as ad
import pandas as pd
import requests
import torch
from collections.abc import Callable, Iterable
from esm import FastaBatchedDataset, pretrained

try:
    import torch
    from torch.utils.data import DataLoader
    from transformers import AutoTokenizer, EsmModel
except ImportError as e:
    torch = None
    DataLoader = None
    AutoTokenizer = None
    EsmModel = None
    raise ImportError(
        "To use gene embedding, please install `transformers` and `torch` \
            e.g. via `pip install cfp['embedding']`."
    ) from e

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
from Bio import SeqIO
import pandas as pd

out_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/pbmc"
fasta_file = os.path.join(out_dir, "cytokines.fasta")

records = list(SeqIO.parse(fasta_file, "fasta"))

data = {
    "gene_id": [record.description for record in records],
    "protein_sequence": [str(record.seq) for record in records]
}
df_sequences = pd.DataFrame(data)



In [19]:
df_sequences

Unnamed: 0,gene_id,protein_sequence
0,4-1BBL,MEYASDASLDPEAPWPPAPRARACRVLPWALVAGLLLLLLLAAACA...
1,ADSF,MKALCLLLLPVLGLLVSSKTLCSMEEAINERIQEVAGSLIFRAISS...
2,APRIL,MPASSPFLLAPKGPPGNMGGPVREPALSVALWLSWGAALGAVACAM...
3,BAFF,MRRGPRSLRGRDAPAPTPCVPAECFDLLVRHCVACGLLRTPRPKPA...
4,C3a,MASFSAETNSTDLLSQPWNEPPVILSMVILSLTFLLGLPGNGLVLW...
...,...,...
88,LT-alpha,MTPPERLFLPRVCGTTLHLLLLGLLLVLLPGAQGLPGVGLTPSAAQ...
89,LT-beta,MGALGLEGRGGRLQGRGSLLLAVAGATSLVTLLLAVPITVLAVLAL...
90,EBI3,MTPQLLLALVLWASCPPCSGRKGPPAALTLPRVQCRASRYPIAVDC...
91,LT-alpha1-beta2,MTPPERLFLPRVCGTTLHLLLLGLLLVLLPGAQGLPGVGLTPSAAQ...


In [20]:
@dataclass
class EmbeddingConfig:
    fasta_path: str
    model_name: str = "esm2_t36_3B_UR50D"
    output_dir: str = "gene_embeddings"
    include: str = "mean"
    use_gpu: bool = True
    toks_per_batch: int = 4096
    truncation_seq_length: int = 1022
    repr_layers: list[int] | None = None
    save_to_disk: bool = True
    _valid_includes = ["per_tok", "mean", "bos"]

    def __post_init__(self):
        if self.repr_layers is None:
            self.repr_layers = [-1]
        assert (
            self.include in self._valid_includes
        ), f"Must be one of {self._valid_includes}"




In [21]:
def embedding_from_seq(
    config: EmbeddingConfig,
):
    model, alphabet = pretrained.load_model_and_alphabet(config.model_name)
    model.eval()
    if torch.cuda.is_available() and config.use_gpu:
        model = model.cuda()
    dataset = FastaBatchedDataset.from_file(config.fasta_path)
    batches = dataset.get_batch_indices(config.toks_per_batch, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=alphabet.get_batch_converter(config.truncation_seq_length),
        batch_sampler=batches,
    )
    print(f"Read {len(dataset)} sequences from {config.fasta_path}")
    results: dict[str, dict[str, torch.Tensor]] = {}
    # Don't overwrite existing embeddings
    for file in os.listdir(config.output_dir):
        if file.endswith(".pth"):
            print(
                f"Found existing .pth file in {config.output_dir}. Skipping embedding generation."
            )
            return

        assert all(
            -(model.num_layers + 1) <= i <= model.num_layers for i in config.repr_layers
        )
    repr_layers = [
        (i + model.num_layers + 1) % (model.num_layers + 1) for i in config.repr_layers
    ]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available() and config.use_gpu:
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }

            for i, label in enumerate(labels):
                output_file = os.path.join(config.output_dir, f"{label}.pth")
                result = {"label": label}
                truncate_len = min(config.truncation_seq_length, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                emb = None
                if "per_tok" in config.include:
                    emb = {
                        layer: t[i, 1 : truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                elif "mean" in config.include:
                    emb = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                elif "bos" in config.include:
                    emb = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }

                result[config.include] = emb
                average_layers = torch.stack(list(emb.values())).mean(0)
                result["average_layers"] = average_layers
                if config.save_to_disk:
                    torch.save(
                        result,
                        output_file,
                    )
                    results[label] = output_file
                else:
                    results[label] = average_layers
    return results

In [22]:
class BatchedDataset:
    """Modified batched dataset from fair-esm `c9c7d4f0fec964ce10c3e11dccec6c16edaa5144`"""

    def __init__(self, sequence_labels, sequence_strs):
        self.sequence_labels = list(sequence_labels)
        self.sequence_strs = list(sequence_strs)

    def __len__(self):
        return len(self.sequence_labels)

    def __getitem__(self, idx):
        return self.sequence_labels[idx], self.sequence_strs[idx]

    def get_batch_indices(
        self, toks_per_batch, extra_toks_per_seq=0
    ) -> list[list[int]]:
        sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
        sizes.sort()
        batches = []
        buf: list[int] = []
        max_len = 0

        def _flush_current_buf():
            nonlocal max_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            max_len = 0

        for sz, i in sizes:
            sz += extra_toks_per_seq
            if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
                _flush_current_buf()
            max_len = max(max_len, sz)
            buf.append(i)

        _flush_current_buf()
        return batches
        
def create_dataloader(
    prot_names: list[str],
    sequences: list[str],
    toks_per_batch: int,
    collate_fn: Callable,  # type: ignore[type-arg]
) -> DataLoader:
    dataset = BatchedDataset(prot_names, sequences)
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    data_loader = DataLoader(
        dataset,
        collate_fn=collate_fn,
        batch_sampler=batches,
    )
    return data_loader


def _get_esm_collate_fn(
    tokenizer: Callable, max_length: int | None, truncation: bool, return_tensors: str  # type: ignore[type-arg]
) -> Callable:  # type: ignore[type-arg]
    def collate_fn(batch):
        # batch of tuples (gene_id, sequence)
        gene_id, seq = zip(*batch, strict=False)
        metadata = {"gene_id": gene_id, "protein_sequence": seq}
        token = tokenizer(
            seq,
            padding=True,
            max_length=max_length,
            truncation=truncation,
            return_tensors=return_tensors,
        )
        return metadata, token

    return collate_fn
    
def get_model_and_tokenizer(
    model_name: str, use_cuda: bool, cache_dir: None | str
) -> tuple[EsmModel, AutoTokenizer]:
    model_path = os.path.join("facebook", model_name)
    model = EsmModel.from_pretrained(model_path, cache_dir=cache_dir)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)
    if use_cuda:
        model = model.cuda()
    model.requires_grad_(False)
    return model, tokenizer
    
def protein_features_from_genes(
    metadata: pd.DataFrame,
    esm_model_name: str = "esm2_t36_3B_UR50D",
    toks_per_batch: int = 4096,
    trunc_len: int | None = 1022,
    truncation: bool = True,
    use_cuda: bool = True,
    cache_dir: None | str = None,
) -> tuple[dict[str, torch.Tensor], pd.DataFrame]:
    """
    Compute gene embeddings using ESM2 :cite:`lin:2023`.
    Parameters
    ----------
    genes : list[str]
        List of gene names.
    esm_model_name : str
        Name of the ESM model to use.
    toks_per_batch : int
        Number of tokens per batch.
    trunc_len : int | None
        Maximum length of the sequence.
    truncation : bool
        Whether to truncate the sequence.
    use_cuda : bool
        Use GPU if available.
    cache_dir : str | None
        Directory to cache the model.
    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary with gene names as keys and embeddings as values.
    """
    if os.getenv("HF_HOME") is None and cache_dir is None:
        print(
            "HF_HOME environment variable is not set and `cache_dir` is None. \
                Cache will be stored in the current directory."
        )
    to_emb = metadata[metadata.protein_sequence.notnull()]
    use_cuda = use_cuda and torch.cuda.is_available()
    esm, tokenizer = get_model_and_tokenizer(esm_model_name, use_cuda, cache_dir)
    data_loader = create_dataloader(
        prot_names=to_emb["gene_id"].to_list(),
        sequences=to_emb["protein_sequence"].to_list(),
        toks_per_batch=toks_per_batch,
        collate_fn=_get_esm_collate_fn(
            tokenizer, max_length=trunc_len, truncation=truncation, return_tensors="pt"
        ),
    )
    results = {}
    for batch_metadata, batch in data_loader:
        if use_cuda:
            batch = {k: v.cuda() for k, v in batch.items()}
        batch_results = esm(**batch).last_hidden_state
        for i, name in enumerate(batch_metadata["gene_id"]):
            trunc_len = min(trunc_len, len(batch_metadata["protein_sequence"][i]))  # type: ignore[type-var]
            emb = batch_results[i, 1 : trunc_len + 1].mean(0).clone()  # type: ignore[operator]
            results[name] = emb
    return results, metadata


In [23]:
res = protein_features_from_genes(df_sequences)

HF_HOME environment variable is not set and `cache_dir` is None.                 Cache will be stored in the current directory.


Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.91s/it]
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
len(res[0])

93

In [25]:
len(df_sequences)

93

In [26]:
for k1,seq1 in res[0].items():
    for k2,seq2 in res[0].items():
        if (seq1==seq2).all() and k1!=k2:
            print(k1,k2)

In [27]:
df_sequences.head()

Unnamed: 0,gene_id,protein_sequence
0,4-1BBL,MEYASDASLDPEAPWPPAPRARACRVLPWALVAGLLLLLLLAAACA...
1,ADSF,MKALCLLLLPVLGLLVSSKTLCSMEEAINERIQEVAGSLIFRAISS...
2,APRIL,MPASSPFLLAPKGPPGNMGGPVREPALSVALWLSWGAALGAVACAM...
3,BAFF,MRRGPRSLRGRDAPAPTPCVPAECFDLLVRHCVACGLLRTPRPKPA...
4,C3a,MASFSAETNSTDLLSQPWNEPPVILSMVILSLTFLLGLPGNGLVLW...


In [28]:
res[0]

{'IFN-lambda1': tensor([ 0.0112, -0.0815,  0.0285,  ...,  0.0089, -0.0819, -0.0276],
        device='cuda:0'),
 'IL-8': tensor([-0.0143, -0.0439, -0.0228,  ..., -0.0219, -0.1844, -0.1153],
        device='cuda:0'),
 'ADSF': tensor([ 0.0192, -0.0595, -0.0010,  ...,  0.0255, -0.0967,  0.0006],
        device='cuda:0'),
 'TWEAK': tensor([ 0.0480, -0.0882, -0.0545,  ...,  0.0856, -0.0350, -0.1216],
        device='cuda:0'),
 'IL-5': tensor([ 0.0349, -0.0719,  0.0069,  ...,  0.0076, -0.0923, -0.0406],
        device='cuda:0'),
 'GM-CSF': tensor([ 0.0212, -0.0584, -0.0158,  ...,  0.0066, -0.0830, -0.0619],
        device='cuda:0'),
 'IL-13': tensor([ 0.0542, -0.0739,  0.0117,  ..., -0.0002, -0.1194, -0.0866],
        device='cuda:0'),
 'LIGHT': tensor([-0.0672,  0.0025,  0.0571,  ...,  0.0814, -0.1939, -0.1834],
        device='cuda:0'),
 'IL-3': tensor([ 0.0705, -0.0476, -0.0049,  ..., -0.0105, -0.1240, -0.0033],
        device='cuda:0'),
 'IL-2': tensor([ 0.0839, -0.0132, -0.0095,  ..., -0

In [30]:
import numpy as np
embs = {}
for cytokine, v in res[0].items():
    embs[cytokine] = np.array(v.cpu())

In [31]:
embs_dict = {k: v.astype("float64") for k,v in embs.items()}

In [32]:
import pickle
with open(os.path.join(out_dir, "esm2_embeddings.pkl"), "wb") as file:
    pickle.dump(embs_dict, file)