<a href="https://colab.research.google.com/github/vihan-lakshman/mutagenic/blob/main/masking_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Evaluation Dataset

In [55]:
# prompt: read in '/content/prelim-deletion-validation-dataset-functional-annotations.csv'
import torch
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
df = pd.read_csv('/content/prelim-deletion-validation-dataset-functional-annotations-with-interpro.csv')
print(df.head())

    Entry                     id  \
0  Q80WR5  sp|Q80WR5|CA174_MOUSE   
1  Q10N20   sp|Q10N20|MPK5_ORYSJ   
2  P67015    sp|P67015|SYD_STAAN   
3  P51861   sp|P51861|CDR1_HUMAN   
4  P03772     sp|P03772|PP_LAMBD   

                                            sequence  \
0  MRSRKLTGGVRSSARLRARSYSSASLASARDVTSSTSAKTTCLASS...   
1  MDGAPVAEFRPTMTHGGRYLLYDIFGNKFEVTNKYQPPIMPIGRGA...   
2  MSKRTTYCGLVTEAFLGQEITLKGWVNNRRDLGGLIFVDLRDREGI...   
3  MAWLEDVDFLEDVPLLEDIPLLEDVPLLEDVPLLEDTSRLEDINLM...   
4  MRYYEKIDGSKYRNIWVVGDLHGCYTNLMNKLDTIGFDNKKDLLIS...   

                                       truncated_seq  del_start  del_end  \
0  MRSRKLTGGVRSSARLRARSYSSASLASARDVTSSTSAKKATDRRT...         39       47   
1  MDGAPVAEFRPTMTHGGRYLLYDIFGNKFEVTNKYQPPIMPIGRGA...        274      293   
2  MSKRTTYCGLVTEAFLGQEITLKGWVNNRRDLGGLIFVDLRDREGI...         71      219   
3  MAWLEDVDFLEDVPLLEDIPLLEDVPLLEDVPLLEDTSRLEDINLM...        246      259   
4  MRYYEKIDGSKYRNIWVVGDLHGCYTNLMNKLDTIGFDNKKDLLIS...         77   

In [33]:
df.columns

Index(['Entry', 'id', 'sequence', 'truncated_seq', 'del_start', 'del_end',
       'Protein names', 'Organism', 'Length', 'EC number', 'InterPro',
       'Gene Ontology (GO)', 'Gene Ontology IDs',
       'Gene Ontology (molecular function)',
       'Gene Ontology (cellular component)',
       'Gene Ontology (biological process)', 'Subcellular location [CC]'],
      dtype='object')

In [56]:
# prompt: for each row, identify the substring that starts at del_start (counting from 1, not 0), and ends at del_end, and for each letter in that substring, it is a amino acid residue that should be randomly replaced with one of the 20 amino acids. save this as "substituted_seq" in a column after "truncated_seq"

import random

# List of standard amino acid single-letter codes
amino_acids = "ACDEFGHIKLMNPQRSTVWY"

def substitute_substring(sequence, del_start, del_end):
    """Substitutes amino acids in a substring with random amino acids."""

    # Adjust del_start to be 0-indexed
    del_start -= 1

    if del_start < 0 or del_end > len(sequence) or del_start >= del_end:
        return sequence  # Handle invalid indices

    substituted_seq = list(sequence)
    for i in range(del_start, del_end):
        substituted_seq[i] = random.choice(amino_acids)

    return "".join(substituted_seq)


# Apply the function to each row of the DataFrame
df['substituted_seq'] = df.apply(lambda row: substitute_substring(row['sequence'], row['del_start'], row['del_end']), axis=1)
df['substituted_seq'].head()

Unnamed: 0,substituted_seq
0,MRSRKLTGGVRSSARLRARSYSSASLASARDVTSSTSAWTSDGDAW...
1,MDGAPVAEFRPTMTHGGRYLLYDIFGNKFEVTNKYQPPIMPIGRGA...
2,MSKRTTYCGLVTEAFLGQEITLKGWVNNRRDLGGLIFVDLRDREGI...
3,MAWLEDVDFLEDVPLLEDIPLLEDVPLLEDVPLLEDTSRLEDINLM...
4,MRYYEKIDGSKYRNIWVVGDLHGCYTNLMNKLDTIGFDNKKDLLIS...


In [57]:
df['sequence'].head()

Unnamed: 0,sequence
0,MRSRKLTGGVRSSARLRARSYSSASLASARDVTSSTSAKTTCLASS...
1,MDGAPVAEFRPTMTHGGRYLLYDIFGNKFEVTNKYQPPIMPIGRGA...
2,MSKRTTYCGLVTEAFLGQEITLKGWVNNRRDLGGLIFVDLRDREGI...
3,MAWLEDVDFLEDVPLLEDIPLLEDVPLLEDVPLLEDTSRLEDINLM...
4,MRYYEKIDGSKYRNIWVVGDLHGCYTNLMNKLDTIGFDNKKDLLIS...


In [58]:
# Dictionary to hold the results
embeddings_dict = {}

# Iterate through each row in the DataFrame
for _, row in df.iterrows():
    entry = row['Entry']
    interpro = row['InterPro']

    # Skip rows where 'Interpro' is None
    if pd.isna(interpro) or not interpro.strip():
        continue

    # Split the InterPro IDs by semicolons
    interpro_ids = interpro.split(';')
    interpro_ids = interpro_ids[:-1]

    # Initialize entry in the dictionary if not present
    if entry not in embeddings_dict:
        embeddings_dict[entry] = {
            'InterPro_ids': interpro_ids
        }


In [37]:
embeddings_dict

{'Q80WR5': {'InterPro_ids': ['IPR031530']},
 'Q10N20': {'InterPro_ids': ['IPR011009',
   'IPR050117',
   'IPR003527',
   'IPR008351',
   'IPR000719',
   'IPR017441',
   'IPR008271']},
 'P67015': {'InterPro_ids': ['IPR004364',
   'IPR006195',
   'IPR045864',
   'IPR004524',
   'IPR047089',
   'IPR002312',
   'IPR047090',
   'IPR004115',
   'IPR029351',
   'IPR012340',
   'IPR004365']},
 'P51861': {'InterPro_ids': ['IPR048506', 'IPR048507']},
 'P03772': {'InterPro_ids': ['IPR050126',
   'IPR004843',
   'IPR029052',
   'IPR006186']},
 'P33171': {'InterPro_ids': ['IPR041709',
   'IPR050055',
   'IPR004161',
   'IPR033720',
   'IPR031157',
   'IPR027417',
   'IPR005225',
   'IPR000795',
   'IPR009000',
   'IPR009001',
   'IPR004541',
   'IPR004160']},
 'B8DF05': {'InterPro_ids': ['IPR001790',
   'IPR043141',
   'IPR022973',
   'IPR047865',
   'IPR002363']},
 'Q6AAB8': {'InterPro_ids': ['IPR000941',
   'IPR036849',
   'IPR029017',
   'IPR020810',
   'IPR020809',
   'IPR020811']},
 'Q9ZMM5': 

In [38]:
pip install esm



In [9]:
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig

# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login()



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [10]:
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.00 [00:00<?, ?B/s]

data/entry_list_safety_29026.list:   0%|          | 0.00/1.60M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.30k [00:00<?, ?B/s]

data/1utn.pdb:   0%|          | 0.00/569k [00:00<?, ?B/s]

data/ParentChildTreeFile.txt:   0%|          | 0.00/595k [00:00<?, ?B/s]

data/esm3_entry.list:   0%|          | 0.00/1.93M [00:00<?, ?B/s]

hyperplanes_8bit_58641.npz:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

hyperplanes_8bit_68103.npz:   0%|          | 0.00/34.9M [00:00<?, ?B/s]

data/interpro2keywords.csv:   0%|          | 0.00/7.32M [00:00<?, ?B/s]

(…)ata/interpro_29026_to_keywords_58641.csv:   0%|          | 0.00/10.1M [00:00<?, ?B/s]

keyword_idf_safety_filtered_58641.npy:   0%|          | 0.00/469k [00:00<?, ?B/s]

(…)ord_vocabulary_safety_filtered_58641.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]

data/keywords.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]

data/tag_dict_4.json:   0%|          | 0.00/691k [00:00<?, ?B/s]

data/tag_dict_4_safety_filtered.json:   0%|          | 0.00/569k [00:00<?, ?B/s]

(…)0_residue_annotations_gt_1k_proteins.csv:   0%|          | 0.00/109k [00:00<?, ?B/s]

tfidf_safety_filtered_58641.pkl:   0%|          | 0.00/2.02M [00:00<?, ?B/s]

esm3_function_decoder_v0.pth:   0%|          | 0.00/1.30G [00:00<?, ?B/s]

esm3_sm_open_v1.pth:   0%|          | 0.00/2.80G [00:00<?, ?B/s]

esm3_structure_encoder_v0.pth:   0%|          | 0.00/62.3M [00:00<?, ?B/s]

esm3_structure_decoder_v0.pth:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

  state_dict = torch.load(


In [15]:
import torch.nn as nn
from esm.tokenization import InterProQuantizedTokenizer
from esm.utils.types import FunctionAnnotation
def get_keywords_from_interpro(
    interpro_annotations,
    interpro2keywords=InterProQuantizedTokenizer().interpro2keywords,
):
    keyword_annotations_list = []
    for interpro_annotation in interpro_annotations:
        keywords = interpro2keywords.get(interpro_annotation.label, [])
        keyword_annotations_list.extend([
            FunctionAnnotation(
                label=keyword,
                start=interpro_annotation.start,
                end=interpro_annotation.end,
            )
            for keyword in keywords
        ])
    return keyword_annotations_list

In [16]:
#protein that only has one function?
#longer sequences of all As, all Gs, or completely random
def get_label_embedding(interpro_label,sequence):
  hostProtein = ESMProtein(sequence=sequence)
  embedding_function = model.encoder.function_embed
  hostProtein.function_annotations = get_keywords_from_interpro([FunctionAnnotation(label=interpro_label, start=1, end=len(sequence))])
  hostProtein_tensor = model.encode(hostProtein)
  device = hostProtein_tensor.function.device  # Get the device of protein2_tensor.function
  embedding_function = embedding_function.to(device)  # Move embedding_function to the device

  function_embed = torch.cat(
      [
          embed_fn(funcs.to(device)) # Ensure funcs is on the same device
          for embed_fn, funcs in zip(
              embedding_function, hostProtein_tensor.function.unbind(-1)
          )
      ],
      -1,
  )

  if function_embed.shape[0] >= 3:
      row_sum = function_embed.sum(dim=0)  # Sum all rows
      row_avg = row_sum / (function_embed.shape[0] - 2)  # Divide by (number of rows - 2)
      row_avg_np = row_avg.cpu().detach().type(torch.float32).numpy()
      return row_avg_np
  else:
      return None

In [61]:
import numpy as np
for entry, interpro_ids in embeddings_dict.items():
  embeddings = []
  for interpro_id in interpro_ids['InterPro_ids']:
    embeddings.append(get_label_embedding(interpro_id,"A"))
  avg_embedding = np.mean(embeddings, axis=0)
  embeddings_dict[entry]['embedding'] = avg_embedding


In [18]:
embeddings_dict

{'Q80WR5': {'InterPro_ids': ['IPR031530'],
  'embedding': array([-0.48046875, -0.31640625,  0.95703125, ...,  0.09082031,
          0.4140625 ,  1.1796875 ], dtype=float32)},
 'Q10N20': {'InterPro_ids': ['IPR011009',
   'IPR050117',
   'IPR003527',
   'IPR008351',
   'IPR000719',
   'IPR017441',
   'IPR008271'],
  'embedding': array([ 0.07560512,  0.15066965, -0.48082623, ..., -0.06509835,
          0.46777344,  0.07798549], dtype=float32)},
 'P67015': {'InterPro_ids': ['IPR004364',
   'IPR006195',
   'IPR045864',
   'IPR004524',
   'IPR047089',
   'IPR002312',
   'IPR047090',
   'IPR004115',
   'IPR029351',
   'IPR012340',
   'IPR004365'],
  'embedding': array([-0.2986062 ,  0.23703836, -0.1981534 , ...,  0.08385121,
         -0.07188832,  0.30217952], dtype=float32)},
 'P51861': {'InterPro_ids': ['IPR048506', 'IPR048507'],
  'embedding': array([-0.1328125 , -0.03613281,  0.02978516, ..., -0.14404297,
         -0.26611328, -0.7246094 ], dtype=float32)},
 'P03772': {'InterPro_ids': ['I

In [64]:
def embedding_masking_model(
    prompt,
    model,
    df,
    embeddings_dict
):
    """
    Helper function to process a protein sequence, calculate similarities,
    and return indices for masking.

    Args:
        prompt (str): The protein sequence to be processed.
        model: The model used for protein generation and embeddings.
        df (pd.DataFrame): DataFrame containing protein data.
        embeddings_dict (dict): Dictionary storing embeddings and other details.

    Returns:
        List[int]: Indices used for masking in the sequence.
    """
    # Create an ESMProtein object
    protein = ESMProtein(sequence=prompt)

    # Configure the model for generation
    generation_config = GenerationConfig(track="function", num_steps=8)

    # Generate the protein
    generated_protein = model.generate(protein, generation_config)

    # Check if function annotations are available
    entry = df.loc[df['substituted_seq'] == prompt, 'Entry'].iloc[0]
    if generated_protein.function_annotations is None:
        embeddings_dict[entry]['hamming_distance'] = None
        return []

    # Getting embedding for the protein
    protein_tensor = model.encode(generated_protein)
    embedding_function = model.encoder.function_embed
    device = protein_tensor.function.device  # Get the device of protein_tensor.function
    embedding_function = embedding_function.to(device)  # Move embedding_function to the device

    function_embed = torch.cat(
        [
            embed_fn(funcs.to(device))  # Ensure funcs is on the same device
            for embed_fn, funcs in zip(
                embedding_function, protein_tensor.function.unbind(-1)
            )
        ],
        -1,
    )

    # Exclude start and end tokens
    function_embed = function_embed[1:-1, :]

    # Convert the protein_tensor.function to a NumPy array
    protein_np = function_embed.cpu().detach().type(torch.float32).numpy()

    # Retrieve target sequence and embedding
    target_seq = df.loc[df['substituted_seq'] == prompt, 'sequence'].iloc[0]
    embedding = embeddings_dict[entry]['embedding']

    # Calculate cosine similarity
    similarities = cosine_similarity(protein_np, embedding.reshape(1, -1))

    # Find the top 10 indices with the lowest cosine similarity
    indices = np.argsort(similarities.flatten())[:10]

    # Store the indices in the embeddings_dict
    embeddings_dict[entry]['indices'] = indices.tolist()

    return indices.tolist()

In [67]:
for idx, prompt in enumerate(df['substituted_seq']):
    indices = embedding_masking_model(prompt, model, df, embeddings_dict)
    if not indices:
      continue
    # Replace locations in prompt corresponding to indices in 'indices' with "_"
    modified_prompt = list(prompt)
    for index in indices:
        modified_prompt[index] = "_"
    modified_prompt = "".join(modified_prompt)
    protein_prompt = ESMProtein(sequence=modified_prompt)
    sequence_generation = model.generate(
        protein_prompt,
        GenerationConfig(
            track="sequence",
            num_steps=protein_prompt.sequence.count("_") // 2,
            temperature=0.5,
        ),
    )
    print("Generated Sequence: " + str(sequence_generation.sequence))
    generated_sequence = sequence_generation.sequence
    print("Target Sequence: " + str(target_seq))
    # Ensure sequences are of equal length
    if len(generated_sequence) != len(target_seq):
        raise ValueError("Sequences must be of the same length to calculate Hamming distance.")

    # Calculate Hamming distance
    hamming_distance = sum(1 for gen, target in zip(generated_sequence, target_seq) if gen != target)

    # Print results
    print("Hamming Distance:", hamming_distance)
    embeddings_dict[entry]['hamming_distance'] = hamming_distance


100%|██████████| 8/8 [00:02<00:00,  2.98it/s]
100%|██████████| 8/8 [00:04<00:00,  1.93it/s]
100%|██████████| 5/5 [00:02<00:00,  1.93it/s]


Generated Sequence: MDGAPVAEFRPTMTHGGRYLLYDIFGNKFEVTNKYQPPIMPIGRGAYGIVCSVMNFETREMVAIKKIANAFNNDMDAKRTLREIKLLRHLDHENIIGIRDVIPPPIPQAFNDVYIATELMDTDLHHIIRSNQELSEEHCQYFLYQILRGLKYIHSANVIHRDLKPSNLLLNANCDLKICDFGLARPSSESDMMTEYVVTRWYRAPELLLNSTDYSAAIDVWSVGCIFMELINRQPLFPGRDHMHQMRLITEVIGTPTDDELGFIRNEDARKYMGFDDIETFVRKYCRNDHRWLAALDLIERMLTFNPLQRITVEEALDHPYLERLHDIADEPICLEPFDFSFEDQALNEDQMKQLIFNEAIEMNPNIRY
Target Sequence: MDGAPVAEFRPTMTHGGRYLLYDIFGNKFEVTNKYQPPIMPIGRGAYGIVCSVMNFETREMVAIKKIANAFNNDMDAKRTLREIKLLRHLDHENIIGIRDVIPPPIPQAFNDVYIATELMDTDLHHIIRSNQELSEEHCQYFLYQILRGLKYIHSANVIHRDLKPSNLLLNANCDLKICDFGLARPSSESDMMTEYVVTRWYRAPELLLNSTDYSAAIDVWSVGCIFMELINRQPLFPGRDHMHQMRLITEVIGTPTDDELGFIRNEDARKYMRHLPQYPRRTFASMFPRVQPAALDLIERMLTFNPLQRITVEEALDHPYLERLHDIADEPICLEPFSFDFEQKALNEDQMKQLIFNEAIEMNPNIRY
Hamming Distance: 24


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 148.00 MiB. GPU 0 has a total capacity of 14.75 GiB of which 83.06 MiB is free. Process 5192 has 14.66 GiB memory in use. Of the allocated memory 13.69 GiB is allocated by PyTorch, and 872.62 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)