In [1]:
import os
os.chdir(os.path.expanduser("~/Projects/zimmer"))
import pickle

import numpy as np
np.random.seed(1234)

import matplotlib.pyplot as plt
plt.ion()

from tqdm import tqdm
from functools import partial

from zimmer.io import load_kato_data
import zimmer.plotting as zplt

from ssm.models import HMM, SLDS, LDS
from ssm.util import find_permutation, compute_state_overlap
from ssm.preprocessing import pca_with_imputation, trend_filter

In [2]:
# Load the data
ys, ms, z_trues, z_true_key, neuron_names = load_kato_data(include_unnamed=False, signal="dff")
ys = [trend_filter(y) for y in ys]

K_true = len(z_true_key)
N = ys[0].shape[1]
W = len(ys)
Ts = [y.shape[0] for y in ys]

Only including named neurons.
59 neurons across all 5 worms


In [None]:
Ds = np.arange(2, 21, step=2)

In [None]:
N_iters = 1000
ldss = []
elboss = []
xss = []
for D in Ds:
    print("Fitting LDS with {} latent dimensions".format(D))
    lds = LDS(N, D)
    elbos, variational_params = lds.fit(ys, masks=ms, optimizer="adam_with_convergence_check", print_intvl=1)
    xs = [vp[0] for vp in variational_params]
    
    # Save results
    ldss.append(lds)
    elboss.append(elbos)
    xss.append(xs)

Fitting LDS with 2 latent dimensions
Initializing with an ARHMM using 25 steps of EM.
Done
Iteration 0.  ELBO: 70379.0
Iteration 1.  ELBO: 72028.1
Iteration 2.  ELBO: 72077.2
Iteration 3.  ELBO: 72913.6
Iteration 4.  ELBO: 73795.9
Iteration 5.  ELBO: 74311.1
Iteration 6.  ELBO: 72411.1
Iteration 7.  ELBO: 73845.8
Iteration 8.  ELBO: 77153.2
Iteration 9.  ELBO: 76045.0
Iteration 10.  ELBO: 75914.1
Iteration 11.  ELBO: 75323.4
Iteration 12.  ELBO: 75745.5
Iteration 13.  ELBO: 76656.6
Iteration 14.  ELBO: 78267.6
Iteration 15.  ELBO: 77354.8
Iteration 16.  ELBO: 79013.9
Iteration 17.  ELBO: 77040.2
Iteration 18.  ELBO: 78630.1
Iteration 19.  ELBO: 81697.4
Iteration 20.  ELBO: 81077.1
Iteration 21.  ELBO: 80556.6
Iteration 22.  ELBO: 80430.0
Iteration 23.  ELBO: 81781.0
Iteration 24.  ELBO: 82137.1
Iteration 25.  ELBO: 84181.7
Iteration 26.  ELBO: 83742.3
Iteration 27.  ELBO: 83961.6
Iteration 28.  ELBO: 84949.3
Iteration 29.  ELBO: 85498.2
Iteration 30.  ELBO: 87269.1
Iteration 31.  ELBO:

Iteration 267.  ELBO: 158806.2
Iteration 268.  ELBO: 159871.1
Iteration 269.  ELBO: 159660.1
Iteration 270.  ELBO: 160370.1
Iteration 271.  ELBO: 160914.2
Iteration 272.  ELBO: 161632.0
Iteration 273.  ELBO: 159214.7
Iteration 274.  ELBO: 161011.0
Iteration 275.  ELBO: 161056.0
Iteration 276.  ELBO: 161012.1
Iteration 277.  ELBO: 160220.0
Iteration 278.  ELBO: 161978.7
Iteration 279.  ELBO: 162918.9
Iteration 280.  ELBO: 163002.5
Iteration 281.  ELBO: 161361.5
Iteration 282.  ELBO: 161400.6
Iteration 283.  ELBO: 162711.0
Iteration 284.  ELBO: 163190.3
Iteration 285.  ELBO: 162425.9
Iteration 286.  ELBO: 163803.7
Iteration 287.  ELBO: 164267.8
Iteration 288.  ELBO: 163459.0
Iteration 289.  ELBO: 164864.0
Iteration 290.  ELBO: 164392.1
Iteration 291.  ELBO: 164616.0
Iteration 292.  ELBO: 163941.0
Iteration 293.  ELBO: 165129.6
Iteration 294.  ELBO: 165573.3
Iteration 295.  ELBO: 167049.1
Iteration 296.  ELBO: 166496.0
Iteration 297.  ELBO: 165932.2
Iteration 298.  ELBO: 167276.4
Iteratio

Iteration 532.  ELBO: 202691.4
Iteration 533.  ELBO: 204097.5
Iteration 534.  ELBO: 203568.4
Iteration 535.  ELBO: 203699.2
Iteration 536.  ELBO: 204957.1
Iteration 537.  ELBO: 203496.7
Iteration 538.  ELBO: 204493.7
Iteration 539.  ELBO: 204660.7
Iteration 540.  ELBO: 204659.3
Iteration 541.  ELBO: 204214.4
Iteration 542.  ELBO: 204074.3
Iteration 543.  ELBO: 204483.9
Iteration 544.  ELBO: 205702.5
Iteration 545.  ELBO: 205445.0
Iteration 546.  ELBO: 204849.0
Iteration 547.  ELBO: 205703.1
Iteration 548.  ELBO: 205072.8
Iteration 549.  ELBO: 206569.8
Iteration 550.  ELBO: 205367.1
Iteration 551.  ELBO: 206313.7
Iteration 552.  ELBO: 207156.6
Iteration 553.  ELBO: 205902.4
Iteration 554.  ELBO: 206445.3
Iteration 555.  ELBO: 206761.4
Iteration 556.  ELBO: 206649.9
Iteration 557.  ELBO: 205645.5
Iteration 558.  ELBO: 207754.4
Iteration 559.  ELBO: 206347.5
Iteration 560.  ELBO: 206101.0
Iteration 561.  ELBO: 206276.0
Iteration 562.  ELBO: 207692.4
Iteration 563.  ELBO: 206710.9
Iteratio

Iteration 797.  ELBO: 228609.1
Iteration 798.  ELBO: 228326.4
Iteration 799.  ELBO: 229400.0
Iteration 800.  ELBO: 229249.2
Iteration 801.  ELBO: 229493.2
Iteration 802.  ELBO: 229588.4
Iteration 803.  ELBO: 228902.3
Iteration 804.  ELBO: 229620.4
Iteration 805.  ELBO: 229193.8
Iteration 806.  ELBO: 229801.4
Iteration 807.  ELBO: 229961.0
Iteration 808.  ELBO: 229801.6
Iteration 809.  ELBO: 230201.1
Iteration 810.  ELBO: 228898.9
Iteration 811.  ELBO: 229770.7
Iteration 812.  ELBO: 229325.9
Iteration 813.  ELBO: 230237.8
Iteration 814.  ELBO: 229475.4
Iteration 815.  ELBO: 230159.1
Iteration 816.  ELBO: 230479.9
Iteration 817.  ELBO: 230117.4
Iteration 818.  ELBO: 230686.5
Iteration 819.  ELBO: 230923.9
Iteration 820.  ELBO: 230897.6
Iteration 821.  ELBO: 230266.2
Iteration 822.  ELBO: 231050.8
Iteration 823.  ELBO: 230842.9
Iteration 824.  ELBO: 230747.7
Iteration 825.  ELBO: 231008.1
Iteration 826.  ELBO: 230751.9
Iteration 827.  ELBO: 231501.5
Iteration 828.  ELBO: 231445.1
Iteratio

Iteration 1060.  ELBO: 243773.8
Iteration 1061.  ELBO: 244061.7
Iteration 1062.  ELBO: 244170.0
Iteration 1063.  ELBO: 244279.7
Iteration 1064.  ELBO: 243994.8
Iteration 1065.  ELBO: 243721.7
Iteration 1066.  ELBO: 244665.9
Iteration 1067.  ELBO: 244203.1
Iteration 1068.  ELBO: 244610.6
Iteration 1069.  ELBO: 243944.5
Iteration 1070.  ELBO: 244346.9
Iteration 1071.  ELBO: 244388.7
Iteration 1072.  ELBO: 244837.1
Iteration 1073.  ELBO: 244470.6
Iteration 1074.  ELBO: 244810.7
Iteration 1075.  ELBO: 245084.6
Iteration 1076.  ELBO: 244574.9
Iteration 1077.  ELBO: 244698.4
Iteration 1078.  ELBO: 244881.0
Iteration 1079.  ELBO: 244566.2
Iteration 1080.  ELBO: 245269.3
Iteration 1081.  ELBO: 244939.9
Iteration 1082.  ELBO: 245442.4
Iteration 1083.  ELBO: 245703.4
Iteration 1084.  ELBO: 245420.9
Iteration 1085.  ELBO: 245227.8
Iteration 1086.  ELBO: 245523.2
Iteration 1087.  ELBO: 245764.9
Iteration 1088.  ELBO: 245795.8
Iteration 1089.  ELBO: 244847.7
Iteration 1090.  ELBO: 245362.9
Iteratio

Iteration 1317.  ELBO: 253446.8
Iteration 1318.  ELBO: 253939.9
Iteration 1319.  ELBO: 254006.6
Iteration 1320.  ELBO: 254076.1
Iteration 1321.  ELBO: 253729.3
Iteration 1322.  ELBO: 253712.2
Iteration 1323.  ELBO: 253970.8
Iteration 1324.  ELBO: 253833.8
Iteration 1325.  ELBO: 253859.6
Iteration 1326.  ELBO: 254082.6
Iteration 1327.  ELBO: 253923.5
Iteration 1328.  ELBO: 253816.7
Iteration 1329.  ELBO: 253989.4
Iteration 1330.  ELBO: 253849.6
Iteration 1331.  ELBO: 253344.4
Iteration 1332.  ELBO: 254158.3
Iteration 1333.  ELBO: 254214.6
Iteration 1334.  ELBO: 254308.9
Iteration 1335.  ELBO: 254039.9
Iteration 1336.  ELBO: 253976.7
Iteration 1337.  ELBO: 254011.4
Iteration 1338.  ELBO: 254538.5
Iteration 1339.  ELBO: 254524.7
Iteration 1340.  ELBO: 254504.8
Iteration 1341.  ELBO: 254135.1
Iteration 1342.  ELBO: 254517.2
Iteration 1343.  ELBO: 254388.3
Iteration 1344.  ELBO: 254016.3
Iteration 1345.  ELBO: 253784.9
Iteration 1346.  ELBO: 254267.2
Iteration 1347.  ELBO: 254811.2
Iteratio

Iteration 1574.  ELBO: 259701.8
Iteration 1575.  ELBO: 259938.1
Iteration 1576.  ELBO: 260111.5
Iteration 1577.  ELBO: 260186.1
Iteration 1578.  ELBO: 260273.8
Iteration 1579.  ELBO: 260364.8
Iteration 1580.  ELBO: 260343.2
Iteration 1581.  ELBO: 260259.4
Iteration 1582.  ELBO: 259685.0
Iteration 1583.  ELBO: 260268.7
Iteration 1584.  ELBO: 260569.2
Iteration 1585.  ELBO: 260147.9
Iteration 1586.  ELBO: 260757.4
Iteration 1587.  ELBO: 260121.4
Iteration 1588.  ELBO: 260610.7
Iteration 1589.  ELBO: 260054.9
Iteration 1590.  ELBO: 260226.6
Iteration 1591.  ELBO: 260663.9
Iteration 1592.  ELBO: 260666.3
Iteration 1593.  ELBO: 260661.5
Iteration 1594.  ELBO: 260238.2
Iteration 1595.  ELBO: 260558.6
Iteration 1596.  ELBO: 260502.5
Iteration 1597.  ELBO: 260410.0
Iteration 1598.  ELBO: 260610.1
Iteration 1599.  ELBO: 260408.9
Iteration 1600.  ELBO: 260264.6
Iteration 1601.  ELBO: 260214.8
Iteration 1602.  ELBO: 260352.1
Iteration 1603.  ELBO: 260925.1
Iteration 1604.  ELBO: 260462.5
Iteratio

Iteration 1831.  ELBO: 264785.3
Iteration 1832.  ELBO: 264414.7
Iteration 1833.  ELBO: 264475.4
Iteration 1834.  ELBO: 264354.8
Iteration 1835.  ELBO: 264434.2
Iteration 1836.  ELBO: 264600.7
Iteration 1837.  ELBO: 264466.3
Iteration 1838.  ELBO: 264411.5
Iteration 1839.  ELBO: 264411.5
Iteration 1840.  ELBO: 264534.3
Iteration 1841.  ELBO: 264901.1
Iteration 1842.  ELBO: 264885.2
Iteration 1843.  ELBO: 264537.1
Iteration 1844.  ELBO: 264861.9
Iteration 1845.  ELBO: 264905.3
Iteration 1846.  ELBO: 264953.8
Iteration 1847.  ELBO: 264474.1
Iteration 1848.  ELBO: 264638.7
Iteration 1849.  ELBO: 264604.6
Iteration 1850.  ELBO: 264333.2
Iteration 1851.  ELBO: 264756.0
Iteration 1852.  ELBO: 264474.6
Iteration 1853.  ELBO: 265160.7
Iteration 1854.  ELBO: 264994.6
Iteration 1855.  ELBO: 264336.9
Iteration 1856.  ELBO: 264772.7
Iteration 1857.  ELBO: 264693.8
Iteration 1858.  ELBO: 264930.3
Iteration 1859.  ELBO: 265109.6
Iteration 1860.  ELBO: 264903.0
Iteration 1861.  ELBO: 264983.8
Iteratio

In [None]:
plt.figure()
plt.plot(np.array(elboss).T)
plt.xlabel("Iteration")
plt.ylabel("ELBO")

plt.figure()
plt.plot(Ds, np.array(elboss)[:,-1])
plt.xlabel("D")
plt.ylabel("Final ELBO")

In [None]:
plt.figure(figsize=(12, 9))
for w, (x, z) in enumerate(zip(xs, z_trues)):
    ax = plt.subplot(3, W, w+1)
    zplt.plot_2d_continuous_states(x, z, xlims=(-2.5, 2.5), ylims=(-2.5, 2.5), inds=(0, 1), ax=ax)
    plt.ylabel("PC 2" if w == 0 else "")
    plt.title("worm {}".format(w+1))

    ax = plt.subplot(3, W, W+w+1)
    zplt.plot_2d_continuous_states(x, z, xlims=(-2.5, 2.5), ylims=(-2.5, 2.5), inds=(0, 2), ax=ax)
    plt.ylabel("PC 3" if w == 0 else "")

    ax = plt.subplot(3, W, 2*W+w+1)
    zplt.plot_2d_continuous_states(x, z, xlims=(-2.5, 2.5), ylims=(-2.5, 2.5), inds=(0, 3), ax=ax)
    plt.xlabel("PC 1")
    plt.ylabel("PC 4" if w == 0 else "")
    
plt.suptitle("Continuous Latent States (Zimmer Labels)")

In [None]:
plt.imshow(lds.dynamics.As[0])
plt.title("dynamics")
plt.colorbar()

print("A")
print(lds.dynamics.As[0].round(2))
print("b")
print(lds.dynamics.bs.round(2))
print("sigma")
print(np.exp(lds.dynamics.inv_sigmas.round(2)))

In [None]:
C = lds.emissions.Cs[0]
plt.imshow(C, aspect=0.5)
plt.colorbar()
print(C.T.dot(C).round(2))
print(lds.emissions.ds[0].round(2))
print(np.exp(lds.emissions.inv_etas.round(2)))

# Fit a robust, recurrent, hierarchical ARHMM to the latent states