From 2f2bc3f70e04d38936360810fe179da6c008074a Mon Sep 17 00:00:00 2001 From: Ramin Toosi Date: Sun, 25 Jun 2023 18:04:50 +0330 Subject: [PATCH] gmm sorting - test --- ross_backend/resources/funcs/gmm.py | 22 +++++++++----------- test/backend/test_sorting.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 13 deletions(-) create mode 100644 test/backend/test_sorting.py diff --git a/ross_backend/resources/funcs/gmm.py b/ross_backend/resources/funcs/gmm.py index f1fae28..2d5c9be 100644 --- a/ross_backend/resources/funcs/gmm.py +++ b/ross_backend/resources/funcs/gmm.py @@ -3,27 +3,23 @@ from sklearn.mixture import GaussianMixture as GMM -def gmm_sorter(alignedSpikeMat, ss): - out = dict() - print('GMM') - +def gmm_sorter(aligned_spikemat, ss): g_max = ss.g_max g_min = ss.g_min max_iter = ss.max_iter n_cluster_range = np.arange(g_min + 1, g_max + 1) - scores = [] error = ss.error + scores = [] + for n_cluster in n_cluster_range: - clusterer = GMM(n_components=n_cluster, random_state=5, tol=error, max_iter=max_iter) - cluster_labels = clusterer.fit_predict(alignedSpikeMat) - silhouette_avg = silhouette_score(alignedSpikeMat, cluster_labels) - print('For n_cluster={0} ,score={1}'.format(n_cluster, silhouette_avg)) + cluster = GMM(n_components=n_cluster, random_state=5, tol=error, max_iter=max_iter) + cluster_labels = cluster.fit_predict(aligned_spikemat) + silhouette_avg = silhouette_score(aligned_spikemat, cluster_labels) scores.append(silhouette_avg) k = n_cluster_range[np.argmax(scores)] - clusterer = GMM(n_components=k, random_state=5, tol=error, max_iter=max_iter) - out['cluster_index'] = clusterer.fit_predict(alignedSpikeMat) - print('clusters : ', out['cluster_index']) + cluster = GMM(n_components=k, random_state=5, tol=error, max_iter=max_iter) + cluster_index = cluster.fit_predict(aligned_spikemat) - return out['cluster_index'] + return cluster_index diff --git a/test/backend/test_sorting.py b/test/backend/test_sorting.py new file mode 100644 index 0000000..86a9bd3 --- /dev/null +++ b/test/backend/test_sorting.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass + +import numpy as np + +from ross_backend.resources.funcs.gmm import gmm_sorter + + +@dataclass +class Config: + g_max = 10 + g_min = 2 + max_iter = 1000 + error = 1e-5 + + +def generate_data(num_clusters=2, n_per_class=200): + data = np.zeros((num_clusters * n_per_class, 2)) + for i in range(num_clusters): + data[i * n_per_class: (i + 1) * n_per_class, :] = np.random.multivariate_normal( + [2 * i, 2 * i], [[0.01, 0], [0, 0.01]], size=(n_per_class,) + ) + return data + + +def test_gmm(): + num_clusters = 3 + config = Config() + data = generate_data(num_clusters) + cluster_index = gmm_sorter(data, config) + + assert len(np.unique(cluster_index)) == num_clusters