# Circular Otsu

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
from skimage.filters import threshold_otsu


In [2]:
def calc_weighted_variance(x, h):
    omega = h.sum()
    if omega == 0:
        return 0

    mean = (x * h).sum()

    mean /= omega
    sigma = ((x - mean) ** 2 * h).sum()
    return sigma


In [3]:
def otsu(x, h):
    i_min = 9999
    var_min = 1e30
    for i in range(len(x)):
        var = calc_weighted_variance(x[:i], h[:i]) + calc_weighted_variance(x[i:], h[i:])
        if var < var_min:
            var_min = var
            i_min = i
    return x[i_min]

def circular_otsu_naive(x, h):
    H = len(x)//2

    x = np.hstack([x, x + (x[-1] + x[1] - 2 * x[0])])
    h = np.tile(h, 2)

    i_min = 9999
    sigma_min = 1e30
    for i in range(H):
        sigma_1 = calc_weighted_variance(x[i:H+i], h[i:H+i])
        sigma_2 = calc_weighted_variance(x[H+i:2*H+i], h[H+i:2*H+i])
        sigma = sigma_1 + sigma_2

        if sigma < sigma_min:
            sigma_min = sigma
            i_min = i

    return x[i_min % (2*H)], x[(i_min + H) % (2*H)]

def circular_otsu_less(x, h):
    H = len(x)//2

    x = np.hstack([x, x + (x[-1] + x[1] - 2 * x[0])])
    h = np.tile(h, 2)

    i_min = 9999
    sigma_min = 1e30
    sigma_1 = 0
    sigma_2 = 0
    for i in range(H):
        if sigma_1 > sigma_2:
            sigma_1 = calc_weighted_variance(x[i:H+i], h[i:H+i])
            if sigma_1 >= sigma_min:
                continue
            else:
                sigma_2 = calc_weighted_variance(x[H+i:2*H+i], h[H+i:2*H+i])
        else:
            sigma_2 = calc_weighted_variance(x[H+i:2*H+i], h[H+i:2*H+i])
            if sigma_2 >= sigma_min:
                continue
            else:
                sigma_1 = calc_weighted_variance(x[i:H+i], h[i:H+i])

        sigma = sigma_1 + sigma_2

        if sigma < sigma_min:
            sigma_min = sigma
            i_min = i

    return x[i_min % (2*H)], x[(i_min + H) % (2*H)]


In [4]:
def circular_otsu_updater(x, h):
    ww = len(x) // 2

    xe = np.hstack([x, x + (x[-1] + x[1] - 2 * x[0])])
    he = np.tile(h, 2)

    def init(x, h):
        mean = h[:ww].sum()
        omega = (x[:ww] * h[:ww]).sum()
        return mean, omega, (np.square(x[:ww] - mean / omega) * h[:ww]).sum()

    def update(ia, ie, omega, mean):
        omega += he[ie] - he[ia]
        mean += xe[ie] * he[ie] - xe[ia] * he[ia]
        return (
            omega,
            mean,
            (np.square(xe[ia + 1:ie + 1] - mean / omega) * he[ia + 1:ie + 1]).sum()
        )

    omega_1, mean_1, sigma_1 = init(x[:ww], h[:ww])
    omega_2, mean_2, sigma_2 = init(x[ww:2*ww], h[ww:2*ww])

    sigma_min = sigma_1 + sigma_2
    i_min = 0

    for i in range(ww-1):
        omega_1, mean_1, sigma_1 = update(i, i + ww, omega_1, mean_1)
        omega_2, mean_2, sigma_2 = update(i + ww, i + 2 * ww, omega_2, mean_2)

        sigma = sigma_1 + sigma_2
        if sigma < sigma_min:
            sigma_min = sigma
            i_min = i + 1

    return x[i_min % (2*ww)], x[(i_min + ww) % (2*ww)]


In [6]:
import time


gaussian = lambda x, x0, a, s: a*np.exp(-np.square(x-x0)/(2*s*s))
functions = dict(naive=circular_otsu_naive, less=circular_otsu_less, updater=circular_otsu_updater)
W = 10

@interact(
    p1=(0.1, 0.9, 0.01), p2=(0.1, 0.9, 0.01), noise=(0, 0.1, 0.01),
    roll=(-0.65, 0.65, 0.05), N=(3, 500, 11),
    impl=functions.keys())
def aux(
    p1=0.4, p2=0.54, N=256, roll=0, noise=0.01,
    impl="naive",
):
    x = np.linspace(-W, 2*W, 3*N, endpoint=False, dtype=np.float32)
    h = noise * np.random.random(x.shape).astype(np.float32)
    h += gaussian(x, p1*W + roll*W, 1, 0.5)
    h += gaussian(x, p2*W + roll*W, 0.5, 0.5)
    h[N:2*N] += h[0:N]
    h[N:2*N] += h[2*N:3*N]

    bins = x + 0.5*(x[1]-x[0])
    bins = bins[N:2*N]
    x = x[N:2*N]
    h = h[N:2*N]

    _, ax = plt.subplots(figsize=(10, 5))
    ax.plot(x, h, c="#000", label="histogram")

    start = time.perf_counter()
    co_th = functions[impl](x, h)
    end = time.perf_counter()

    for th in co_th:
        ax.axvline(th, c="#f00", label=impl+" circular otsu", lw=2, alpha=0.3)

    ax.axvline(threshold_otsu(hist=(h, bins)), c="#00f", label="linear Otsu", ls="dotted", lw=2, alpha=0.3)
    ax.set_title(f"calculation time: {1_000*(end-start):.3f} ms")
    ax.legend(loc="lower right")


interactive(children=(FloatSlider(value=0.4, description='p1', max=0.9, min=0.1, step=0.01), FloatSlider(valueâ€¦