In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable, Function
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions import Normal

import numpy as np
import matplotlib.pyplot as plt 
%matplotlib inline
import scipy.misc as scpm
import math
import h5py
import copy
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from datta.messages.torchhmm import hmm_marginal_likelihood
from datta.messages.torchhmm import hmm_sample
from time import time
import joblib
import copy

In [2]:
###############
## Load data ##
###############

batch_size=1000
n_mice=17
h5_temp = h5py.File('whitened_clean_pca.h5','r')

In [3]:
################################
## Get rid of mouse with nans ##
################################

uuids_all = list(h5_temp.keys())
uuids = copy.deepcopy(uuids_all)
delete_inds=[]
for i in range(len(uuids_all)):
   # print(np.sum(np.isnan(h5_temp[uuids_all[i]][:])))
    if np.sum(np.isnan(h5_temp[uuids_all[i]][:]))>0:
        delete_inds.append(i)
for ii in sorted(delete_inds,reverse=True):
    del uuids[ii]
    
for i in range(n_mice):
   # print(uuids[i])
    if np.sum(np.isnan(h5_temp[uuids[i]][:])) > 0:
        print('ERROR: NANS STILL PRESENT')
        ver


In [4]:
############################
## CREATE DATA GENERATORS ##
############################

def data_generator(pca_file, uuids, batch_size,n_mice):
    
    total_batches=0
    n_batches = [None]*n_mice

    for i_mouse in range(n_mice):
        length_mouse = pca_file[uuids[i_mouse]].shape[0]
        n_batches[i_mouse] = np.floor(length_mouse/batch_size)

    total_batches = int(np.sum(n_batches))

    batch_inds = np.zeros((int(total_batches),2))
    i_pos=0
    
    for i_mouse in range(n_mice):
        for i_batch in range(int(n_batches[i_mouse])):
            batch_inds[i_pos,0] = i_mouse
            batch_inds[i_pos,1] = i_batch
            i_pos+=1

    loop_vec = np.arange(total_batches)
    
    
    for i_epoch in range(1000):
        np.random.shuffle(loop_vec)
        for ii in loop_vec:

            i_mouse = int(batch_inds[ii,0])
            which_batch = int(batch_inds[ii,1])

            yield pca_file[uuids[i_mouse]][which_batch*batch_size:(which_batch+1)*batch_size], i_mouse, which_batch #, behavioral_labels[i_mouse][which_bucket][which_batch*batch_size:(which_batch+1)*batch_size],depth[i_mouse][which_bucket][which_batch*batch_size:(which_batch+1)*batch_size],i_mouse,which_bucket,which_batch


In [5]:
# Get number of training batches

pca_file = h5_temp
total_batches=0
n_batches = [None]*n_mice
for i_mouse in range(n_mice):
    length_mouse = pca_file[uuids[i_mouse]].shape[0]
    n_batches[i_mouse] = np.floor(length_mouse/batch_size)
total_batches = int(np.sum(n_batches))
nb_tng_batches = total_batches
print(nb_tng_batches)

901


In [6]:
# Create data generator

data_gen = data_generator(h5_temp,uuids,batch_size,n_mice)

In [7]:
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        return m + torch.log(sum_exp)


In [8]:
#################
## Model class ##
#################

class ARHMM(nn.Module):
    def __init__(self, n_discrete_states, latent_dim_size_h, transition_init, kappa, alpha, batch_size, nlags):
        super(ARHMM, self).__init__()
        self.n_discrete_states = n_discrete_states
        self.latent_dim_size_h = latent_dim_size_h
        self.transition_init=transition_init
        self.batch_size=batch_size
        self.nlags = nlags
        self.kappa = kappa
        self.alpha = alpha
        self.__build_model()
    
    def __build_model(self):

        # Emission parameters
        self.As = nn.Parameter(torch.zeros((self.n_discrete_states, self.latent_dim_size_h*self.nlags, self.latent_dim_size_h)))
        self.bs = nn.Parameter(torch.zeros((self.n_discrete_states, self.latent_dim_size_h)))
        self.inv_softplus_Qs = nn.Parameter(torch.ones((self.n_discrete_states, self.latent_dim_size_h)))

        # Transition bias network (x -> transition bias + transition prob)
        self.log_transition_proba = \
                nn.Parameter(torch.log(
                self.transition_init * torch.eye(self.n_discrete_states) + (1-self.transition_init) / self.n_discrete_states * torch.ones((self.n_discrete_states, self.n_discrete_states))))

        self.softplus = nn.Softplus()

    def get_tiled_normalized_log_Ps(self):
        
        normalized_Ps = self.log_transition_proba - log_sum_exp(self.log_transition_proba, dim=-1, keepdim=True)
        
        return normalized_Ps.unsqueeze(0).repeat(self.batch_size-1,1,1)
    
    def get_transition_dirichlet_prior(self):
        
        log_Ps = self.log_transition_proba-log_sum_exp(self.log_transition_proba,dim=-1,keepdim=True)
        
        lp = 0
        for i_state in range(self.n_discrete_states):
            concentration = self.alpha*torch.ones(self.n_discrete_states) / self.n_discrete_states
            concentration[i_state] += self.kappa
            lp += ((log_Ps[i_state] * (concentration - 1.0)).sum(-1) +
                torch.lgamma(concentration.sum(-1)) -
                torch.lgamma(concentration).sum(-1))
           # dirichlet_dist = torch.distributions.dirichlet.Dirichlet(concentration)
           # lp += torch.sum(dirichlet_dist.log_prob(Ps[i_state]))
        return lp

    def get_lls(self,h):
        E_hs = torch.transpose(torch.matmul(torch.cat(([h[self.nlags-1-i:self.batch_size-1-i] for i in range(self.nlags,-1,-1)]),dim=1), self.As),1,0) + self.bs       # T-1 x K x H
        E_hs = torch.cat((self.bs.view(1, self.n_discrete_states, self.latent_dim_size_h).repeat(nlags,1,1) , E_hs), 0)                             # T x K x H
        
        lls = -0.5 * torch.sum((h.unsqueeze(1) - E_hs)**2 / self.softplus(self.inv_softplus_Qs), dim=2)         # T x K
        lls += -0.5 * torch.sum(math.log(2 * math.pi) + torch.log(self.softplus(self.inv_softplus_Qs)), dim=1)  # K 

        # Below uses torch multivariate but is slower 
#         lls = torch.zeros(self.batch_size,self.n_discrete_states)
#         for i_k in range(self.n_discrete_states):
#             normal_dist = torch.distributions.multivariate_normal.MultivariateNormal(E_hs[:,i_k],covariance_matrix=torch.diag(self.softplus(self.inv_softplus_Qs)[i_k]))
#             lls[:,i_k] = normal_dist.log_prob(h)
        return lls
    
    def get_log_pi0(self):
        return -math.log(self.n_discrete_states) * torch.ones(self.n_discrete_states)
    
    def forward(self, h):
        return self.get_log_pi0(), self.get_lls(h), self.get_tiled_normalized_log_Ps()
    
    def sample_states(self, h, log_pi0, lls, log_Ps,num_samples=1):

        return torch.stack([hmm_sample(log_pi0, log_Ps, lls) for _ in range(num_samples)])
    
    def EM_loss(self, zs, h, log_pi0, lls, log_Ps):
        
        dirichlet_prior = self.get_transition_dirichlet_prior()
    
        n_samples = zs.shape[0]
        lp = 0
        for i_sample in range(n_samples):

            # Compute log p(z | theta)
            lp += torch.sum(log_Ps[torch.arange(0, self.batch_size-1).long(), zs[i_sample,:-1], zs[i_sample,1:]])
             
            # Compute log p(x | z, theta)
            lp += torch.sum(lls[torch.arange(0, self.batch_size).long(), zs])
           
        tt = copy.deepcopy(lp.detach().numpy())
        
        # TO DO: change 901 to parameter!!!
        lp += torch.sum(dirichlet_prior)*(1/901)

        return -lp/(h.shape[0]*h.shape[1])

    def SGD_loss(self, h, log_pi0, lls, log_Ps):
        
        dirichlet_prior = self.get_transition_dirichlet_prior()
    
        log_prob = hmm_marginal_likelihood(log_pi0, log_Ps, lls) + dirichlet_prior*(1/901)
        
        return -log_prob/(h.shape[0]*h.shape[1])
        
    def get_expected_states(self, h):

        log_Ps = self.get_tiled_normalized_log_Ps()
        log_PsT = log_Ps.transpose(1, 2) 
        
        lls = self.get_lls(h)
        
        # Forwards pass
        alpha = torch.zeros_like(lls)
        alpha[0] = -math.log(self.n_discrete_states) + lls[0] 
        for t in range(alpha.shape[0]-1):
            alpha[t+1] = log_sum_exp(alpha[t] + log_PsT[t],dim=1) + lls[t+1]

        # Backwards pass
        beta = torch.zeros_like(lls)
        for t in range(beta.shape[0]-2,-1,-1):
            beta[t] = log_sum_exp(log_Ps[t]+beta[t+1]+lls[t+1],dim=1)

        # Combine to get posterior over z
        expected_z = alpha+beta
        expected_z -= expected_z.max(1)[0].view((-1,1))
        expected_z = torch.exp(expected_z)
        expected_z /= expected_z.sum(1).view((-1,1))

        return expected_z
    
    def initialize_transitions(self, data_gen, nb_tng_batches, device, L2_reg=0.01):

        # Split the data into n_discrete_state chunks, fit linear regression to each chunk separately
        data_split = np.floor(nb_tng_batches/self.n_discrete_states)
        
        i_discrete_state = 0
        start_collecting=1

        for batch_nb in range(nb_tng_batches):
            
            # Get this batch of data
            pca, i_mouse, i_batch = next(data_gen)
            pca = torch.tensor(pca).to(device).float()
            X = torch.cat(([pca[nlags-1-i:pca.shape[0]-1-i] for i in range(nlags,-1,-1)]),dim=1) 
            X = F.pad(X,(1,0),value=1)
            Y = pca[self.nlags:]

            # Collect X/Y/XTX/XTY 
            if start_collecting: # start of a new chunk
                all_X = X
                all_Y = Y
                XTX = torch.matmul(X.transpose(1,0),X)
                XTY = torch.matmul(X.transpose(1,0),Y)
                start_collecting=0
            else:
                all_X = torch.cat((all_X,X),0)
                all_Y = torch.cat((all_Y,Y),0)
                XTX += torch.matmul(X.transpose(1,0),X)
                XTY += torch.matmul(X.transpose(1,0),Y)

            if i_discrete_state < self.n_discrete_states:
                if np.mod(batch_nb+1,data_split)==0:
                    
                    # Calculate weights for this chunk
                    reg_XTX = XTX+L2_reg*torch.eye(X.shape[1])
                    XTX_inv = torch.inverse(reg_XTX)
                    W = torch.matmul(XTX_inv,XTY)

                    self.As.data[i_discrete_state] = W[1:,:].data
                    self.bs.data[i_discrete_state] = W[0,:].data

                    # Reconstruct to get residuals/covariances 
                    Y_hat = torch.matmul(all_X,W)
                    residuals = Y_hat-all_Y
                    Qs = torch.var(residuals,0).data
                    self.inv_softplus_Qs.data[i_discrete_state] = torch.log(torch.exp(Qs)-1)
                    
                    # Reset
                    start_collecting=1
                    i_discrete_state +=1

        if i_discrete_state < self.n_discrete_states-1:
            print('ERROR WITH INITIALIZATION')
            sys.exit(0)


In [9]:
#############################
## Create/initialize model ##
#############################

transition_init=0.5
n_discrete_states=100
latent_dim_size_h=10
nlags=3
alpha=200
kappa=1e6

model = ARHMM(n_discrete_states, latent_dim_size_h, transition_init, kappa, alpha, batch_size, nlags)

# Initialize model
model.initialize_transitions(data_gen, nb_tng_batches, device)


In [10]:
# ##########################
# ## Load in Gibb's Model ##
# ##########################

# model_fit = 'clean_data_pca_gibbs_sampling_100_iter.p'
# model_fit = joblib.load(model_fit)


# model.As.data = torch.tensor(np.swapaxes(np.asarray(model_fit['model_parameters'][0]['ar_mat'])[:,:,:-1],2,1)).float()
# model.bs.data = torch.tensor(np.asarray(model_fit['model_parameters'][0]['ar_mat'])[:,:,-1]).float()
# sig = np.asarray(model_fit['model_parameters'][0]['sig'])
# diag_sig = np.asarray([np.diag(sig[i]) for i in range(100)])
# Qs = torch.tensor(sig).float()
# model.inv_softplus_Qs.data = torch.tensor(np.log(np.exp(diag_sig)-1)).float()
# model.log_transition_proba.data = torch.tensor(np.log(np.asarray(model_fit['model_parameters'][0]['transition_matrix']))).float()



In [11]:
# # # #######################
# # # ## SGD loss training ##
# # # #######################

# optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=.01)
# for i_epoch in range(100):
#     train_loss = 0
#     for i_train in tqdm(range(nb_tng_batches)):
#         pca, i_mouse, i_batch = next(data_gen)
#         pca = torch.tensor(pca).to(device).float()

#         optimizer.zero_grad()
        
#         log_pi0, lls, log_Ps = model(pca)
#         loss = model.SGD_loss(pca,log_pi0, lls, log_Ps)
        
#         if i_epoch > 0:
#             loss.backward()
#             optimizer.step()
#         train_loss += loss.item()
#     print("Epoch: ", i_epoch, " ", train_loss/nb_tng_batches)


In [12]:
######################
## EM loss training ##
######################
all_training_loss = []
all_likelihood_loss = []
all_dirichlet_loss = []
n_unique_states = []
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=.001)

for i_epoch in range(1):
    train_loss = 0
    dirichlet_loss=0
    likelihood_loss = 0
    unique_states=np.empty((0,))
    for i_train in tqdm(range(nb_tng_batches)):
        
        pca, i_mouse, i_batch = next(data_gen)
        pca = torch.tensor(pca).to(device).float()
        
        log_pi0, lls, log_Ps = model(pca)

        zs = model.sample_states(pca, log_pi0, lls, log_Ps)

        unique_states = np.unique(np.concatenate((zs.detach().numpy().reshape((-1,)),unique_states)))
        optimizer.zero_grad()
        
        loss = model.EM_loss(zs, pca, log_pi0, lls, log_Ps)
        d_loss = -1*((model.get_transition_dirichlet_prior())/(pca.shape[0]*pca.shape[1])).detach().numpy()/901
        dirichlet_loss += d_loss
        likelihood_loss += ((loss-d_loss)).detach().numpy()
        
        if np.isnan(loss.item()):
            break
        if i_epoch > 0:
            loss.backward()
            optimizer.step()

        train_loss += loss.item()
    if np.isnan(loss.item()):
        break
    print("Epoch: ", i_epoch, " ", train_loss/nb_tng_batches)
    all_training_loss.append(train_loss/nb_tng_batches)
    all_likelihood_loss.append(likelihood_loss/nb_tng_batches)
    all_dirichlet_loss.append(dirichlet_loss/nb_tng_batches)
    n_unique_states.append(unique_states.shape[0])

 30%|███       | 272/901 [01:03<02:27,  4.28it/s]

KeyboardInterrupt: 

In [None]:
###################
## Analyze model ##
###################

# Save model
# torch.save(model, 'model.pt')

train_loss=0
dirichlet_loss=0
these_counts = np.zeros((100,))
b_freqs = np.empty((0,))
for i_train in tqdm(range(nb_tng_batches)):
    pca, i_mouse, i_batch = next(data_gen)
    pca = torch.tensor(pca).to(device).float()
    
    log_pi0, lls, log_Ps = model(pca)
    zs = model.sample_states(pca, log_pi0, lls, log_Ps)
    loss = model.EM_loss(zs, pca, log_pi0, lls, log_Ps)
    dirichlet_loss += model.get_transition_dirichlet_prior()
    train_loss += loss.item()-dirichlet_loss
    
    
    expected_z = model.get_expected_states(pca)
    these_counts += np.histogram(np.argmax(expected_z.detach().numpy(),1),np.arange(-.5,100.5))[0]
    these_lengths = np.diff(np.where(np.ediff1d(np.argmax(expected_z.detach().numpy(),1))!=0))
    b_freqs = np.concatenate((b_freqs,these_lengths.reshape((-1,))))
    


In [None]:
import matplotlib
matplotlib.rcParams.update({'font.size': 14})


In [None]:
fig, ax = plt.subplots(4,1,figsize=(10,20),sharex=True)
ax[0].plot(np.asarray(all_training_loss),'k')
ax[1].plot(np.asarray(all_likelihood_loss),'k')
ax[2].plot(np.asarray(all_dirichlet_loss),'k')
ax[3].plot(np.asarray(n_unique_states),'k')

ax[3].set_xlabel('Epoch Number')
ax[0].set_ylabel('Training Loss')
ax[1].set_ylabel('Likelihood Loss')
ax[2].set_ylabel('Dirichlet Loss')
ax[3].set_ylabel('# unique states')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(20,15
                                   ))
ax.imshow(expected_z.detach().numpy().T,cmap='gray_r',vmin=0,vmax=1)
ax.set_xlabel('Time (frames)')
ax.set_ylabel('State Number')
ax.set_title('Expected z state')

In [None]:
plt.plot(np.argmax(expected_z.detach().numpy(),1))

In [None]:
zs = model.sample_states

In [None]:
bins = np.linspace(0,400,50)
fig, ax = plt.subplots(1,1,figsize=(10,5))
ax.hist(b_freqs,bins,color='k')
ax.set_xlabel('Length of states')
ax.set_ylabel('Occurrences')
ax.set_xlim([-10,410])

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,5))
ax.hist(np.arange(0,100),np.arange(0,100),weights=these_counts,color='k');
ax.set_xlabel('Expected State')
ax.set_ylabel('Occurrences')


In [None]:
n_states_over_1 = np.sum(these_counts/np.sum(these_counts)*100>1)
n_states_over_2 = np.sum(these_counts/np.sum(these_counts)*100>2)
mean_freq = np.mean(b_freqs)

In [None]:
n_states_over_1

In [None]:
n_states_over_2

In [None]:
mean_freq