Skip to content

Commit

Permalink
gmm sorting - test
Browse files Browse the repository at this point in the history
  • Loading branch information
ramintoosi committed Jun 25, 2023
1 parent 1d94500 commit 2f2bc3f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
22 changes: 9 additions & 13 deletions ross_backend/resources/funcs/gmm.py
Expand Up @@ -3,27 +3,23 @@
from sklearn.mixture import GaussianMixture as GMM


def gmm_sorter(alignedSpikeMat, ss):
out = dict()
print('GMM')

def gmm_sorter(aligned_spikemat, ss):
g_max = ss.g_max
g_min = ss.g_min
max_iter = ss.max_iter
n_cluster_range = np.arange(g_min + 1, g_max + 1)
scores = []
error = ss.error

scores = []

for n_cluster in n_cluster_range:
clusterer = GMM(n_components=n_cluster, random_state=5, tol=error, max_iter=max_iter)
cluster_labels = clusterer.fit_predict(alignedSpikeMat)
silhouette_avg = silhouette_score(alignedSpikeMat, cluster_labels)
print('For n_cluster={0} ,score={1}'.format(n_cluster, silhouette_avg))
cluster = GMM(n_components=n_cluster, random_state=5, tol=error, max_iter=max_iter)
cluster_labels = cluster.fit_predict(aligned_spikemat)
silhouette_avg = silhouette_score(aligned_spikemat, cluster_labels)
scores.append(silhouette_avg)

k = n_cluster_range[np.argmax(scores)]
clusterer = GMM(n_components=k, random_state=5, tol=error, max_iter=max_iter)
out['cluster_index'] = clusterer.fit_predict(alignedSpikeMat)
print('clusters : ', out['cluster_index'])
cluster = GMM(n_components=k, random_state=5, tol=error, max_iter=max_iter)
cluster_index = cluster.fit_predict(aligned_spikemat)

return out['cluster_index']
return cluster_index
31 changes: 31 additions & 0 deletions test/backend/test_sorting.py
@@ -0,0 +1,31 @@
from dataclasses import dataclass

import numpy as np

from ross_backend.resources.funcs.gmm import gmm_sorter


@dataclass
class Config:
g_max = 10
g_min = 2
max_iter = 1000
error = 1e-5


def generate_data(num_clusters=2, n_per_class=200):
data = np.zeros((num_clusters * n_per_class, 2))
for i in range(num_clusters):
data[i * n_per_class: (i + 1) * n_per_class, :] = np.random.multivariate_normal(
[2 * i, 2 * i], [[0.01, 0], [0, 0.01]], size=(n_per_class,)
)
return data


def test_gmm():
num_clusters = 3
config = Config()
data = generate_data(num_clusters)
cluster_index = gmm_sorter(data, config)

assert len(np.unique(cluster_index)) == num_clusters

0 comments on commit 2f2bc3f

Please sign in to comment.