In [None]:
# Install esm and other dependencies
!pip install esm
!pip install matplotlib
!pip install biopython
!pip install parquet
!pip install tables

# Import necessary libraries

In [31]:
import glob
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence, Dict, List

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import boto3

from Bio import SeqIO

from esm.sdk import client
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    ESMProteinError,
    LogitsConfig,
    LogitsOutput,
    ProteinType,
)
import parquet

# Define utilitary functions

In [87]:
EMBEDDING_CONFIG = LogitsConfig(
    sequence=True, return_embeddings=True, return_hidden_states=True, ith_hidden_layer=100
)


def embed_sequence(model: ESM3InferenceClient, sequence: str) -> LogitsOutput:
    protein = ESMProtein(sequence=sequence)
    protein_tensor = model.encode(protein)
    output = model.logits(protein_tensor, EMBEDDING_CONFIG)
    return output


def batch_embed(
    model: ESM3InferenceClient, inputs: Sequence[ProteinType]
) -> Sequence[LogitsOutput]:
    """Forge supports auto-batching. So batch_embed() is as simple as running a collection
    of embed calls in parallel using asyncio.
    """
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(embed_sequence, model, protein) for protein in inputs
        ]
        results = []
        for future in futures:
            try:
                results.append(future.result())
            except Exception as e:
                results.append(ESMProteinError(500, str(e)))
    return results

def embed_sequence(model: ESM3InferenceClient, sequence: str) -> LogitsOutput:
    protein = ESMProtein(sequence=sequence)
    protein_tensor = model.encode(protein)
    output = model.logits(protein_tensor, EMBEDDING_CONFIG)
    return output

def pad_arrays_to_max_shape(list_of_arrays, pad_value=0):
    """
    Pads a list of 2D NumPy arrays to the shape of the largest array in the list.

    Args:
        list_of_arrays (list): A list of 2D NumPy arrays.
        pad_value (int or float): The value to use for padding (default is 0).

    Returns:
        list: A new list containing the padded 2D NumPy arrays, all with the same shape.
    """
    max_rows = max(arr.shape[0] for arr in list_of_arrays)
    #max_cols = max(arr.shape[1] for arr in list_of_arrays)
    max_shape =  max_rows#(max_rows, max_cols)

    padded_arrays = []
    for arr in list_of_arrays:
        rows_to_pad = max_rows - arr.shape[0]
        #cols_to_pad = max_cols - arr.shape[1]

        # Use np.pad to add padding
        padded_arr = np.pad(arr,
                            pad_width=((0, rows_to_pad)),
                            mode='constant',
                            constant_values=pad_value)
        padded_arrays.append(padded_arr)

    return padded_arrays

# Request specific layer

ESM C 6B's hidden states are really large, so we only allow one specific layer to be requested per API call. This also works for other ESM C models, but it is required for ESM C 6B. Refer to https://forge.evolutionaryscale.ai/console to find the number of hidden layers for each model.

In [88]:
# ESMC_6B_EMBEDDING_CONFIG = LogitsConfig(return_hidden_states=True, ith_hidden_layer=55)

# Load dataset

In [89]:
protein = ESMProtein(sequence="AAAAA")
client = ESMC.from_pretrained("esmc_300m").to("cuda") # or "cpu"
protein_tensor = client.encode(protein)
logits_output = client.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)

# Subset proteins set

In [95]:
def get_embeddings_from_fasta(proteome_fasta: str = None, model: ESM3InferenceClient = None) -> Dict:
    """
    Get protein embeddings from sequence and return dictionary where keys are proteins ids and values are embedded sequences
    """
    embeddings_dict = {}

    fasta_sequences = SeqIO.parse(proteome_fasta, 'fasta')
    
    for idx, protein in enumerate(fasta_sequences):
        name = protein.id
        sequence = protein.seq
        embeddings_dict[name] = np.mean(embed_sequence(model, str(sequence)).embeddings.cpu().squeeze(0).numpy(), axis = 1)

    embeddings_list = embeddings_dict.values()

    padded_embeddings = pad_arrays_to_max_shape(embeddings_list)

    mean_embedding = np.mean(padded_embeddings, axis = 0)

    return mean_embedding

def get_embeddings_from_csv(proteome: str = None, model: ESM3InferenceClient = None) -> Dict:

    embeddings_dict = {}
    embeddings_list = []

    df = pd.read_csv(proteome)

    sequences = list(df['seq'])
    phage_id = list(df['phage_id'])[0]

    protein_ids = list(df['protein_id'])
    
    for protein_id, seq in zip(protein_ids, sequences):
        embedding = np.mean(embed_sequence(model, seq).embeddings.cpu().squeeze(0).numpy(), axis =1)
         #embeddings_dict[protein_id] = embed_sequence(model, seq).embeddings.cpu().squeeze(0).numpy()
        embeddings_list.append(embedding)

    padded_embeddings = pad_arrays_to_max_shape(embeddings_list)

    mean_embedding = np.mean(padded_embeddings, axis = 0)

    return phage_id, mean_embedding


# Embed all phage proteoms

In [96]:
phages_folders = ['phages_proteoms/Escherichia_coli/', 'phages_proteoms/Klebsiella_pneumoniae/', 'phages_proteoms/Pseudomonas_aeruginosa/', 'phages_proteoms/Staphylococcus_aureus/']
embeddings_df_columns = {
    "phage_id": [],
    "embeddings": []
}

for folder in phages_folders:
    files = glob.glob(f"{folder}*")
    for file in files:
        phage_id, embedding = get_embeddings_from_csv(proteome=file, model=client)
        embeddings_df_columns["phage_id"].append(phage_id)
        embeddings_df_columns["embeddings"].append(embedding)

In [97]:
embeddings_df = pd.DataFrame(embeddings_df_columns)

In [None]:
embeddings_df.to_hdf("../tmp/phage_embeddings.h5", "phage_embeddings")

# Reload embeddings for check

In [103]:
embeddings_df_reloaded = pd.read_hdf("../tmp/phage_embeddings.h5", key="phage_embeddings")

In [104]:
embeddings_df_reloaded.head()

Unnamed: 0,phage_id,embeddings
0,PQ850631,"[-4.6356887e-05, 0.0010236966, 0.0007067089, 0..."
1,MZ826699,"[-3.8704005e-05, 0.001010636, 0.00069760345, 0..."
2,MF356679,"[-6.9524205e-05, 0.000920012, 0.000651642, 0.0..."
3,GQ149088,"[-1.899178e-05, 0.0010069675, 0.0007526012, 0...."
4,KX266606,"[-6.262038e-05, 0.0010293312, 0.0008145436, 0...."


In [105]:
embeddings_df_reloaded.iloc[0]['embeddings'].shape

(967,)

## Embed bacteria proteins

In [106]:
bacterias_df = pd.read_csv("bacterias.csv", dtype=dict(bacteria=str, bvbrc_id=str))
bacterias_df.head()

Unnamed: 0,bacteria,bvbrc_id
0,Escherichia coli,511145.12
1,Klebsiella pneumoniae,1125630.4
2,Pseudomonas aeruginosa,287.5706
3,Ralstonia solanacearum,305.1006
4,Lactococcus lactis,1360.457


In [107]:

bacteria_embeddings_df_columns = {
    "bacteria_id": [],
    "embeddings": []
}
for i, row in bacterias_df.iterrows():
    bacteria_id, bvbrc_id = row
    print(bacteria_id)
    proteome = f"bacteria_proteoms/{bvbrc_id}.PATRIC.selected.faa"
    embedding = get_embeddings_from_fasta(proteome_fasta=proteome, model=client)
    bacteria_embeddings_df_columns["bacteria_id"].append(bacteria_id)
    bacteria_embeddings_df_columns["embeddings"].append(embedding)

Escherichia coli
Klebsiella pneumoniae
Pseudomonas aeruginosa
Ralstonia solanacearum
Lactococcus lactis
Staphylococcus aureus
Salmonella sp.
Acinetobacter baumannii
Streptococcus thermophilus


In [108]:
bacteria_embeddings_df = pd.DataFrame(bacteria_embeddings_df_columns)

In [155]:
bacteria_embeddings_df.iloc[0]['embeddings'].shape

(1261,)

In [None]:
bacteria_embeddings_df.to_hdf("../tmp/bacteria_embeddings.h5", key="bacteria_embeddings")

In [110]:
bacteria_embeddings_df_reloaded = pd.read_hdf("../tmp/bacteria_embeddings.h5", "bacteria_embeddings")

## Preperation of concatenated vectors

In [111]:
embeddings_df = embeddings_df_reloaded
bacteria_embeddings_df = bacteria_embeddings_df_reloaded

In [112]:
metadata = pd.read_csv("./14Apr2025_data_excluding_refseq.tsv", sep="\t")

In [113]:
embeddings_df["padded_embeddings"] = pad_arrays_to_max_shape(embeddings_df["embeddings"].values)
bacteria_embeddings_df["padded_embeddings"] = pad_arrays_to_max_shape(bacteria_embeddings_df["embeddings"].values)

In [157]:
bacteria_embeddings_df.iloc[0]['padded_embeddings'].shape

(4199,)

In [114]:
embeddings_df["flat_embeddings"] = [embedding.flatten() for embedding in embeddings_df["padded_embeddings"].values]
bacteria_embeddings_df["flat_embeddings"] = [embedding.flatten() for embedding in bacteria_embeddings_df["padded_embeddings"].values]

In [115]:
import itertools

In [116]:
viruses = embeddings_df["phage_id"].unique()
bacterias = bacterias_df["bacteria"].unique()

len(list(itertools.product(viruses, bacterias)))

3402

In [117]:
def find_label(phage_id, bacteria_id):
    phage_metadata = metadata[metadata["Accession"] == phage_id]
    pair_metadata = phage_metadata[phage_metadata["Isolation Host (beware inconsistent and nonsense values)"].str.contains(bacteria_id)]
    return pair_metadata.shape[0] > 0

In [118]:
features_df_columns = {
    "phage_id": [],
    "bacteria_id": [],
    "embedding": [],
    "label": []
}


for phage_id, bacteria_id in itertools.product(viruses, bacterias):
    label = find_label(phage_id, bacteria_id)

    viral_embedding = embeddings_df[embeddings_df["phage_id"] == phage_id]["flat_embeddings"].values[0]
    bacteria_embedding = bacteria_embeddings_df[bacteria_embeddings_df["bacteria_id"] == bacteria_id]["flat_embeddings"].values[0]
    features_df_columns["phage_id"].append(phage_id)
    features_df_columns["bacteria_id"].append(bacteria_id)
    features_df_columns["embedding"].append(np.append(viral_embedding, bacteria_embedding))
    label = find_label(phage_id, bacteria_id)
    features_df_columns["label"].append(label)
    

In [129]:
features_df = pd.DataFrame(features_df_columns)
features_df["label"].value_counts()


label
False    3024
True      378
Name: count, dtype: int64

In [130]:
features_df_false = features_df[~features_df["label"]]
features_df_true = features_df[features_df["label"]]

In [143]:
features_df_false.head()

Unnamed: 0,phage_id,bacteria_id,embedding,label
1,PQ850631,Klebsiella pneumoniae,"[-4.6356887e-05, 0.0010236966, 0.0007067089, 0...",False
2,PQ850631,Pseudomonas aeruginosa,"[-4.6356887e-05, 0.0010236966, 0.0007067089, 0...",False
3,PQ850631,Ralstonia solanacearum,"[-4.6356887e-05, 0.0010236966, 0.0007067089, 0...",False
4,PQ850631,Lactococcus lactis,"[-4.6356887e-05, 0.0010236966, 0.0007067089, 0...",False
5,PQ850631,Staphylococcus aureus,"[-4.6356887e-05, 0.0010236966, 0.0007067089, 0...",False


In [131]:
sampled_features_df_false = features_df_false.sample(350)
sampled_features_df_true = features_df_true.sample(350)

In [132]:
sampled_features_df = pd.concat([sampled_features_df_false, sampled_features_df_true])

In [None]:
batch_size = 10
for i, df_chunk in sampled_features_df.groupby(np.arange(sampled_features_df.shape[0]) // batch_size):
    df_chunk.to_hdf('../tmp/features.h5', 'features_%i'%i, complib= 'blosc:lz4', mode='a')

In [137]:
exploded_df = features_df['embedding'].apply(pd.Series)

In [None]:
exploded_df.head()

In [148]:
new_df = pd.concat([features_df.drop('embedding', axis=1), exploded_df], axis=1)

In [151]:
new_df.to_csv("final_features.csv", index = False)