In [None]:
import numpy as np
import pandas as pd
import nd2reader
import matplotlib.pyplot as plt
import numba
import dask
import dask.array as da
from cytoolz import partial, compose, juxt
import numpy_indexed
import segmentation

In [None]:
nd2 = nd2reader.ND2Reader(
    "/n/scratch2/jqs1/190922/190922_photobleaching_greens/GFP_photobleaching_100pct_100ms_0001.nd2"
)
img_stack = np.stack([nd2.get_frame_2D(v=0, t=t) for t in range(10)])
img = img_stack[0]

In [None]:
labels = segmentation.segment(img_stack[0])

In [None]:
plt.imshow(img.ravel()[np.arange(img.size)].reshape(img.shape))

In [None]:
def mean_naive(labels, img_stack, skip0=True):
    labels_list = np.arange(1 if skip0 else 0, labels.max() + 1)
    for img in img_stack:
        for label in labels_list:
            img[labels == label].mean()


def mean_pandas(labels, img_stack):
    for img in img_stack:
        pd.DataFrame({"label": labels.ravel(), "value": img.ravel()}).groupby(
            "label"
        ).agg(["mean"])


def mean_npi(labels, img_stack):
    for img in img_stack:
        numpy_indexed.group_by(labels.ravel(), img_stack.ravel(), reduction=np.mean)


def mean_npi2(labels, img_stack, skip0=True):
    for img in img_stack:
        g = numpy_indexed.GroupBy(labels.ravel())
        groups = g.split(img.ravel())
        ret = [
            (key, np.mean(group))
            for key, group in zip(g.unique, groups)
            if key != 0 or not skip0
        ]


def mean_npi3(labels, img_stack, skip0=True):
    g = numpy_indexed.GroupBy(labels.ravel())
    for img in img_stack:
        groups = g.split(img.ravel())
        ret = [
            (key, np.mean(group))
            for key, group in zip(g.unique, groups)
            if key != 0 or not skip0
        ]


def mean_split(labels, img_stack, skip0=True):
    # g = numpy_indexed.GroupBy(labels.ravel())
    keys = labels.ravel()
    sorter = np.argsort(keys, kind="mergesort")
    sorted_ = keys[sorter]
    flag = sorted_[:-1] != sorted_[1:]
    slices = np.concatenate(([0], np.flatnonzero(flag) + 1, [keys.size]))
    unique = sorted_[slices[:-1]]
    for img in img_stack:
        # groups = g.split(img.ravel())
        values = img.ravel()
        values = values[sorter]
        groups = np.split(values, slices[1:-1], axis=0)
        ret = [
            (key, np.mean(group))
            for key, group in zip(unique, groups)
            if key != 0 or not skip0
        ]


def mean_split2(labels, img_stack, skip0=True):
    # g = numpy_indexed.GroupBy(labels.ravel())
    keys = labels.ravel()
    sorter = np.argsort(keys, kind="mergesort")
    sorted_ = keys[sorter]
    flag = sorted_[:-1] != sorted_[1:]
    slices = np.concatenate(([0], np.flatnonzero(flag) + 1, [keys.size]))
    unique = sorted_[slices[:-1]]
    # for img in img_stack:
    # groups = g.split(img.ravel())
    values = img_stack.reshape((img_stack.shape[0], -1))[
        :, sorter
    ]  # .reshape(img_stack.shape)
    # values = img.ravel()
    # values = values[sorter]
    groups = np.split(values, slices[1:-1], axis=1)
    ret = [
        (key, np.mean(group, axis=1))
        for key, group in zip(unique, groups)
        if key != 0 or not skip0
    ]


from numba import prange


@numba.jit(nopython=True, parallel=True)
def mean_vectorized(labels, img_stack, skip0=True):
    max_label = labels.max()
    sums = np.zeros((max_label + 1, img_stack.shape[0]), dtype=img_stack.dtype)
    counts = np.zeros((max_label + 1, 1), dtype=np.uint64)
    for y in prange(img_stack.shape[1]):
        for x in prange(img_stack.shape[2]):
            label = labels[y, x]
            if label == 0 and skip0:
                continue
            sums[label] += img_stack[:, y, x]
            counts[label, 0] += 1
    return sums / counts


from numba import prange


@numba.jit(nopython=True, parallel=True)
def mean_vectorized2(labels, img_stack, skip0=True):
    max_label = labels.max()
    sums = np.zeros((max_label + 1, img_stack.shape[2]), dtype=img_stack.dtype)
    counts = np.zeros((max_label + 1, 1), dtype=np.uint64)
    for y in prange(img_stack.shape[0]):
        for x in prange(img_stack.shape[1]):
            label = labels[y, x]
            if label == 0 and skip0:
                continue
            sums[label] += img_stack[y, x, :]
            counts[label, 0] += 1
    return sums / counts


# def mean_npi2(labels, img_stack):
#    numpy_indexed.group_by(labels.ravel(), img_stack.ravel(), reduction=np.mean)

In [None]:
%timeit mean_pandas(labels, img_stack)

In [None]:
%timeit mean_npi(labels, img_stack)

In [None]:
%timeit mean_npi2(labels, img_stack)

In [None]:
%timeit mean_npi3(labels, img_stack)

In [None]:
%timeit mean_split(labels, img_stack)

In [None]:
%timeit mean_split2(labels, img_stack, skip0=False)

In [None]:
%timeit mean_vectorized(labels, img_stack, skip0=False)

In [None]:
img_stack_T = np.ascontiguousarray(np.moveaxis(img_stack, 0, -1))

In [None]:
%timeit mean_vectorized2(labels, img_stack_T, skip0=False)

In [None]:
import numpy_groupies as npg

%timeit npg.aggregate(labels.ravel(), img_stack.reshape((img_stack.shape[0],-1)), func='mean', axis=1)

In [None]:
%prun mean_npi3(labels, img_stack)

In [None]:
GroupBy

In [None]:
%timeit mean_naive(labels, img_stack[:2])