# Simple example of a recurrent SLDS

In [1]:
import os
import pickle

import autograd.numpy as np
import autograd.numpy.random as npr
npr.seed(12345)

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
color_names = ["windows blue", "red", "amber", "faded green"]
colors = sns.xkcd_palette(color_names)
sns.set_style("white")
sns.set_context("talk")

from ssm.models import SLDS
from ssm.variational import SLDSTriDiagVariationalPosterior
from ssm.util import random_rotation, find_permutation

In [7]:
# read DEBS2013 dataset
indir = os.path.join('..','_csv','ver1')
infile_list = [infile for infile in os.listdir(indir) if infile.endswith('.csv')]

for infile in infile_list[:1]:
    y = np.loadtxt(os.path.join(indir, infile), delimiter=',').reshape(-1,8,2)
    
# Global parameters
T = y.shape[0]
K = 4
D_obs = y.shape[2]
D_latent = 2

print(T, K, D_obs, D_latent)

8779 4 2 2


In [20]:
# Fit a robust rSLDS with its default initialization
y_tmp = np.hstack([y[:,0,:],])

rslds = SLDS(D_obs*2, K, D_latent,
             transitions="recurrent_only",
             dynamics="gaussian",
             emissions="gaussian",
#              emissions="ar_id",
             single_subspace=True)

rslds.initialize(y_tmp)

q = SLDSTriDiagVariationalPosterior(rslds, y_tmp)
elbos = rslds.fit(q, y_tmp, num_iters=1000, initialize=False)

Initializing with an ARHMM using 25 steps of EM.


  return f_raw(*args, **kwargs)


HBox(children=(IntProgress(value=0, max=25), HTML(value='')))

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


AssertionError: 

In [None]:
# Get the posterior mean of the continuous states
print(q)
# xhat = q.mean[0]

# Find the permutation that matches the true and inferred states
# rslds.permute(find_permutation(z, rslds.most_likely_states(xhat, y)))
zhat = rslds.most_likely_states(xhat, y)

In [None]:
# Plot some results
plt.figure()
plt.plot(elbos)
plt.xlabel("Iteration")
plt.ylabel("ELBO")

In [134]:
# Helper functions for plotting results
def plot_trajectory(z, x, ax=None, ls="-"):
    zcps = np.concatenate(([0], np.where(np.diff(z))[0] + 1, [z.size]))
    if ax is None:
        fig = plt.figure(figsize=(4, 4))
        ax = fig.gca()
    for start, stop in zip(zcps[:-1], zcps[1:]):
        ax.plot(x[start:stop + 1, 0],
                x[start:stop + 1, 1],
                lw=1, ls=ls,
                color=colors[z[start] % len(colors)],
                alpha=1.0)

    return ax


def plot_most_likely_dynamics(model,
    xlim=(-4, 4), ylim=(-3, 3), nxpts=20, nypts=20,
    alpha=0.8, ax=None, figsize=(3, 3)):
    
    K = model.K
    assert model.D == 2
    x = np.linspace(*xlim, nxpts)
    y = np.linspace(*ylim, nypts)
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack((X.ravel(), Y.ravel()))

    # Get the probability of each state at each xy location
    z = np.argmax(xy.dot(model.transitions.Rs.T) + model.transitions.r, axis=1)

    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)

    for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
        dxydt_m = xy.dot(A.T) + b - xy

        zk = z == k
        if zk.sum(0) > 0:
            ax.quiver(xy[zk, 0], xy[zk, 1],
                      dxydt_m[zk, 0], dxydt_m[zk, 1],
                      color=colors[k % len(colors)], alpha=alpha)

    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')

    plt.tight_layout()

    return ax

#### Figure.5 のような可視化を行う

In [147]:
def plot_most_likely_dynamics(model,
    xlim=(-2.643110244480175, 2.6196227080428462), 
    ylim=(-2.6453652504108405, 2.4965166456481436), 
    nxpts=100, nypts=100,
    alpha=0.8, ax=None, figsize=(3, 3)):
    
    K = model.K
    x = np.linspace(*xlim, nxpts)
    y = np.linspace(*ylim, nypts)
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack((X.ravel(), Y.ravel()))
    

    # Get the probability of each state at each xy location
    z = np.argmax(xy.dot(model.transitions.Rs.T) + model.transitions.r, axis=1)

    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)

    for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
        dxydt_m = xy.dot(A.T) + b - xy

        zk = z == k
        if zk.sum(0) > 0:
            ax.quiver(xy[zk, 0], xy[zk, 1],
                      dxydt_m[zk, 0], dxydt_m[zk, 1],
                      color=colors[k % len(colors)], alpha=alpha)

    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')

    plt.tight_layout()
    
    return ax

plot_most_likely_dynamics(rslds)

ValueError: shapes (10000,2) and (1,4) not aligned: 2 (dim 1) != 1 (dim 0)