In [1]:
import hdbscan 
import zarr
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from sklearn import metrics# import adjusted_rand_score, adjusted_mutual_info_score,silhouette_score
from tqdm import tqdm
import matplotlib.pyplot as plt
import networkx as nx
from scipy.signal import find_peaks
import pandas as pd
import faiss

In [2]:
zarr_path = "/media/microscopie-lcb/swapnesh/protein/embeddings/phages/1Sept2024_INPHARED_db_latest/ESM2_650m_1Sept24_650m.zarr"
zarr_store = zarr.open(zarr_path,'r')
#db_vectors = zarr_store['vectors_mean'][:]*1.0
db_accessions = zarr_store['accessions'][:]
#db_vectors.shape

In [3]:
annotations = pd.read_csv('../HieVi_INPHARED_ordered_annotation.csv')
annotations= annotations[annotations["Accession"].isin(db_accessions)]
annotations = annotations.set_index("Accession").loc[db_accessions].reset_index()
hievi_cluster = pd.read_csv('/media/microscopie-lcb/swapnesh/protein/embeddings/phages/1Sept2024_INPHARED_db_latest/ESM2_650m_1Sept24_650m._cluster.csv')
# annotations["Accession"] = pd.Categorical(annotations["Accession"], categories=accessions, ordered=True)
# annotations = annotations.sort_values("Accession")
hievi_cluster_prefix = 'HieVi_cluster_'
annotations.head()

Unnamed: 0,Accession,Virus_Description,Virus_Genome_size,Virus_molGC_(%),Virus_Number_CDS,Realm,Kingdom,Phylum,Class,Order,...,VC_Subcluster_Size,VC_number,VC_subcluster,Adj P-value,Families in VC,Genera in VC,Genus Confidence Score,Orders in VC,Quality,Topology Confidence Score
0,AY319521,Salmonella phage SopEPhi,35155.0,51.3,45.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,44.0,4.0,0.0,1.0,0.0,0.0,1.0,0.0,0.9656,0.9656
1,MW175890,Dompiswa phage TSP7_1,150892.0,39.1,272.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,44.0,913.0,1.0,1.0,0.0,0.0,0.9683,0.0,0.7568,0.7568
2,GU339467,Mycobacterium phage RedRock,53332.0,64.5,90.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,5.0,221.0,1.0,0.99999997,0.0,0.0,1.0,0.0,0.0484,0.0484
3,MF417929,Uncultured Caudovirales phage clone 2F_1,32618.0,39.2,42.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified
4,MH616963,crAssphage sp. isolate ctbg_1,94878.0,28.5,89.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Crassvirales,...,16.0,703.0,1.0,1.0,0.0,0.0,0.725,0.0,0.8536,0.8536


In [4]:
query_zarr_path = "/media/microscopie-lcb/swapnesh/protein/embeddings/phages/Metavirome_INRAE_HieVi/MV_INRAE_ESM2_650m.zarr"
query_zarr_store = zarr.open(query_zarr_path,'r')
query_vectors = query_zarr_store['vectors_mean'][:]*1.0
query_accessions = query_zarr_store['accessions'][:]
query_vectors.shape

(1258, 1280)

In [5]:
faiss_index_path = zarr_path+'faiss_index.bin'
# Load FAISS index and search for nearest neighbors
index = faiss.read_index(faiss_index_path)

In [6]:
eps_values = np.load('eps_values_flat_clusters_650m.npy')
eps_values

array([0.00202933, 0.00389589, 0.00540029, 0.00824194, 0.00924487,
       0.01027566, 0.01244868, 0.01389736, 0.01556891, 0.01665543,
       0.01707331, 0.01860557, 0.01977566, 0.02097361, 0.02164223,
       0.0221437 , 0.02420528, 0.02531965])

In [7]:
k_neighbours = 5
distance_threshold = 3e-2
distances, indices = index.search(query_vectors, k_neighbours)
#outliers = distances[:,0] > distance_threshold


In [8]:
n_clusts = len(np.where([hievi_cluster_prefix in h for h in list(hievi_cluster.columns)])[0])

In [9]:
query_labels = {}
all_nearest_accessions = []
for i,(distance_to_neighbours,neighbours) in enumerate(zip(distances,indices)):
    if distance_to_neighbours[0] > distance_threshold:
        idx_in_hievi = hievi_cluster['Accession'] == db_accessions[neighbour]            
        clust_values = np.squeeze(hievi_cluster[idx_in_hievi][[hievi_cluster_prefix+ str(j) for j in range(n_clusts)]].values)
        query_labels[i] = {"name": query_accessions[i], "cluster_label":-2,'lowest_hievi_cluster': None,'all_clusters':clust_values,'nearest_accessions':None}
    else:
        visited = False
        for neighbour in neighbours:
            if visited:
                break;
            idx_in_hievi = hievi_cluster['Accession'] == db_accessions[neighbour]            
            clust_values = np.squeeze(hievi_cluster[idx_in_hievi][[hievi_cluster_prefix+ str(j) for j in range(n_clusts)]].values)
            lowest_clust_idx = np.where(clust_values!=-1)[0]
            if (len(lowest_clust_idx)>0):
                visited = True
                idx = hievi_cluster['HieVi_cluster_' + str(lowest_clust_idx[0])] == clust_values[lowest_clust_idx[0]]                
                query_labels[i] = {"name": query_accessions[i],"cluster_label":clust_values[lowest_clust_idx[0]],
                                   'lowest_hievi_cluster': lowest_clust_idx[0],
                                   'all_clusters':clust_values,
                                   'nearest_accessions': hievi_cluster[idx]['Accession'].values}            
                all_nearest_accessions += list(hievi_cluster[idx]['Accession'].values)
            else:
                query_labels[i] = {"name": query_accessions[i],"cluster_label":-2,'lowest_hievi_cluster': None,'all_clusters':clust_values,'nearest_accessions':None}            
all_nearest_accessions = np.unique(np.array(all_nearest_accessions))            
len(all_nearest_accessions)

8594

In [10]:
query_labels

{0: {'name': 'Cabbage-00001',
  'cluster_label': 1117,
  'lowest_hievi_cluster': 1,
  'all_clusters': array([  -1, 1117, 1201, 1229, 1214, 1184, 1045,  972,  864,  820,  801,
          759,  711,  674,  655,  634,  544,  506]),
  'nearest_accessions': array(['MN830255', 'MN830254'], dtype=object)},
 1: {'name': 'Cabbage-00002',
  'cluster_label': 1117,
  'lowest_hievi_cluster': 1,
  'all_clusters': array([  -1, 1117, 1201, 1229, 1214, 1184, 1045,  972,  864,  820,  801,
          759,  711,  674,  655,  634,  544,  506]),
  'nearest_accessions': array(['MN830255', 'MN830254'], dtype=object)},
 2: {'name': 'Cabbage-00003',
  'cluster_label': 1378,
  'lowest_hievi_cluster': 8,
  'all_clusters': array([  -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1, 1378, 1322, 1301,
         1237, 1177, 1112, 1081, 1051,  912,  843]),
  'nearest_accessions': array(['KF302033', 'KF302032'], dtype=object)},
 3: {'name': 'Cabbage-00004',
  'cluster_label': 1226,
  'lowest_hievi_cluster': 2,
  'all_clusters'

In [11]:
all_indices = np.squeeze(np.array([np.where(db_accessions==acc)[0] for acc in all_nearest_accessions]))
subset_db_vectors = zarr_store['vectors_mean'][all_indices]*1.0
subset_db_vectors.shape
#all_indices

(8594, 1280)

In [12]:
indices = np.unique(np.ravel(indices))
print(f"Nearest neighbor search completed. Found {len(indices)} unique neighbors.")
# Get nearest accessions in tree
nearest_accessions = annotations[annotations["Accession"].isin(all_nearest_accessions)]
nearest_accessions = nearest_accessions.set_index("Accession").loc[all_nearest_accessions].reset_index()

nearest_accessions.head()

Nearest neighbor search completed. Found 1639 unique neighbors.


Unnamed: 0,Accession,Virus_Description,Virus_Genome_size,Virus_molGC_(%),Virus_Number_CDS,Realm,Kingdom,Phylum,Class,Order,...,VC_Subcluster_Size,VC_number,VC_subcluster,Adj P-value,Families in VC,Genera in VC,Genus Confidence Score,Orders in VC,Quality,Topology Confidence Score
0,AB009866,Staphylococcus phage PVL,41401.0,33.6,64.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,32.0,6.0,0.0,1.0,0.0,0.0,0.996,0.0,0.4272,0.4272
1,AB044554,Staphylococcus prophage phiPV83,45636.0,33.5,68.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,32.0,6.0,0.0,1.0,0.0,0.0,0.996,0.0,0.4272,0.4272
2,AB045978,Staphylococcus phage phiSLT,42942.0,33.3,68.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,72.0,7.0,0.0,1.0,0.0,0.0,0.9237,0.0,0.8381,0.8381
3,AB231700,Microcystis phage LMM01,162109.0,46.0,189.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,3.0,43.0,0.0,0.96819872,0.0,0.0,1.0,0.0,0.839,0.8123
4,AB243556,Staphylococcus virus 108PVL,44857.0,33.5,66.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,32.0,6.0,0.0,1.0,0.0,0.0,0.996,0.0,0.4272,0.4272


In [13]:
# Combine query and database data
query_df = pd.DataFrame({"Accession": query_accessions})
annotation_df = pd.concat([nearest_accessions, query_df], axis=0)
mprs = np.concatenate((subset_db_vectors, query_vectors), axis=0)

# Perform clustering
dist_scaled = euclidean_distances(mprs).astype("double")
clusterer = hdbscan.HDBSCAN(
    min_cluster_size=2,
    n_jobs=32,
    min_samples=1,
    allow_single_cluster=False,
    cluster_selection_method="leaf",
    metric="precomputed",
    gen_min_span_tree=True
)
clusterer.fit(dist_scaled)
annotation_df["HieVi_granular_cluster"] = clusterer.labels_
for i,eps in enumerate(eps_values):
    annotation_df['HieVi_cluster_'+str(i)] = clusterer.dbscan_clustering(cut_distance=eps,min_cluster_size=2)
annotation_df

Unnamed: 0,Accession,Virus_Description,Virus_Genome_size,Virus_molGC_(%),Virus_Number_CDS,Realm,Kingdom,Phylum,Class,Order,...,HieVi_cluster_8,HieVi_cluster_9,HieVi_cluster_10,HieVi_cluster_11,HieVi_cluster_12,HieVi_cluster_13,HieVi_cluster_14,HieVi_cluster_15,HieVi_cluster_16,HieVi_cluster_17
0,AB009866,Staphylococcus phage PVL,41401.0,33.6,64.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,-1,-1,-1,-1,-1,172,160,150,50,40
1,AB044554,Staphylococcus prophage phiPV83,45636.0,33.5,68.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,-1,-1,-1,-1,-1,172,160,150,50,40
2,AB045978,Staphylococcus phage phiSLT,42942.0,33.3,68.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,-1,-1,-1,98,192,172,160,150,50,40
3,AB231700,Microcystis phage LMM01,162109.0,46.0,189.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,46,38,32,33,28,26,26,26,19,15
4,AB243556,Staphylococcus virus 108PVL,44857.0,33.5,66.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,-1,-1,-1,9,192,172,160,150,50,40
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1253,Turnip-00679,,,,,,,,,,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
1254,Turnip-00680,,,,,,,,,,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
1255,Turnip-00681,,,,,,,,,,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
1256,Turnip-00682,,,,,,,,,,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1


In [14]:
# Prepare final DataFrame
node_properties = [
    "Accession", "Virus_Description", "Virus_Genome_size",
    "Virus_molGC_(%)", "Virus_Number_CDS", "Realm", "Kingdom", "Phylum",
    "Class", "Order", "Family", "Subfamily", "Genus", "Lowest_taxa", "tRNAs",
    "VC_cluster", "VC", "VC Status", "VC_Size", "VC_Subcluster",
    "VC_Subcluster_Size", "VC_number", "VC_subcluster","HieVi_granular_cluster"
]
node_properties += [f"HieVi_cluster_{i}" for i in range(len(eps_values))]
annotation_df = annotation_df[node_properties].fillna("Unclassified").astype(str)


In [18]:
import numpy as np
import networkx as nx

def make_network(hdb, df, wt_nan=1e12,min_lambda = None):
    """
    Constructs a directed graph from HDBSCAN's condensed tree.
    
    Parameters:
        hdb: HDBSCAN object with condensed_tree_ attribute.
        df (pd.DataFrame): Dataframe containing node attributes.
        wt_nan (float): Weight assigned when lambda_val is NaN.
    
    Returns:
        nx.DiGraph: A directed graph with nodes and edges from the condensed tree.
    """
    G = nx.DiGraph()
    
    # Add all nodes from the condensed tree
    all_nodes = set(hdb.condensed_tree_._raw_tree['parent']).union(set(hdb.condensed_tree_._raw_tree['child']))
    G.add_nodes_from(all_nodes)

    # Add edges with weights
    for row in hdb.condensed_tree_._raw_tree:
        parent, child, lambda_val = int(row['parent']), int(row['child']), row['lambda_val']
        weight = lambda_val if np.isfinite(lambda_val) else wt_nan
        if min_lambda is not None:
            if weight > min_lambda:
                G.add_edge(parent, child, weight=weight, distance=1 / weight if weight != 0 else wt_nan)
        else:
            G.add_edge(parent, child, weight=weight, distance=1 / weight if weight != 0 else wt_nan)

    # Assign attributes to existing nodes
    for i, (idx, row) in enumerate(df.iterrows()):
        node_id = int(i)
        if node_id in G.nodes:
            G.nodes[node_id].update(row.to_dict())

    return G


# Create and save network
G = make_network(clusterer, annotation_df)
nx.write_gexf(G, query_zarr_path[:-5] + "_HieVi.gexf")


In [88]:
annotation_df

Unnamed: 0,Accession,Virus_Description,Virus_Genome_size,Virus_molGC_(%),Virus_Number_CDS,Realm,Kingdom,Phylum,Class,Order,...,HieVi_cluster_2,HieVi_cluster_3,HieVi_cluster_4,HieVi_cluster_5,HieVi_cluster_6,HieVi_cluster_7,HieVi_cluster_8,HieVi_cluster_9,HieVi_cluster_10,HieVi_cluster_11
281,PP759150,Escherichia phage HH3,44661.0,54.5,58.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,8,5,6,6,7,7,7,7,7,7
483,OZ035806,Felixounavirus NRG857CP1,87084.0,38.8,121.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,8,5,6,6,7,7,7,7,7,7
489,OZ035800,Felixounavirus LF82P5,85898.0,39.0,114.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,-1,1,2,2,2,2,2,2,2,2
493,OZ035796,Felixounavirus NRG857CP1,84421.0,38.8,114.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,5,2,3,3,3,3,3,3,3,3
499,OZ035790,Felixounavirus LF110P4,86072.0,39.1,117.0,Duplodnaviria,Heunggongvirae,Uroviricota,Caudoviricetes,Unclassified,...,-1,8,8,8,9,9,9,9,9,9
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5,PHI-JG5,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
6,PHI-LetoIII,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,...,-1,-1,2,2,2,2,2,2,2,2
7,PHI-PR1,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,...,-1,-1,-1,-1,-1,-1,-1,4,4,4
8,PHI-PR3,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,Unclassified,...,-1,-1,-1,-1,-1,-1,-1,8,8,8
