Skip to content

Commit

Permalink
update examples and notebooks following fit refactor
Browse files Browse the repository at this point in the history
fix examples and notebooks following fit refactor

fix simple LDS notebook
  • Loading branch information
davidzoltowski committed Jul 15, 2019
1 parent 080878d commit 223ec32
Show file tree
Hide file tree
Showing 13 changed files with 578 additions and 469 deletions.
1 change: 1 addition & 0 deletions examples/lds.py
Expand Up @@ -28,6 +28,7 @@
colors = sns.xkcd_palette(color_names)

from ssm import LDS
from ssm.util import random_rotation

# Set the parameters of the HMM
T = 1000 # number of time bins
Expand Down
2 changes: 0 additions & 2 deletions examples/rslds.py
Expand Up @@ -16,8 +16,6 @@
sns.set_context("talk")

import ssm
from ssm.variational import SLDSMeanFieldVariationalPosterior, \
SLDSTriDiagVariationalPosterior, SLDSStructuredMeanFieldVariationalPosterior
from ssm.util import random_rotation

# Global parameters
Expand Down
13 changes: 9 additions & 4 deletions examples/slds.py
Expand Up @@ -6,10 +6,14 @@
import matplotlib.pyplot as plt

import ssm
from ssm.variational import SLDSMeanFieldVariationalPosterior, SLDSTriDiagVariationalPosterior, \
SLDSStructuredMeanFieldVariationalPosterior
from ssm.util import random_rotation, find_permutation

import seaborn as sns
color_names = ["windows blue", "red", "amber", "faded green"]
colors = sns.xkcd_palette(color_names)
sns.set_style("white")
sns.set_context("talk")

# Set the parameters of the HMM
T = 10000 # number of time bins
K = 5 # number of discrete states
Expand Down Expand Up @@ -40,7 +44,7 @@
slds.permute(find_permutation(z, slds.most_likely_states(q_mf_x, y)))
q_mf_z = slds.most_likely_states(q_mf_x, y)

Do the same with the structured posterior
# Do the same with the structured posterior
print("Fitting SLDS with SVI using structured variational posterior")
slds = ssm.SLDS(N, K, D, emissions="bernoulli")
slds.initialize(y)
Expand Down Expand Up @@ -76,7 +80,8 @@
# Plot the ELBOS
plt.figure()
plt.plot(q_lem_elbos, label="Laplace EM")
# plt.plot(q_struct_elbos, label="LDS")
plt.plot(q_struct_elbos, label="LDS")
plt.plot(q_mf_elbos, label="MF")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.legend()
Expand Down
255 changes: 136 additions & 119 deletions notebooks/1b Simple Linear Dynamical System.ipynb

Large diffs are not rendered by default.

56 changes: 32 additions & 24 deletions notebooks/1b Simple Linear Dynamical System.py
Expand Up @@ -37,11 +37,10 @@
colors = sns.xkcd_palette(color_names)

import ssm
from ssm.variational import SLDSMeanFieldVariationalPosterior, SLDSTriDiagVariationalPosterior
from ssm.util import random_rotation, find_permutation
from ssm.util import random_rotation

# Specify whether or not to save figures
save_figures = True
save_figures = False


# In[2]:
Expand Down Expand Up @@ -125,9 +124,24 @@
if save_figures:
plt.savefig("lds_2.pdf")


# In[6]:

print("Fitting LDS with SVI using structured variational posterior")
lds = ssm.LDS(N, D, emissions="gaussian_orthog")
lds.initialize(y)

q_lem_elbos, q_lem = lds.fit(y, method="laplace_em", variational_posterior="structured_meanfield",
num_iters=10, initialize=False)

# Get the posterior mean of the continuous states
q_lem_x = q_lem.mean_continuous_states[0]

# Smooth the data under the variational posterior
q_lem_y = lds.smooth(q_lem_x, y)


# In[7]:


print("Fitting LDS with SVI")

Expand All @@ -136,16 +150,11 @@
lds.initialize(y)

# Create a variational posterior
q_mf = SLDSMeanFieldVariationalPosterior(lds, y)
q_mf_elbos = lds.fit(q_mf, y, num_iters=1000, initialize=False)
q_mf_elbos, q_mf = lds.fit(y, method="bbvi", variational_posterior="mf", num_iters=1000, initialize=False)

# Get the posterior mean of the continuous states
q_mf_x = q_mf.mean[0]


# In[7]:


# Smooth the data under the variational posterior
q_mf_y = lds.smooth(q_mf_x, y)

Expand All @@ -157,8 +166,7 @@
lds = ssm.LDS(N, D, emissions="gaussian_orthog")
lds.initialize(y)

q_struct = SLDSTriDiagVariationalPosterior(lds, y)
q_struct_elbos = lds.fit(q_struct, y, num_iters=1000, initialize=False)
q_struct_elbos, q_struct = lds.fit(y, method="bbvi", variational_posterior="tridiag", num_iters=1000, initialize=False)

# Get the posterior mean of the continuous states
q_struct_x = q_struct.mean[0]
Expand All @@ -171,6 +179,7 @@


# Plot the ELBOs
plt.plot(q_lem_elbos, label="Laplace-EM")
plt.plot(q_mf_elbos, label="MF")
plt.plot(q_struct_elbos, label="LDS")
plt.xlabel("Iteration")
Expand All @@ -184,8 +193,9 @@
plt.figure(figsize=(8,4))
plt.plot(x + 4 * np.arange(D), '-k')
for d in range(D):
plt.plot(q_mf_x[:,d] + 4 * d, '-', color=colors[0], label="MF" if d==0 else None)
plt.plot(q_struct_x[:,d] + 4 * d, '-', color=colors[1], label="Struct" if d==0 else None)
plt.plot(q_lem_x[:,d] + 4 * d, '--', color=colors[0], label="Laplace-EM" if d==0 else None)
plt.plot(q_mf_x[:,d] + 4 * d, '--', color=colors[1], label="MF" if d==0 else None)
plt.plot(q_struct_x[:,d] + 4 * d, ':', color=colors[2], label="Struct" if d==0 else None)
plt.ylabel("$x$")
plt.legend()

Expand All @@ -197,8 +207,9 @@
plt.figure(figsize=(8,4))
for n in range(N):
plt.plot(y[:, n] + 4 * n, '-k', label="True" if n == 0 else None)
plt.plot(q_mf_y[:, n] + 4 * n, '--', color=colors[0], label="MF" if n == 0 else None)
plt.plot(q_struct_y[:, n] + 4 * n, ':', color=colors[1], label="Struct" if n == 0 else None)
plt.plot(q_lem_y[:, n] + 4 * n, '--', color=colors[0], label="Laplace-EM" if n == 0 else None)
plt.plot(q_mf_y[:, n] + 4 * n, '--', color=colors[1], label="MF" if n == 0 else None)
plt.plot(q_struct_y[:, n] + 4 * n, ':', color=colors[2], label="Struct" if n == 0 else None)
plt.legend()
plt.xlabel("time")

Expand All @@ -208,8 +219,6 @@
# In[13]:


from ssm.models import HMM

N_iters = 50
K = 15
hmm = ssm.HMM(K, D, observations="gaussian")
Expand Down Expand Up @@ -244,10 +253,10 @@

plt.figure(figsize=(6, 6))
for k in range(K):
plt.contour(XX, YY, np.exp(lls[:,k]).reshape(XX.shape),
plt.contour(XX, YY, np.exp(lls[:,k]).reshape(XX.shape),
cmap=white_to_color_cmap(colors[k % len(colors)]))
plt.plot(x[z==k, 0], x[z==k, 1], 'o', mfc=colors[k], mec='none', ms=4)

plt.plot(x[:,0], x[:,1], '-k', lw=2, alpha=.5)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
Expand Down Expand Up @@ -276,7 +285,7 @@
for i, (_, x_smpl) in enumerate(smpls):
x_smpl = np.concatenate((x[:1], x_smpl))
plt.plot(x_smpl[:,d] - d*lim, '-', lw=1, color=colors[i])

plt.yticks(-np.arange(D) * lim, ["$x_{}$".format(d+1) for d in range(D)])
plt.xlabel("time")
plt.xlim(0, T)
Expand Down Expand Up @@ -306,10 +315,10 @@

plt.figure(figsize=(6, 6))
for k in range(K):
plt.contour(XX, YY, np.exp(lls[:,k]).reshape(XX.shape),
plt.contour(XX, YY, np.exp(lls[:,k]).reshape(XX.shape),
cmap=white_to_color_cmap(colors[k % len(colors)]))
plt.plot(x[z==k, 0], x[z==k, 1], 'o', mfc=colors[k], mec='none', ms=4)

plt.plot(x[:,0], x[:,1], '-k', lw=2, alpha=.5)
for i, (_, x_smpl) in enumerate(smpls):
x_smpl = np.concatenate((x[:1], x_smpl))
Expand Down Expand Up @@ -393,4 +402,3 @@

if save_figures:
plt.savefig("lds_7.pdf")

0 comments on commit 223ec32

Please sign in to comment.