In [1]:
import os
os.chdir(os.path.expanduser("~/Projects/zimmer"))
import copy
import pickle

import numpy as np
np.random.seed(1234)

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
plt.ion()
%matplotlib inline

from tqdm import tqdm
from functools import partial

from zimmer.io import load_kato_data
import zimmer.plotting as zplt
from zimmer.observations import HierarchicalIndependentAutoRegressiveObservations

from ssm.models import HMM
from ssm.core import _HMM
from ssm.init_state_distns import InitialStateDistribution
from ssm.transitions import RecurrentTransitions, StationaryTransitions
from ssm.util import find_permutation, compute_state_overlap
from ssm.preprocessing import pca_with_imputation, trend_filter, standardize

In [2]:
# Load the data
ys, ms, z_trues, z_true_key, neuron_names = load_kato_data(include_unnamed=False, signal="dff")
ys = [trend_filter(y) for y in ys]

K_true = len(z_true_key)
N = ys[0].shape[1]
W = len(ys)
Ts = [y.shape[0] for y in ys]

Only including named neurons.
59 neurons across all 5 worms


In [3]:
# Run PCA to get a 3d projection of the data
from scipy.ndimage import gaussian_filter1d
from ssm.preprocessing import pca_with_imputation
pca, xs = pca_with_imputation(4, ys, ms)
xs = [gaussian_filter1d(x, 1, axis=0) for x in xs]
lim = 1.1 * abs(np.vstack(xs)).max()

# Split the data into chunks, some train and some test

In [4]:
chunk = 250
train_frac = 0.7
val_frac = 0.15
all_ys = []
all_ms = []
all_tags = []
all_z_trues = []
all_choices = []
for tag, (y, m, ztr) in enumerate(zip(ys, ms, z_trues)):
    T = y.shape[0]
    C = 0
    for start in range(0, T, chunk):
        stop = min(start+chunk, T)
        all_ys.append(y[start:stop])
        all_ms.append(m[start:stop])
        all_z_trues.append(ztr[start:stop])
        all_tags.append(tag)
        C += 1
        
    # assign some of the data to train, val, and test
    choices = -1 * np.ones(C)
    choices[:int(train_frac * C)] = 0
    choices[int(train_frac * C):int((train_frac + val_frac) * C)] = 1
    choices[int((train_frac + val_frac) * C):] = 2
    choices = choices[np.random.permutation(C)]
    all_choices.append(choices)

all_choices = np.concatenate(all_choices)
get = lambda arr, chc: [x for x, c in zip(arr, all_choices) if c == chc]

train_ys = get(all_ys, 0)
train_ms = get(all_ms, 0)
train_zs = get(all_z_trues, 0)
train_tags = get(all_tags, 0)

val_ys = get(all_ys, 1)
val_ms = get(all_ms, 1)
val_zs = get(all_z_trues, 1)
val_tags = get(all_tags, 1)

test_ys = get(all_ys, 2)
test_ms = get(all_ms, 2)
test_zs = get(all_z_trues, 2)
test_tags = get(all_tags, 2)

In [5]:
print(train_tags)
print(val_tags)
print(test_tags)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4]


In [6]:
Ks = np.arange(2, 20+1, 2, dtype=int)
M = 0
eta = 1e-4

# Fit standard AR models to the training data

In [7]:
# Fit an HMM with a single set of dynamics
standard_hmms = []
standard_train_lls = []
standard_val_lls = []
for K in Ks:
    print("Fitting model with {} states".format(K))
    standard_hmm = HMM(K, N, M, transitions="standard", observations="independent_ar")
    lls = standard_hmm.fit(train_ys, masks=train_ms, method="em", num_em_iters=50)
    
    standard_hmms.append(standard_hmm)
    standard_train_lls.append(lls)
    standard_val_lls.append(standard_hmm.log_probability(val_ys, masks=val_ms))

Fitting model with 2 states
Iteration 0.  LL: 884979.7
Iteration 1.  LL: 891258.0
Iteration 2.  LL: 892653.9
Iteration 3.  LL: 893595.6
Iteration 4.  LL: 894051.6
Iteration 5.  LL: 894252.2
Iteration 6.  LL: 894352.4
Iteration 7.  LL: 894399.6
Iteration 8.  LL: 894421.4
Iteration 9.  LL: 894431.1
Iteration 10.  LL: 894434.7
Iteration 11.  LL: 894435.9
Iteration 12.  LL: 894436.2
Iteration 13.  LL: 894436.4
Iteration 14.  LL: 894436.4
Iteration 15.  LL: 894436.4
Iteration 16.  LL: 894436.4
Iteration 17.  LL: 894436.4
Iteration 18.  LL: 894436.4
Iteration 19.  LL: 894436.4
Iteration 20.  LL: 894436.4
Iteration 21.  LL: 894436.4
Iteration 22.  LL: 894436.4
Iteration 23.  LL: 894436.4
Iteration 24.  LL: 894436.4
Iteration 25.  LL: 894436.4
Iteration 26.  LL: 894436.4
Iteration 27.  LL: 894436.4
Iteration 28.  LL: 894436.4
Iteration 29.  LL: 894436.4
Iteration 30.  LL: 894436.4
Iteration 31.  LL: 894436.4
Iteration 32.  LL: 894436.4
Iteration 33.  LL: 894436.4
Iteration 34.  LL: 894436.4
It

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(121)
for K,lls in zip(Ks, train_lls):
    plt.plot(lls, label="K={}".format(K))
plt.xlabel("Iteration")
plt.ylabel("Train LL")
plt.legend(loc="lower right", ncol=4)


plt.subplot(122)
for K, val_ll in zip(Ks, val_lls):
    plt.plot(K, val_ll, 'ko')
plt.xticks(Ks)
plt.xlabel("K")
plt.ylabel("Validation LL")

# Fit hierarchical HMMs

In [None]:
def make_hierarchical_hmm(K, standard_hmm):
    # Fit a hierarchical hmm
    init_state_distn = InitialStateDistribution(K, N, M)
    transition_distn = StationaryTransitions(K, N, M)
    observation_distn = HierarchicalIndependentAutoRegressiveObservations(K, N, W, M, eta=eta)
    hmm = _HMM(K, N, M, init_state_distn, transition_distn, observation_distn)

    # Initialize with the standard HMM
    hmm.init_state_distn.log_pi0 = standard_hmm.init_state_distn.log_pi0.copy() 
    hmm.transitions.log_Ps = standard_hmm.transitions.log_Ps.copy()

    hmm.observations.shared_As = standard_hmm.observations.As.copy()
    hmm.observations.shared_Vs = standard_hmm.observations.Vs.copy()
    hmm.observations.shared_bs = standard_hmm.observations.bs.copy()

    for w in range(W):
        hmm.observations.As[w] = standard_hmm.observations.As.copy()
        hmm.observations.Vs[w] = standard_hmm.observations.Vs.copy()
        hmm.observations.bs[w] = standard_hmm.observations.bs.copy()
        hmm.observations.inv_sigmas[w] = standard_hmm.observations.inv_sigmas.copy()
    
    assert np.allclose(
        standard_hmm.log_probability(train_ys, masks=train_ms),
        hmm.log_probability(train_ys, masks=train_ms, tags=train_tags)
    )

    return hmm

In [None]:
# Fit the hierarchical HMM
hmms = []
train_lls = []
val_lls = []
for K, standard_hmm in zip(Ks, standard_hmms):
    print("Fitting model with {} states".format(K))
    hmm = make_hierarchical_hmm(K, standard_hmm)
    lls = hmm.fit(train_ys, masks=train_ms, tags=train_tags, method="em", num_em_iters=50)
    
    hmms.append(hmm)
    train_lls.append(lls)
    val_lls.append(hmm.log_probability(val_ys, masks=val_ms, tags=val_tags))

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(121)
for K,lls in zip(Ks, train_lls):
    plt.plot(lls, label="K={}".format(K))
plt.xlabel("Iteration")
plt.ylabel("Train LL")
plt.legend(loc="lower right", ncol=4)

plt.subplot(122)
for K, val_ll in zip(Ks, val_lls):
    plt.plot(K, val_ll, 'ko')
plt.xticks(Ks)
plt.xlabel("K")
plt.ylabel("Validation LL")

# Fit the model to the entire dataset

In [None]:
# zs = [standard_hmm.most_likely_states(y, mask=m) for y, m in zip(ys, ms)]
# standard_hmm.permute(find_permutation(np.concatenate(zs), np.concatenate(z_trues)))
# zs = [standard_hmm.most_likely_states(y, mask=m) for y, m in zip(ys, ms)]

In [None]:
zs = [hmm.most_likely_states(y, mask=m, tag=tag) for y, m, tag in zip(ys, ms, range(W))]
hmm.permute(find_permutation(np.concatenate(z_trues), np.concatenate(zs)))
zs = [hmm.most_likely_states(y, mask=m, tag=tag) for y, m, tag in zip(ys, ms, range(W))]

In [None]:
for w in range(W):
    plt.figure(figsize=(12, 2))

    plt.subplot(211)
    plt.imshow(z_trues[w][None, :], aspect="auto", cmap=zplt.default_cmap, vmin=0, vmax=len(zplt.default_colors) -1)
    plt.ylabel("$z_{true}$")

    plt.subplot(212)
    plt.imshow(zs[w][None, :], aspect="auto", cmap=zplt.default_cmap, vmin=0, vmax=len(zplt.default_colors) -1)
    plt.ylabel("$z$")

In [None]:
zplt.plot_state_overlap(zs, z_trues)

In [None]:
# E_z, E_zzp1 = hmm.expected_states(ys[0], mask=ms[0], tag=0)
# plt.imshow(E_z.T, aspect="auto", vmin=0, vmax=1)
# plt.colorbar()

In [None]:
plt.figure(figsize=(12, 9))

for w in range(W):
    ax = plt.subplot(3, W, w+1)
    zplt.plot_2d_continuous_states(xs[w], zs[w], xlims=(-lim, lim), ylims=(-lim, lim), inds=(0, 1), ax=ax)
    plt.ylabel("PC 2" if w == 0 else "")
    plt.title("worm {}".format(w+1))

    ax = plt.subplot(3, W, W+w+1)
    zplt.plot_2d_continuous_states(xs[w], zs[w], xlims=(-lim, lim), ylims=(-lim, lim), inds=(0, 2), ax=ax)
    plt.ylabel("PC 3" if w == 0 else "")
    plt.xlabel("PC 1")

    ax = plt.subplot(3, W, 2 * W+w+1)
    zplt.plot_2d_continuous_states(xs[w], zs[w], xlims=(-lim, lim), ylims=(-lim, lim), inds=(0, 3), ax=ax)
    plt.ylabel("PC 3" if w == 0 else "")
    plt.xlabel("PC 1")

    plt.suptitle("Continuous Latent States (Inferred Labels)")

# Simulate from the model with the inferred discrete states

In [None]:
# z_smpl = np.array([superhmm.observations.sample_x(st, None) for st in s])  # T x N_valid
w = 4
z_smpl = zs[w].copy()
y_smpl = np.zeros((Ts[w], N))
for t in range(Ts[w]):
    y_smpl[t] = hmm.observations.sample_x(z_smpl[t], y_smpl[:t], tag=w, with_noise=False)

In [None]:
# Throw out the unstable samples
bad = np.where(abs(y_smpl).max(0) > 3)[0]
y_good = y_smpl.copy()
y_good[:, bad] = np.nan
print("neurons with bad simulations: ", bad)

In [None]:
# Plot real data for comparison
plt.figure(figsize=(12, 10))
gs = GridSpec(2, 1, height_ratios=[1, 8])


plt.subplot(gs[0])
plt.imshow(z_smpl[None, :], aspect="auto", cmap=zplt.default_cmap, vmin=0, vmax=len(zplt.default_colors) -1)
plt.ylabel("$s$")

plt.subplot(gs[1])
plt.plot(np.arange(Ts[w]) / 3.0, ys[w] - np.arange(N), '-k')
plt.plot(np.arange(Ts[w]) / 3.0, y_good - np.arange(N), '-', color=zplt.default_colors[3])
plt.yticks(-np.arange(N), neuron_names)
plt.ylim(-N,1)
plt.xlim(0, Ts[w] / 3.0)
plt.xlabel("time (s)")