# Python segmentation benchmark

Make some test data for segmentation and test algorithm performance.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
%matplotlib inline

np.random.seed(12348)

# Set up constants

In [2]:
# duration stuff

T = 500  # total duration (in s)
dt = 0.05  # bin size (in s)
Nt = int(np.ceil(T / dt))

mu = 0.5  # mean state duration (in s)
sig = 0.05  # standard deviation of state duration (in s)

In [3]:
# rates
lam = 1.  # baseline rate (Hz)
nu = 500.  # rate multiplier (Hz)

# Make some intervals

In [4]:
durations = mu + sig * np.random.randn(Nt)
changepoints = np.cumsum(durations)
maxind = np.argwhere(changepoints > T)[0, 0]
changepoints = changepoints[:maxind]

In [5]:
changepoints / dt

array([   11.3095624 ,    19.66190965,    29.58543266, ...,  9972.56112388,
        9983.99454471,  9993.95774758])

In [6]:
taxis = np.arange(0, T, dt)
states = np.zeros(Nt)
rates = lam * np.ones(Nt)

In [7]:
for idx in xrange(0, len(changepoints), 2):
    if idx == len(changepoints) - 1:
        upper = T
    else:
        upper = changepoints[idx + 1]
    in_state_1 = (changepoints[idx] < taxis) & (taxis < upper)
    states[in_state_1] = 1
    rates[in_state_1] = nu * lam

# Make counts

In [8]:
counts = stats.poisson.rvs(rates * dt)

# Calculate log likelihoods

In [9]:
psi = np.empty((len(counts), 2))
psi[:, 0] = stats.poisson.logpmf(counts, lam * dt)
psi[:, 1] = stats.poisson.logpmf(counts, nu * lam * dt)

# Run inference

In [10]:
from spiketopics import pelt

In [11]:
# prior parameters:
theta = 0.5  # unbiased z prior
alpha = 1  # exp(-m) prior on changepoint number

cplist = pelt.find_changepoints(psi, theta, alpha)

# now time without jit overhead
%time cplist = pelt.find_changepoints(psi, theta, alpha)

#%prun cplist = pelt.find_changepoints(psi, theta, alpha)

CPU times: user 63.5 ms, sys: 0 ns, total: 63.5 ms
Wall time: 63.5 ms


In [12]:
inferred = pelt.calc_state_probs(psi, theta, cplist)

In [13]:
changepoints / dt , np.array(cplist)

(array([   11.3095624 ,    19.66190965,    29.58543266, ...,  9972.56112388,
         9983.99454471,  9993.95774758]),
 array([   0,   11,   19, ..., 9972, 9983, 9993]))