### This notebook generates ProteinBERT embeddings for bacterial data

In [None]:
import pandas as pd
import glob, json

from proteinbert import load_pretrained_model
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

from tqdm import tqdm
tqdm.pandas() 

In [None]:
base_dir = '' # Insert the directory where you cloned the repository
data_dir = f'{base_dir}/data'
data_files = glob.glob(f'{data_dir}/*.csv')

In [None]:
# Create a df containing Proteins sequences from all genes
protein_df = pd.DataFrame()
for file in data_files:
    data = pd.read_csv(file)
    gene_name = file.split('/')[-1].split('.')[0]
    data = data.drop(columns='dna_seq')
    data['gene'] = gene_name
    protein_df = pd.concat([protein_df, data], ignore_index=True)

# Cleanup
protein_df.rename(columns={' organism': 'organism', ' strain': 'strain'}, inplace=True)
protein_df = protein_df[~protein_df['organism'].isna()]
protein_df = protein_df[~protein_df['strain'].isna()]

In [None]:
def get_embedding(aa_seq):
    seqs = [aa_seq]
    seq_len = len(seqs)
    batch_size = 1
    pretrained_model_generator, input_encoder = load_pretrained_model()
    model = get_model_with_hidden_layers_as_outputs(pretrained_model_generator.create_model(seq_len))
    encoded_x = input_encoder.encode_X(seqs, seq_len)
    local_representations, global_representations = model.predict(encoded_x, batch_size=batch_size)
    embedding = global_representations[0]
    return embedding

protein_df['embeddings'] = protein_df['protein_seq'].progress_apply(get_embedding)

In [None]:
protein_df['embeddings_json'] = protein_df['embeddings'].apply(lambda x: json.dumps(x.tolist()))

In [None]:
protein_df.to_csv('protein_embeddings.csv')