In [1]:
import numpy as np
from sklearn.mixture import GaussianMixture

import scipy.io
import os
import tqdm
from typing import List


def get_cluster_mlist(
    clsuter_list: np.ndarray, wide: int, mlists_load_fold: str
) -> List[np.ndarray]:
    # file name list
    mlists_list = os.listdir(mlists_load_fold)

    # use to store final result
    cluster_mlist = [[] for _ in range(len(clsuter_list))]

    # go through all mlists and get the mlist for each cluster
    for index in tqdm.tqdm(range(45000), unit="mlist"):
        # read a mlist out of 45000 mlists
        _, mlist = scipy.io.loadmat(
            os.path.join(mlists_load_fold, mlists_list[index])
        ).popitem()
        # (H W D), index start from 1 -> (D H W), index start from 0
        mlist = mlist[:, [2, 0, 1]] - 1
        # [64 512 512] -> [32 512 512]
        mlist[:, 0] = (mlist[:, 0] + 0.5) / 2 - 0.5

        for c in range(len(clsuter_list)):
            # find the submlist that match the position with current cluster
            position = clsuter_list[c]  # (D H W)
            submlist = mlist
            submlist = submlist[submlist[:, 0] >= position[0] - wide]
            submlist = submlist[submlist[:, 0] <= position[0] + wide]
            submlist = submlist[submlist[:, 1] >= position[1] - wide]
            submlist = submlist[submlist[:, 1] <= position[1] + wide]
            submlist = submlist[submlist[:, 2] >= position[2] - wide]
            submlist = submlist[submlist[:, 2] <= position[2] + wide]

            # add the molecular in submlist to cluster_mlist
            for i in range(len(submlist)): cluster_mlist[c].append(submlist[i])

    # mena center each cluster's mlist
    for k in range(len(cluster_mlist)):
        cluster_mlist[k] = np.array(cluster_mlist[k])
        if cluster_mlist[k].shape[0] == 0: continue

        gmm = GaussianMixture(n_components=1)
        gmm.fit(cluster_mlist[k])
        cluster_mlist[k] -= gmm.means_[0]

    return cluster_mlist  # [K, num of mol in each cluster]

In [6]:
# load cluster position
cluster_list = np.array([  # (D H W), index start from 0
    [85, 1603, 523 ],
    [88, 1437, 512 ],
    [86, 1447, 493 ],
    [51, 1546, 685 ],
    [53, 1576, 867 ],
    [57, 1768, 1110],
    [58, 1686, 1214],
    [61, 670 , 482 ],
    [59, 574 , 607 ],
    [62, 531 , 1239],
    [62, 712 , 1416],
    [61, 1059, 1671]
])
# since position in cluster_list find in super resolution image
cluster_list = (cluster_list + 0.5) / 4 - 0.5

# go through all mlists and get the mlist (mean centered) for each cluster
cluster_mlist = get_cluster_mlist(cluster_list, 0.5, "D:/mlists")

# combine all cluster's mlist and fit a gmm
all_mlist = []
for k in range(len(cluster_mlist)): all_mlist.extend(cluster_mlist[k])
gmm = GaussianMixture(n_components=1)
gmm.fit(all_mlist)

# information
print(gmm.covariances_)
print(gmm.means_)

  0%|          | 0/45000 [00:00<?, ?mlist/s]

100%|██████████| 45000/45000 [00:22<00:00, 1963.25mlist/s]


(65, 3)
(116, 3)
(79, 3)
(44, 3)
(71, 3)
(68, 3)
(28, 3)
(132, 3)
(68, 3)
(21, 3)
(55, 3)
(55, 3)
[[[ 0.06378605  0.00073886 -0.00064982]
  [ 0.00073886  0.04276527 -0.00056021]
  [-0.00064982 -0.00056021  0.05327207]]]
[[-1.28021727e-15  1.21199808e-14 -1.54689229e-14]]


