In [1]:
import anndata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from kb_python.utils import import_matrix_as_anndata
import anndata 
from scipy.io import mmread
from mpl_toolkits.axes_grid1 import make_axes_locatable

fsize = 15
import sys

def nd(arr):
    return np.asarray(arr).reshape(-1)


def yex(ax):
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]

    # now plot both limits against eachother
    ax.plot(lims, lims, c="k", alpha=0.75, zorder=0)
    ax.set(**{"aspect": "equal", "xlim": lims, "ylim": lims})
    return ax


plt.rcParams.update({"font.size": fsize})
%config InlineBackend.figure_format = 'retina'



In [2]:
from sklearn.mixture import GaussianMixture
from scipy.stats import entropy
def gmm(x, v, comps):
    n_comps = comps.pop(0)

    gm = GaussianMixture(n_components=n_comps, random_state=42)
    labels = gm.fit_predict(v)
    prob = gm.predict_proba(v)
    print(gm.means_, gm.covariances_, sep="\n")
    ent = entropy(prob, axis=1)

    # index of v where low count cell is
    cutoff = 0
    if n_comps == 2:
        ind = np.argmax(ent)
        # log1p_cutoff = v[ind][0]
        cutoff = x[ind]
    elif n_comps > 2:
        # sort means, and pick the range of the top two
        means = np.sort((np.exp(gm.means_) - 1).flatten())
        r = np.logical_and(x > means[-2], x < means[-1])  # make ranage
        df = pd.DataFrame({"ent": ent, "idx": np.arange(ent.shape[0]).astype(int)})[r]
        # get the index (of x) where the entropy is the max (in range r)
        amax = df["ent"].argmax()
        idx = df.iloc[amax]["idx"].astype(int)
        cutoff = x[idx]

    # n_iter -= 1
    n_iter = len(comps)
    if n_iter <= 0:
        return (cutoff, (x > cutoff).sum(), ent)
    return gmm(x[x > cutoff], v[x > cutoff], comps)  # , n_comps, n_iter)



In [3]:
from mx.mx_filter import mx_filter, knee

ModuleNotFoundError: No module named 'mx'

In [None]:
mtx = mmread("/home/cellatlas/human/data/colon/GSM3587010/mito_filter/matrix.mtx").tocsr()

In [None]:
mtx = mtx[nd(mtx.sum(1)>0)].copy()

In [None]:
u, x, v = knee(mtx, 1)

(cutoff, ncells, ent) = gmm(x, v, comps=[2])

In [None]:
cutoff

In [None]:
x[np.argmax(ent)]

In [None]:
means = [9.40710059, 4.89899629]
variances = [0.82530123, 0.25033376]

In [None]:
mtx.shape

In [None]:
# (mtxf, bcsf) = mx_filter(mtx, np.arange(mtx.shape[0]),sum_axis=1, comps=[2], select_axis=None)

In [None]:
s = nd(mtx.sum(1))

In [None]:
fig, axs = plt.subplots(figsize=(15,5), ncols=3, constrained_layout=True)

ax = axs[0]
x = np.log(s)
ax.hist(x, color="#7394B3", edgecolor="k")
ax.axvline(x=np.log(cutoff), color="grey", linestyle="--")

ax.set(**{
    "xlabel": "log(UMI counts)",
    "ylabel": "Frequency"
})

ax = axs[1]
x = np.linspace(np.log(s.min()), np.log(s.max()), 1000)

for mu, sigma, color in zip(means, np.sqrt(variances), ["#7394B3", "#941655"]):
    y = (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-(x - mu)**2 / (2 * sigma**2))
    ax.plot(x,y, linewidth=5, color=color)
# ax.axvline(x=np.log(cutoff), color="grey", linestyle="--")
# ax.scatter(np.log(s), ent)
ax.set(**{
    "xlabel": "log(UMI counts)",
    "ylabel": "Frequency"
})

ax = axs[2]
x = np.sort(s)[::-1]
y = np.arange(x.shape[0])
im = ax.scatter(x,y, c=ent, cmap="Blues_r")
ax.axvline(x=cutoff, color="grey", linestyle="--")
ax.axhline(y[x < cutoff][0], color="grey", linestyle="--")
ax.set(**{
    "xscale": "log",
    "yscale": "log",
    "xlabel": "UMI counts",
    "ylabel": "Frequency"
})
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
   
plt.colorbar(im, cax=cax, label="Entropy")

fig.savefig("figures/mx_filter.png", dpi=300, bbox_inches="tight")
fig.show()