In [None]:
# step 0 install required packages TODO: rename this notebook
%pip install transformers torch pandas numpy matplotlib networkx seaborn scikit-learn umap-learn

In [3]:
# step 1 import required packages
import os

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
import random
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# step 2: Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.67s/it]


In [5]:
# step 3: function to parse FASTA files
def parse_fasta(file_path: str):
    '''
    Parse fasta file (.fna) file into list of strings

    Parameters
    ----------
    file_path : str
        Path to the fasta sequence file
    
    Returns
    -------
    seq : List
        The sequence parsed so that each element is a line from the fasta file
    '''
    with open(file_path) as f:
        seq = []
        for line in f:
            line = line.rstrip()
            # do not parse the header line
            if line.startswith('>'):
                continue
            else:
                seq.append(line)
    return seq

In [9]:
# step 4: parse sequences from data file, at the moment, NOTE: at the moment
# we are only considering the first line of the sequence
strain_1 = parse_fasta(os.path.join('data', 'GCA_006094915.1', 'GCA_006094915.1_ASM609491v1_genomic.fna'))[0]
strain_2 = parse_fasta(os.path.join('data', 'GCA_026167765.1', 'GCA_026167765.1_ASM2616776v1_genomic.fna'))[0]
strain_3 = parse_fasta(os.path.join('data', 'GCA_900607265.1', 'GCA_900607265.1_BPH2003_genomic.fna'))[0]
strain_4 = parse_fasta(os.path.join('data', 'GCA_900620245.1', 'GCA_900620245.1_BPH2947_genomic.fna'))[0]

print(strain_1)

ATGTCGGAAAAAGAAATTTGGGAAAAAGTGCTTGAAATTGCTCAAGAAAAATTATCAGCTGTAAGTTACTCAACTTTCCT


In [59]:
# step 5: Define masked embedding function
def get_masked_embedding(sequence: str):
    '''
    Create an embedding of a genome sequence

    Parameters
    ----------
    sequence : str
        The sequence we parsed earlier, as a single string

    Returns
    -------
    np.ndarray
        The mean embedding for the tokenized sequence
    '''
    # this splits the sequence into input:
    #  - IDs e.g. [2, 312, ... , 3671] which each correspond to a different 6-mer string in the
    # tokenizer's vocabulary
    #  - an attention mask e.g. [1, 1, 1, 0, 0] which tells the embedding model which IDs to
    # process. In our case (for now), it'll be tensor of 1s since we are including every 6-mer
    # string
    tokens = tokenizer(sequence, return_tensors="pt")
    # input_ids = tokens["input_ids"][0].tolist()
    input_ids = tokens["input_ids"]
    # print(tokenizer.convert_ids_to_tokens(input_ids))
    # print(input_ids)
    attention_mask = tokens['attention_mask']
    with torch.no_grad():
        outputs = model(
            input_ids,
            attention_mask=attention_mask,
            encoder_attention_mask=attention_mask,
            output_hidden_states=True
        )
    # outputs is an encoding for each *token* for each layer of the (32-layer) Transformer,
    # it has length 33 since layer 0 is an 'embedding layer'
    # to get the 'final' embeddings, take the last layer NOTE: we may not always want the last
    # layer
    embeddings = outputs.hidden_states[-1]
    print(embeddings.shape)
    print(embeddings)
    attention_mask = attention_mask.unsqueeze(-1)
    masked_embeddings = embeddings * attention_mask
    mean_embedding = masked_embeddings.sum(dim=1) / attention_mask.sum(dim=1)
    return mean_embedding.squeeze().numpy()

In [60]:
# step 5: generate mean embeddings for each strain
embedding_for_strain_1 = get_masked_embedding(strain_1)
print(len(embedding_for_strain_1))
# embedding_for_strain_2 = get_masked_embedding(strain_2)
# embedding_for_strain_3 = get_masked_embedding(strain_3)
# embedding_for_strain_4 = get_masked_embedding(strain_4)

torch.Size([1, 16, 2560])
tensor([[[-0.0720,  0.2746,  0.0986,  ...,  0.0417, -0.0201,  0.2377],
         [ 0.3900, -0.5494,  0.8788,  ..., -0.4445,  0.5566,  0.3560],
         [ 0.4664, -0.9278,  0.3667,  ..., -0.2012,  0.7068, -0.1666],
         ...,
         [-0.7106,  0.3668,  0.8364,  ..., -0.6543, -0.5756,  0.1532],
         [ 0.6815, -0.2832,  0.1981,  ...,  0.6719,  0.0036,  0.3221],
         [-0.5512, -0.6719, -0.2787,  ...,  0.5013,  0.5288,  0.0133]]])
2560


In [None]:
# step 6: Create dataframe with embeddings
headers = ['strain_1', 'strain_2', 'strain_3', 'strain_4']
embeddings = [
    embedding_for_strain_1,
    embedding_for_strain_2,
    embedding_for_strain_3,
    embedding_for_strain_4
]

df_all = pd.DataFrame({
    'Header': headers,
    'embedding': embeddings
})

# step 7: Compute cosine similarity matrix
X = np.stack(df_all['embedding'].values)
similarity_matrix = cosine_similarity(X)

# step 8: Visualize similarity matrix
plt.figure(figsize=(8, 6))
sns.heatmap(similarity_matrix, xticklabels=df_all['Header'], yticklabels=df_all['Header'],
            annot=True, cmap='viridis')
plt.title("Cosine Similarity Heatmap")
plt.show()

# step 9: Clustering using DBSCAN
from sklearn.cluster import DBSCAN
clustering = DBSCAN(eps=0.1, min_samples=2, metric='cosine').fit(X)
df_all['Cluster'] = clustering.labels_

In [None]:
# step 10: Build similarity graph
import networkx as nx

threshold = 0.85  # Adjust as needed
G = nx.Graph()
for i, label_i in enumerate(df_all['Header']):
    G.add_node(label_i)
    for j in range(i+1, len(df_all)):
        if similarity_matrix[i][j] >= threshold:
            G.add_edge(label_i, df_all['Header'][j], weight=similarity_matrix[i][j])

# step 11: Visualise similarity graph
plt.figure(figsize=(8, 6))
pos = nx.spring_layout(G, seed=42)
nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=800)
nx.draw_networkx_edges(G, pos, edge_color='gray')
nx.draw_networkx_labels(G, pos, font_size=10)
edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
plt.title("Genome Similarity Graph (Cosine ≥ 0.85)")
plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# step 12: PCA visualisation
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)
jitter = np.random.normal(0, 0.01, X_pca.shape)
X_pca_jittered = X_pca + jitter

unique_clusters = sorted(df_all['Cluster'].unique())
color_map = {cluster: plt.cm.tab10(i % 10) for i, cluster in enumerate(unique_clusters)}
colors = [color_map[c] for c in df_all['Cluster']]

plt.figure(figsize=(8,6))
plt.scatter(X_pca_jittered[:,0], X_pca_jittered[:,1], c=colors)
for i, label in enumerate(df_all['Header']):
    plt.text(X_pca_jittered[i,0], X_pca_jittered[i,1], label, fontsize=8)
plt.title("PCA of Genome Embeddings with Clustering")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.grid(True)
plt.show()

In [None]:
%pip install umap-learn

In [None]:
# step 13: euclidean distance matrix and visualisation
from sklearn.metrics import pairwise_distances
distance_matrix = pairwise_distances(X, metric='euclidean')

sns.heatmap(distance_matrix, xticklabels=df_all['Header'], yticklabels=df_all['Header'], annot=True, cmap='coolwarm')
plt.title("Euclidean Distance Heatmap")
plt.show()

In [None]:
# step 14: UMAP visualisation for non linear relationships between embeddings
from umap import UMAP
umap = UMAP(n_neighbors=5, min_dist=0.3, metric='cosine')
X_umap = umap.fit_transform(X)

plt.figure(figsize=(8,6))
plt.scatter(X_umap[:,0], X_umap[:,1], c=colors)
for i, label in enumerate(df_all['Header']):
    plt.text(X_umap[i,0], X_umap[i,1], label, fontsize=8)
plt.title("UMAP projection of Genome Embeddings")
plt.show()

In [None]:
# step 15: Hierarchical Clustering Dendrogram
from scipy.cluster.hierarchy import linkage, dendrogram

linked = linkage(X, 'ward')
plt.figure(figsize=(8, 5))
dendrogram(linked, labels=df_all['Header'].values)
plt.title("Hierarchical Clustering Dendrogram")
plt.show()