In [1]:
# step 0 installs
%pip install transformers torch pandas numpy matplotlib networkx seaborn

Collecting transformersNote: you may need to restart the kernel to use updated packages.

  Using cached transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
Collecting torch
  Downloading torch-2.4.1-cp38-cp38-win_amd64.whl.metadata (27 kB)
Collecting pandas
  Using cached pandas-2.0.3-cp38-cp38-win_amd64.whl.metadata (18 kB)
Collecting numpy
  Downloading numpy-1.24.4-cp38-cp38-win_amd64.whl.metadata (5.6 kB)
Collecting matplotlib
  Using cached matplotlib-3.7.5-cp38-cp38-win_amd64.whl.metadata (5.8 kB)
Collecting networkx
  Downloading networkx-3.1-py3-none-any.whl.metadata (5.3 kB)
Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting filelock (from transformers)
  Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting huggingface-hub<1.0,>=0.23.2 (from transformers)
  Using cached huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting pyyaml>=5.1 (from transformers)
  Using cached PyYAML-6.0.3-cp38-cp38-win_a

In [4]:
%pip install scikit-learn

Collecting scikit-learnNote: you may need to restart the kernel to use updated packages.

  Using cached scikit_learn-1.3.2-cp38-cp38-win_amd64.whl.metadata (11 kB)
Collecting scipy>=1.5.0 (from scikit-learn)
  Using cached scipy-1.10.1-cp38-cp38-win_amd64.whl.metadata (58 kB)
Collecting joblib>=1.1.1 (from scikit-learn)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=2.0.0 (from scikit-learn)
  Using cached threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Using cached scikit_learn-1.3.2-cp38-cp38-win_amd64.whl (9.3 MB)
Using cached joblib-1.4.2-py3-none-any.whl (301 kB)
Using cached scipy-1.10.1-cp38-cp38-win_amd64.whl (42.2 MB)
Using cached threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn
Successfully installed joblib-1.4.2 scikit-learn-1.3.2 scipy-1.10.1 threadpoolctl-3.5.0


In [5]:
# step 0 imports
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
import random
import seaborn as sns

In [None]:
# step 1: Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.08s/it]


In [None]:
# step 2 function to parse FASTA files
def parse_fasta(file_path):
    with open(file_path) as f:
        seq = []

        for line in f:
            line = line.strip()
            if line.startswith('>'):
                continue
            else:
                seq.append(line)
    return seq



In [None]:
# step 3 parse sequences from data file
strain_1 = parse_fasta('data\GCA_006094915.1\GCA_006094915.1_ASM609491v1_genomic.fna')[0]
strain_2 = parse_fasta('data\GCA_026167765.1\GCA_026167765.1_ASM2616776v1_genomic.fna')[0]
strain_3 = parse_fasta('data\GCA_900607265.1\GCA_900607265.1_BPH2003_genomic.fna')[0]
strain_4 = parse_fasta('data\GCA_900620245.1\GCA_900620245.1_BPH2947_genomic.fna')[0]

In [23]:
# step 4: Define masked embedding function
def get_masked_embedding(sequence):
    tokens = tokenizer(sequence, return_tensors="pt", truncation=True, padding="max_length", max_length=tokenizer.model_max_length)
    input_ids = tokens["input_ids"]
    attention_mask = input_ids != tokenizer.pad_token_id

    with torch.no_grad():
        outputs = model(
            input_ids,
            attention_mask=attention_mask,
            encoder_attention_mask=attention_mask,
            output_hidden_states=True
        )
    embeddings = outputs.hidden_states[-1]
    attention_mask = attention_mask.unsqueeze(-1)
    masked_embeddings = embeddings * attention_mask
    mean_embedding = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1)
    return mean_embedding.squeeze().numpy()

In [None]:
embedding_for_strain_1 = get_masked_embedding(strain_1)
embedding_for_strain_2 = get_masked_embedding(strain_2)
embedding_for_strain_3 = get_masked_embedding(strain_3)
embedding_for_strain_4 = get_masked_embedding(strain_4) 