### Colocalization-based ion image clusternig with NRDC

Publication: "A noise-robust deep clustering of biomolecular ions improves interpretability of mass spectrometric images" by Dan Guo. Bioinformatics. 2023. https://academic.oup.com/bioinformatics/article/39/2/btad067/7028486

Code by Tim Daniel Rose: https://github.com/tdrose/deep_mzClustering

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import pickle
import time
import random
import numpy as np
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score
from pathlib import Path

# Fix random seed
random.seed(0)

In [9]:
# Import necessary Moran Imaging modules
from moran_imaging.ari_balance import balanced_adjusted_rand_index
import moran_imaging.nrdc_clustering as NRDC
from moran_imaging.plotting import position_discrete_colorbar_ticks
from moran_imaging.vmeasure_balance import balanced_homogeneity_completeness_v_measure

# Import necessary NRDC module
import moran_imaging.nrdc_clustering as NRDC

#### Load imaging mass spectrometry data

In [4]:
# Load dataset
data_dir = Path(".").resolve() / "Data"
assert data_dir.exists(), f"The data directory {data_dir} does not exist."

with open(data_dir / 'Zebra_fish_8_clusters_dataset.pickle', 'rb') as file:
    clustering_data_metadata = pickle.load(file) 

image_shape = clustering_data_metadata['utils']['image_shape']
background_mask = clustering_data_metadata['utils']['background_mask']
data_cluster1 = clustering_data_metadata['cluster_1']['data'] 
data_cluster2 = clustering_data_metadata['cluster_2']['data']
data_cluster3 = clustering_data_metadata['cluster_3']['data']
data_cluster4 = clustering_data_metadata['cluster_4']['data']
data_cluster5 = clustering_data_metadata['cluster_5']['data']
data_cluster6 = clustering_data_metadata['cluster_6']['data']
data_cluster7 = clustering_data_metadata['cluster_7']['data']
data_cluster8 = clustering_data_metadata['cluster_8']['data']

In [5]:
dataset = np.hstack((data_cluster1, data_cluster2, data_cluster3, data_cluster4, data_cluster5, data_cluster6, data_cluster7, data_cluster8))

ref_labels = ([1]*data_cluster1.shape[1] + [2]*data_cluster2.shape[1] + [3]*data_cluster3.shape[1] + [4]*data_cluster4.shape[1] + [5]*data_cluster5.shape[1] +
              [6]*data_cluster6.shape[1] + [7]*data_cluster7.shape[1] + [8]*data_cluster8.shape[1])

total_num_pixels = np.prod(image_shape)
num_mz_bins = dataset.shape[1]

#### Noise-robust deep clustering (NRDC) workflow

In [None]:
# If you have access to a GPU, set use_gpu to True
start_time = time.time()

NRDC_cluster_model = NRDC.Deep_Clustering(dataset, np.invert(background_mask), image_shape, num_cluster=8, lr=0.0001, knn=True, k=5, use_gpu="auto", random_seed=0)

device(type='mps')

In [None]:
cae, CLUST = NRDC_cluster_model.train()
NRDC_labels = NRDC_cluster_model.inference(cae, CLUST)

end_time = time.time()
compute_time = end_time - start_time
print(f"Noise robust deep clustering execution time: {compute_time:.6f} seconds")

In [10]:
# Deep clustering performance 
ARI = np.round(adjusted_rand_score(ref_labels, NRDC_labels), 4)
print('Adjusted Rand index:', ARI)
AMI = np.round(adjusted_mutual_info_score(ref_labels, NRDC_labels, average_method='arithmetic'), 4)
print('Adjusted mutual information:', AMI)
BARI = np.round(balanced_adjusted_rand_index(np.array(ref_labels), np.array(NRDC_labels)), 4)
print('Balanced adjusted Rand index:', BARI)
balanced_homogeneity, balanced_completeness, balanced_v_measure = np.round(balanced_homogeneity_completeness_v_measure(np.array(ref_labels), np.array(NRDC_labels)), 4)
print('Balanced V-measure:', balanced_v_measure)

Adjusted Rand index: 0.6217
Adjusted mutual information: 0.7191
Balanced adjusted Rand index: 0.4849
Balanced V-measure: 0.7199
