In [None]:
# step 0 install required packages TODO: rename this notebook
%pip install transformers torch pandas numpy matplotlib networkx seaborn scikit-learn

In [None]:
# step 1 import required packages
import os
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [None]:
# step 2: Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")

In [None]:
# step 3: function to parse FASTA files
def parse_fasta(file_path: str):
    '''
    Parse fasta file (.fna) file into list of strings

    Parameters
    ----------
    file_path : str
        Path to the fasta sequence file
    
    Returns
    -------
    seq : List
        The sequence parsed so that each element is a line from the fasta file
    '''
    with open(file_path) as f:
        seq = []
        for line in f:
            line = line.strip()
            # do not parse the header line
            if line.startswith('>'):
                continue
            else:
                seq.append(line)
    return seq

In [None]:
# step 4: parse sequences from data file, at the moment, NOTE: at the moment
# we are only considering the first line of the sequence
strain_1 = parse_fasta(os.path.join('data', 'GCA_006094915.1', 'GCA_006094915.1_ASM609491v1_genomic.fna'))[0]
strain_2 = parse_fasta(os.path.join('data', 'GCA_026167765.1', 'GCA_026167765.1_ASM2616776v1_genomic.fna'))[0]
strain_3 = parse_fasta(os.path.join('data', 'GCA_900607265.1', 'GCA_900607265.1_BPH2003_genomic.fna'))[0]
strain_4 = parse_fasta(os.path.join('data', 'GCA_900620245.1', 'GCA_900620245.1_BPH2947_genomic.fna'))[0]


In [None]:
# step 5: Define masked embedding function
def get_masked_embedding(sequence: str):
    '''
    Create an embedding of a genome sequence

    Parameters
    ----------
    sequence : str
        The sequence we parsed earlier, as a single string

    Returns
    -------
    np.ndarray
        The mean embedding?
    '''
    # this splits the sequence into input:
    #  - IDs e.g. [2, 312, ... , 3671] which each correspond to a different 6-mer string in the
    # tokenizer's vocabulary
    #  - an attention mask e.g. [1, 1, 1, 0, 0] which tells the embedding model which IDs to
    # process. In our case (for now), it'll be tensor of 1s since we are including every 6-mer
    # string
    tokens = tokenizer(sequence, return_tensors="pt")
    input_ids = tokens["input_ids"]
    attention_mask = tokens['attention_mask']
    with torch.no_grad():
        outputs = model(
            input_ids,
            attention_mask=attention_mask,
            encoder_attention_mask=attention_mask,
            output_hidden_states=True
        )
    # outputs is an encoding for each *token* for each layer of the (32-layer) Transformer,
    # it has length 33 since layer 0 is an 'embedding layer'
    # to get the 'final' embeddings, take the last layer NOTE: we may not always want the last
    # layer
    embeddings = outputs.hidden_states[-1]
    attention_mask = attention_mask.unsqueeze(-1)
    masked_embeddings = embeddings * attention_mask
    mean_embedding = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1)
    return mean_embedding.squeeze().numpy()

In [None]:
embedding_for_strain_1 = get_masked_embedding(strain_1)
embedding_for_strain_2 = get_masked_embedding(strain_2)
embedding_for_strain_3 = get_masked_embedding(strain_3)
embedding_for_strain_4 = get_masked_embedding(strain_4) 