In [1]:
from utils.fasta_utils import *
from utils.proteome_process import *
from utils.prefetcher import *
from utils.network_utils import *

import os
from tqdm import tqdm
import pandas as pd
import hdbscan
import faiss
from sklearn.metrics.pairwise import cosine_distances,euclidean_distances
import numpy as np
import networkx as nx


In [2]:
data_folder = "/path/to/folder" # path to folder where the databse is downloaded ~250mb

model_name = "650m" # use the corresponding model
db_zarr_path = os.path.join(data_folder , f"ESM2_{model_name}_1Sept24_{model_name}.zarr")
faiss_index_path = os.path.join(data_folder , f"ESM2_{model_name}_1Sept24_{model_name}.zarrfaiss_index.bin")
eps_values_path = os.path.join(data_folder , f"ESM2_{model_name}_1Sept24_{model_name}_eps_values_flat_clusters.npy")

HieVi_INPHARED_ordered_annotation = "HieVi_INPHARED_ordered_annotation.csv"

db_zarr_store = zarr.open(db_zarr_path,'r')
db_accessions = db_zarr_store['accessions'][:]

# load annotation file for HieVi vector database
annotations = pd.read_csv(HieVi_INPHARED_ordered_annotation)
annotations= annotations[annotations["Accession"].isin(db_accessions)]
annotations = annotations.set_index("Accession").loc[db_accessions].reset_index()

eps_values = np.load(eps_values_path)
index = faiss.read_index(faiss_index_path)

  exec(code_obj, self.user_global_ns, self.user_ns)


In [57]:
# generate query vectors and save in the same folder

filename = "/path/to/proteome_multifasta.faa"

# @title Compute phage representations
expt_name = "TEST"  # @param {type:"string"}
expt_name = expt_name.replace(' ','_')
output_folder = os.path.dirname(filename) + os.sep 
fasta_path = filename
#model_name = "3b" # This colab works for 650m only
mode = "mean"
query_zarr_path = os.path.join(output_folder,f"{expt_name}_{model_name}.zarr")
!python GenPhageRepresentationsESM2.py {expt_name} {output_folder} {fasta_path} {model_name} {mode}

In [58]:
# load query vectors
query_zarr_path = f"{output_folder}{expt_name}_{model_name}.zarr"
print(query_zarr_path)
query_zarr_store = zarr.open(query_zarr_path,'r')
query_vectors = query_zarr_store['vectors_mean'][:]*1.0
query_accessions = query_zarr_store['accessions'][:]

/home/swapnesh/Downloads/TEST_650m.zarr


In [None]:
# Show nearest neighbours and distances for sanity check
#distances, indices = index.search(query_vectors, 4)
#print(db_accessions[indices],distances)

In [48]:
# extract nearest neighbours

hievi_cluster_prefix = 'HC_'
k_neighbours = 8 # 
distance_threshold = 0.023 # 
distance_in_tree = 2

distances, indices = index.search(query_vectors, k_neighbours)
valid_idx = distances[:,0] < distance_threshold
invalid_idx = np.logical_not(valid_idx)

if len(np.where(invalid_idx)[0]):
    print('Cannot classifiy: ' ,len(np.where(invalid_idx)[0]))
    invalid_query_df = pd.DataFrame({"Accession": query_accessions[invalid_idx]})
    invalid_query_df.to_csv(query_zarr_path[:-5] + "_HieVi_Unclassifieds.csv")

all_indices = np.unique(np.ravel(np.array(indices[valid_idx])))
subset_db_vectors = np.array([db_zarr_store['vectors_mean'][i] for i in all_indices])
distances, indices = index.search(subset_db_vectors, k_neighbours)


D1 = np.squeeze(distances)
I1 = np.squeeze(indices)

all_indices = np.unique(np.array(I1))
all_nearest_accessions = db_accessions[all_indices]

#subset_db_vectors = db_zarr_store['vectors_mean'][all_indices,:]

subset_db_vectors = np.array([db_zarr_store['vectors_mean'][i] for i in all_indices])

print(f"Loaded {subset_db_vectors.shape[0]} Accessions from database." )

Loaded 102 Accessions from database.


In [49]:
# @title Combine query and nearest phages from database data

# Get nearest accessions in tree
nearest_accessions = annotations[annotations["Accession"].isin(all_nearest_accessions)]
nearest_accessions = nearest_accessions.set_index("Accession").loc[all_nearest_accessions].reset_index()

# Combine query and database data
query_df = pd.DataFrame({"Accession": query_accessions[valid_idx],"Query": "yes"})
annotation_df = pd.concat([nearest_accessions, query_df], axis=0)
mprs = np.concatenate((subset_db_vectors, query_vectors[valid_idx]), axis=0)

# Perform clustering
# dist_scaled = euclidean_distances(mprs).astype("double")
# clusterer = hdbscan.HDBSCAN(
#     min_cluster_size=2,
#     n_jobs=32,
#     min_samples=1,
#     allow_single_cluster=False,
#     cluster_selection_method="leaf",
#     metric="precomputed",
#     gen_min_span_tree=True
# )
# clusterer.fit(dist_scaled)

clusterer = hdbscan.HDBSCAN(
    min_cluster_size=2,
    min_samples=1,
    allow_single_cluster=False,
    cluster_selection_method="leaf",
    metric="euclidean",
    gen_min_span_tree=True
)
clusterer.fit(mprs)

annotation_df["HieVi_cluster"] = clusterer.labels_
for i,eps in enumerate(eps_values):
    annotation_df[hievi_cluster_prefix+str(i)] = clusterer.dbscan_clustering(cut_distance=eps,min_cluster_size=2)
#annotation_df

import re
# Function to clean text
# this is required for proper formatting of the gexf file, otherwise opening it in cytoscape fails
def clean_text(text):
    text = re.sub(r'[^\w\s-]', '', text)  # Allow dashes by including "-" in the character class
    text = re.sub(r'%', '', text)  # Remove percent signs
    text = text.strip()  # Remove leading/trailing spaces
    return text    
min_lambda = 1/0.023 # @param {type:"slider", min:-1, max:32, step:1}

node_attributes = ['Accession',"Query", 'Virus_Description', 'Realm', 'Kingdom', 'Phylum',
       'Class', 'Order', 'Family', 'Subfamily', 'Genus','Host_Enveloppe',
       'Host_Isolation', 'Host_species', 'Host_order',
       'Host_phylum', 'Molecule_type','HieVi_cluster', 'HC_0', 'HC_1',
       'HC_2', 'HC_3', 'HC_4', 'HC_5', 'HC_6', 'HC_7', 'HC_8', 'HC_9', 'HC_10',
       'HC_11']

df = annotation_df[node_attributes].copy()
# Apply cleaning to column names and all values
df.columns = [clean_text(col) for col in df.columns]
df = df.applymap(lambda x: clean_text(str(x)))  # Apply to all values


min_lambda =13.00 # @param {type:"slider", min:-1, max:32, step:1}
# Create and save network
G = make_network(clusterer, df,min_lambda=min_lambda)
nx.write_gexf(G, query_zarr_path[:-5] + "_HieVi.gexf")

In [51]:
# save a html for visualization (optional)
from utils.plotter import *
# Plot the graph for small graphs only
show = True # @param {type:"boolean"}
if len(df) > 2048:
    show = False
fig = plot_hierarchical_graph(G,"radial")
# Save to HTML
fig.write_html(query_zarr_path[:-5] + "_HieVi.html")

# Show the plot only if show=True
if show:
    fig.show()

ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed