<a href="https://colab.research.google.com/github/vihan-lakshman/mutagenic/blob/main/substitution_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [50]:
# prompt: read in '/content/prelim-deletion-validation-dataset-functional-annotations.csv'
import torch
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

df = pd.read_csv('/content/subset_nonfunctional_DMS_substitutions_with_interpro.csv')


In [51]:
df.head()

Unnamed: 0,mutant,mutated_sequence,DMS_score,wildtype_seq,UniProt_ID,InterPro
0,L43P,SPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRRPLAY...,-3.196763,SPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRRLLAY...,RS15_GEOSE,IPR000589;IPR005290;IPR009068;
1,L56N,EAHAAIDTFTKYLDIDEDFATVLVEEGFSTLEELAYVPMKELLEIE...,-3.952602,EAHAAIDTFTKYLDIDEDFATVLVEEGFSTLEELAYVPMKELLEIE...,NUSA_ECOLI,IPR010995;IPR015946;IPR025249;IPR009019;IPR012...
2,I182K,MDSLVVLVLCLSCLLLLSLWRQSSGRGKLPPGPTPLPVIGNILQIG...,-0.045873,MDSLVVLVLCLSCLLLLSLWRQSSGRGKLPPGPTPLPVIGNILQIG...,CP2C9_HUMAN,IPR001128;IPR017972;IPR002401;IPR036396;IPR050...
3,Q2E,MEFKVYTYKRESRYRLFVDVQSDIIDTPGRRMVIPLASARLLSDKV...,-9.0,MQFKVYTYKRESRYRLFVDVQSDIIDTPGRRMVIPLASARLLSDKV...,CCDB_ECOLI,IPR002712;IPR011067;
4,G324V:G344M,MDCLCIVTTKKYRYQDEDTPPLEHSPAHLPNQANSPPVIVNTDTLE...,-1.545841,MDCLCIVTTKKYRYQDEDTPPLEHSPAHLPNQANSPPVIVNTDTLE...,DLG4_HUMAN,IPR019583;IPR016313;IPR019590;IPR008145;IPR008...


In [52]:
!pip install blosum
import blosum as bl
matrix = bl.BLOSUM(80)



In [53]:
# # prompt: for each row, identify the substring that starts at del_start (counting from 1, not 0), and ends at del_end, and for each letter in that substring, it is a amino acid residue that should be randomly replaced with one of the 20 amino acids. save this as "substituted_seq" in a column after "truncated_seq"

import random

# List of standard amino acid single-letter codes
amino_acids = "ACDEFGHIKLMNPQRSTVWY"

def substitute_substring(sequence, del_start, del_end):
    """Substitutes amino acids in a substring with random amino acids."""

    # Adjust del_start to be 0-indexed
    #del_start -= 1

    if del_start < 0 or del_end > len(sequence) or del_start >= del_end:
        return sequence  # Handle invalid indices

    substituted_seq = list(sequence)
    for i in range(del_start, del_end):
        substituted_seq[i] = random.choice(amino_acids)

    return "".join(substituted_seq)


#Apply the function to each row of the DataFrame
df['substituted_seq'] = df['mutated_sequence']
df['substituted_seq'].head()

Unnamed: 0,substituted_seq
0,SPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRRPLAY...
1,EAHAAIDTFTKYLDIDEDFATVLVEEGFSTLEELAYVPMKELLEIE...
2,MDSLVVLVLCLSCLLLLSLWRQSSGRGKLPPGPTPLPVIGNILQIG...
3,MEFKVYTYKRESRYRLFVDVQSDIIDTPGRRMVIPLASARLLSDKV...
4,MDCLCIVTTKKYRYQDEDTPPLEHSPAHLPNQANSPPVIVNTDTLE...


In [54]:
# Dictionary to hold the results
embeddings_dict = {}

# Iterate through each row in the DataFrame
for _, row in df.iterrows():
    entry = row['UniProt_ID']
    interpro = row['InterPro']

    # Skip rows where 'Interpro' is None
    if pd.isna(interpro) or not interpro.strip():
        continue

    # Split the InterPro IDs by semicolons
    interpro_ids = interpro.split(';')
    interpro_ids = interpro_ids[:-1]

    # Initialize entry in the dictionary if not present
    if entry not in embeddings_dict:
        embeddings_dict[entry] = {
            'InterPro_ids': interpro_ids
        }

In [55]:
!pip install esm



In [56]:
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig

# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login("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")

In [57]:
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to('cuda')

  state_dict = torch.load(


In [58]:
import torch.nn as nn
from esm.tokenization import InterProQuantizedTokenizer
from esm.utils.types import FunctionAnnotation

def get_keywords_from_interpro(
    interpro_annotations,
    interpro2keywords=InterProQuantizedTokenizer().interpro2keywords,
):
    keyword_annotations_list = []
    for interpro_annotation in interpro_annotations:
        keywords = interpro2keywords.get(interpro_annotation.label, [])
        keyword_annotations_list.extend([
            FunctionAnnotation(
                label=keyword,
                start=interpro_annotation.start,
                end=interpro_annotation.end,
            )
            for keyword in keywords
        ])
    return keyword_annotations_list

In [59]:
#protein that only has one function?
#longer sequences of all As, all Gs, or completely random
def get_label_embedding(interpro_label,sequence):
  hostProtein = ESMProtein(sequence=sequence)
  embedding_function = model.encoder.function_embed
  hostProtein.function_annotations = get_keywords_from_interpro([FunctionAnnotation(label=interpro_label, start=1, end=len(sequence))])
  hostProtein_tensor = model.encode(hostProtein)
  device = hostProtein_tensor.function.device  # Get the device of protein2_tensor.function
  embedding_function = embedding_function.to(device)  # Move embedding_function to the device

  function_embed = torch.cat(
      [
          embed_fn(funcs.to(device)) # Ensure funcs is on the same device
          for embed_fn, funcs in zip(
              embedding_function, hostProtein_tensor.function.unbind(-1)
          )
      ],
      -1,
  )
  if function_embed.shape[0] >= 3:
      row_sum = function_embed.sum(dim=0)  # Sum all rows
      row_avg = row_sum / (function_embed.shape[0] - 2)  # Divide by (number of rows - 2)
      row_avg_np = row_avg.cpu().detach().type(torch.float32).numpy()
      return row_avg_np
  else:
      return None

In [60]:
import numpy as np
for entry, interpro_ids in embeddings_dict.items():
  embeddings = []
  for interpro_id in interpro_ids['InterPro_ids']:
    # embeddings.append(get_label_embedding(interpro_id,"A"))
    label_embedding = get_label_embedding(interpro_id,"A")
    embeddings.append(label_embedding)

  #embeddings = torch.stack(embeddings, dim=0)
  #avg_embedding = torch.mean(embeddings)
  embeddings_dict[entry]['embedding'] = embeddings

In [61]:
def embedding_masking_model(
    prompt,
    model,
    df,
    embeddings_dict,
    percentage=10,
):
    """
    Helper function to process a protein sequence, calculate similarities,
    and return indices for masking.

    Args:
        prompt (str): The protein sequence to be processed.
        model: The model used for protein generation and embeddings.
        df (pd.DataFrame): DataFrame containing protein data.
        embeddings_dict (dict): Dictionary storing embeddings and other details.

    Returns:
        List[int]: Indices used for masking in the sequence.
    """
    # Create an ESMProtein object
    protein = ESMProtein(sequence=prompt)

    # Configure the model for generation
    generation_config = GenerationConfig(track="function", num_steps=8)

    # Generate the protein
    generated_protein = model.generate(protein, generation_config)

    # Check if function annotations are available
    entry = df.loc[df['substituted_seq'] == prompt, 'UniProt_ID'].iloc[0]
    if generated_protein.function_annotations is None:
        embeddings_dict[entry]['hamming_distance'] = None
        return []

    # Getting embedding for the protein
    protein_tensor = model.encode(generated_protein)
    embedding_function = model.encoder.function_embed
    device = protein_tensor.function.device  # Get the device of protein_tensor.function
    embedding_function = embedding_function.to(device)  # Move embedding_function to the device

    function_embed = torch.cat(
        [
            embed_fn(funcs.to(device))  # Ensure funcs is on the same device
            for embed_fn, funcs in zip(
                embedding_function, protein_tensor.function.unbind(-1)
            )
        ],
        -1,
    )

    # Exclude start and end tokens
    function_embed = function_embed[1:-1, :]

    # Convert the protein_tensor.function to a NumPy array
    protein_np = function_embed.cpu().detach().type(torch.float32).numpy()
    print(protein_np.shape)
    return None

    # Retrieve target sequence and embedding
    embedding = embeddings_dict[entry]['embedding']

    # Calculate cosine similarity
    embedding = embedding.cpu().detach().type(torch.float32).numpy()
    similarities = cosine_similarity(protein_np, embedding.reshape(1, -1))

    num_indices = int(len(prompt) * percentage / 100)

    # Ensure we select at least 1 index
    num_indices = max(1, num_indices)

    # Find the top 10 indices with the lowest cosine similarity
    indices = np.argsort(similarities.flatten())[:num_indices]

    # Store the indices in the embeddings_dict
    embeddings_dict[entry]['indices'] = indices.tolist()

    return indices.tolist()

In [62]:
def calculatedist(proteinembed,targetembed):
  alldist = []
  for embedding in proteinembed:
    dist = cosine_similarity(embedding, targetembed.reshape(1, -1))
    alldist.append(dist**2)
  return alldist

In [64]:
import math
row = df.iloc[0]
maskedindeces = embedding_masking_model(row['substituted_seq'], model, df, embeddings_dict)

100%|██████████| 8/8 [00:00<00:00, 12.29it/s]
  state_dict = torch.load(
  state_dict = torch.load(


(63, 1536)


In [65]:
def get_random_indices(prompt, percentage):
    """
    Randomly select indices to mask based on the percentage of the prompt's length.
    """
    num_indices = int(len(prompt) * percentage / 100)
    # Ensure we select at least one index
    num_indices = max(1, num_indices)

    # Randomly select unique indices to mask
    return random.sample(range(len(prompt)), num_indices)

In [None]:
#RICHIE RUN THIS!!!!!

import math
allnuminterpro = []
# allpercentmasks = df['percent_deleted'].tolist()
allpercentidentities = []
allindexes = []
allmasked = []
sequence_similarity = []
masked_sequence = []
generated_sequence_list = []
for index, row in df.iterrows():
  if row["UniProt_ID"] not in embeddings_dict:
    continue
  try:
    maskedindeces = embedding_masking_model(row['substituted_seq'], model, df, embeddings_dict,)
  except:
    torch.cuda.empty_cache()
    continue
  if not maskedindeces:
    continue
  allindexes.append(index)
  numinterpro = int(len(row['InterPro'])/10)
  allnuminterpro.append(numinterpro)
  correctmasks = set(np.arange(row['del_start'],row['del_end']+1))
  truncatedpredictions = set(maskedindeces[:len(correctmasks)])
  allmasked.append(truncatedpredictions)
  identical_count = len(truncatedpredictions.intersection(correctmasks))
  percent_identity = (identical_count / len(correctmasks))
  allpercentidentities.append(percent_identity)
  modified_prompt = list(row['substituted_seq'])
  for index in maskedindeces:
      modified_prompt[index] = "_"
  modified_prompt = "".join(modified_prompt)
  masked_sequence.append(modified_prompt)
  protein_prompt = ESMProtein(sequence=modified_prompt)
  sequence_generation = model.generate(
      protein_prompt,
      GenerationConfig(
          track="sequence",
          num_steps=protein_prompt.sequence.count("_") // 2,
          temperature=0.5,
      ),
  )
  generated_sequence = sequence_generation.sequence
  generated_sequence_list.append(generated_sequence)
  # Ensure sequences are of equal length
  if len(generated_sequence) != len(row['sequence']):
      print("Sequences must be of the same length to calculate Hamming distance.")
      sequence_similarity.append(None)
  else:
      blosum_score = 0
      for gen_residue, target_residue in zip(generated_sequence, row['sequence']):
          blosum_val =  matrix[gen_residue][target_residue]
          blosum_score += blosum_val
      blosum_score = blosum_score / len(generated_sequence)
      sequence_similarity.append(blosum_score)
  torch.cuda.empty_cache()

shortenedpercentmasks = [allpercentmasks[i] for i in allindexes]
df2 = pd.DataFrame({
    'Number of Interpro Terms': allnuminterpro,
    'Percentage Deleted': shortenedpercentmasks,
    'Percent Correct': allpercentidentities,
    'Index': allindexes,
    'Masked sites': allmasked,
    'Sequence Similarity': sequence_similarity,
    'Generated Sequences': generated_sequence_list,
    'Masked Sequences': masked_sequence
})

# Save the DataFrame as a CSV file
df2.to_csv('with_seq_similarity_embedding_output_full.csv')


100%|██████████| 8/8 [00:00<00:00, 12.86it/s]


(63, 1536)


100%|██████████| 8/8 [00:00<00:00, 12.71it/s]


(69, 1536)


100%|██████████| 8/8 [00:00<00:00, 12.96it/s]


(490, 1536)


100%|██████████| 8/8 [00:00<00:00, 12.82it/s]


(101, 1536)


100%|██████████| 8/8 [00:00<00:00, 12.69it/s]
