In [1]:
import pandas as pd
import numpy as np
import scipy.linalg as la
import matplotlib.pyplot as plt
from scipy.misc import logsumexp
from scipy.stats import multivariate_normal
from tqdm import tqdm_notebook as tqdm

In [2]:
df = pd.read_csv('alsfrs_cleaned.csv', index_col=0)

In [3]:
df.head()

Unnamed: 0,subject_id,Q1_Speech,Q2_Salivation,Q3_Swallowing,Q4_Handwriting,Q5_Cutting,Q6_Dressing_and_Hygiene,Q7_Turning_in_Bed,Q8_Walking,Q9_Climbing_Stairs,Q10_Respiratory,ALSFRS_Delta,ALSFRS_Total,delta_t
0,329,4.0,3.0,4.0,3.0,2.0,3.0,2.0,2.0,1.0,3.0,0.0,27.0,0.0
1,329,4.0,3.0,4.0,3.0,1.5,3.0,2.0,2.0,1.0,3.0,8.0,26.5,8.0
2,329,4.0,3.0,4.0,3.0,1.0,3.0,2.0,2.0,1.0,3.0,16.0,26.0,8.0
3,329,4.0,3.0,4.0,3.0,3.0,3.0,3.0,2.0,1.0,4.0,42.0,30.0,26.0
4,329,4.0,3.0,4.0,3.0,2.0,3.0,4.0,2.0,2.0,3.0,72.0,30.0,30.0


In [96]:
class CTHMM:
    def __init__(self, n_states, n_dim):
        self.n_states = n_states
        self.log_pi = np.log(np.ones((self.n_states)) / self.n_states)
        self.log_P = {}
        self.n_pid = 0
        
        # init R
        self.R = -np.eye(self.n_states)
        self.R[-1, -1] = 0
        for i in range(self.n_states - 1):
            self.R[i,i+1:] = 1 / (self.R[i+1:].shape[0])
        
        self.n_dim = n_dim
        
        # init emission matrix
        self.emission_matrix = np.zeros((2, self.n_states, self.n_dim))
        self.emission_matrix[0, :, :] = 2
        self.emission_matrix[1, :, :] = .5
        self.emission_matrix[0, 0, :] = 4
        self.emission_matrix[0, -1, :] = 0

    def EM_step(self, data):
        ### E Step ###
        
        self.n_pid = data['subject_id'].unique().shape[0]
        
        log_pi_update = np.zeros((self.n_states))
        weighted_means = np.log(np.zeros((self.n_states, self.n_dim)))

        unique_intervals = data['delta_t'].unique()
        C = np.zeros((unique_intervals.shape[0], self.n_states, self.n_states))
        interval_map = {}
        
        total_weight_assgn = np.log(np.zeros((self.n_states)))
        
        for pid, pdata in data.groupby('subject_id'):
#             print (pdata.columns)
            obs = pdata.drop(['subject_id', 'ALSFRS_Delta', 'delta_t', 'ALSFRS_Total'], axis=1).values
            intervals = pdata['delta_t'].values

            alpha = self.forward(obs, intervals)
            beta = self.backward(obs, intervals)

            LL = logsumexp((alpha[:, -1] + beta[:, -1]))
            
#             print ('Sanity:\n', logsumexp((alpha + beta), axis=0))
            
            for idx, t_delta in enumerate(intervals[1:]):
                if t_delta not in interval_map:
                    interval_map[t_delta] = len(interval_map.keys())
                log_P = self.log_transition_matrix(t_delta)
                log_emission = self.log_emission(obs[idx + 1, :])
                for src in range(self.n_states):
                    for dest in range(self.n_states):
                        C[interval_map[t_delta], src, dest] = logsumexp([C[interval_map[t_delta], src, dest], alpha[src, idx], log_P[src, dest],
                                                                     beta[dest, idx + 1], log_emission[dest]])

            log_pi_update = logsumexp([log_pi_update, alpha[:, 0] + beta[:, 0] - LL], axis=0)
#             print ('alpha:\n', alpha)
#             print ('beta:\n', beta)
#             print ('LL:\n', LL)
#             print ('mean update:\n', np.e**(alpha + beta - LL) @ obs)
            log_weights = np.zeros(alpha.shape)
            for t in range(log_weights.shape[1]):
                log_weights[:,t] = alpha[:,t] + beta[:,t] - logsumexp(alpha[:,t] + beta[:,t])  # M x T
#             print ('Resp:\n', np.sum(np.exp(log_weights), axis=0))
#             print ('log_weights before:\n', log_weights)
#             print ('log_weights e sum:\n', np.sum(np.exp(log_weights), axis=1))
            for i in range(self.n_states):
                for t in range(log_weights.shape[1]):
                    for d in range(self.n_dim):
#                         print ('weighted_prev:\n', weighted_means[i,d])
#                         print ('obs:\n', obs[t,d])
#                         print ('adding:\n', log_weights[i,t] + np.log(obs[t,d]))
#                         print ('logsumexp:\n', logsumexp([weighted_means[i,d], log_weights[i,t] + np.log(obs[t,d])]))
                        weighted_means[i, d] = logsumexp([weighted_means[i,d], log_weights[i,t] + np.log(obs[t,d])])
                    total_weight_assgn[i] = logsumexp([total_weight_assgn[i], log_weights[i,t]])
#                     weighted_means[i, j] = np.e**(alpha + beta - LL) @ obs


        ### M Step ###

        # Update emission params
        self.emission_matrix[0, 1:-1, :] = np.e**(weighted_means - total_weight_assgn[:, None])[1:-1, :]
#         self.emission_matrix[0, 1:-1, :] = np.e**(weighted_means)[1:-1, :]
        print ('total_weight_assgn:\n', np.e**total_weight_assgn)

        # Update pi
        self.log_pi = log_pi_update - logsumexp(log_pi_update)

        # Updated R
        A = np.zeros((self.n_states * 2, self.n_states * 2))
        A[:self.n_states, :self.n_states] = self.R
        A[self.n_states:, self.n_states:] = self.R

        D = np.zeros((self.n_states, self.n_states, self.n_states))
        tau = np.zeros((self.n_states))

        N = np.zeros((self.n_states, self.n_states, self.n_states, self.n_states))
        nu = np.zeros((self.n_states, self.n_states))
        
        C = np.e**(C) - 1
        
#         print('past groupby')
#         print (interval_map)
        for i in range(self.n_states):
            A[i, self.n_states + i] = 1
            for t_delta in unique_intervals:
                if t_delta == 0:
                    continue
#                 print ('A exp:\n', la.expm(A * t_delta)[:self.n_states, self.n_states:])
#                 print ('e trans:\n',np.e**(self.log_transition_matrix(t_delta)))
#                 exit(0)
                D[i] = la.expm(A * t_delta)[:self.n_states, self.n_states:] /                           \
                       np.e**(self.log_transition_matrix(t_delta))
                D = np.nan_to_num(D)
                tau[i] += np.sum(C[interval_map[t_delta], :, :] * D[i, :, :])
            A[i, self.n_states + i] = 0

        for i in range(self.n_states):
            for j in range(self.n_states):
                A[i, self.n_states + j] = 1
                for t_delta in unique_intervals:
                    if t_delta == 0:
                        continue
                    N[i, j] = self.R[i, j] * la.expm(A * t_delta)[:self.n_states, self.n_states:] /        \
                           np.e**(self.log_transition_matrix(t_delta))
                    N = np.nan_to_num(N)
                    nu[i, j] += np.sum(C[interval_map[t_delta], :, :] * N[i, j, :, :])
                A[i, self.n_states + j] = 0
#         print ('C:\n', C)
#         print ('D:\n', D)
#         print('nu:\n',nu)
#         print ('tau:\n',tau)
        
        for i in range(self.n_states):
            self.R[i, i+1:] = nu[i, i+1:] / tau[i]
            self.R[i, i] = -np.sum(self.R[i, i+1:])
            
        self.log_P = {}


    def log_transition_matrix(self, t_delta):
        """
        Input:
            t_delta scalar
        Output:
            P M x M
        """
        if t_delta in self.log_P:
            return self.log_P[t_delta]
        
        self.log_P[t_delta] = np.log(la.expm(self.R * t_delta))
        
#         print('nonlog transition:\n', la.expm(self.R * t_delta))
#         print('log transition:\n', self.log_P[t_delta])
#         print('back to nonlog:\n', np.sum(np.e**(self.log_P[t_delta]), axis=1))

        return self.log_P[t_delta]

    def log_emission(self, observation):
        """
            Input: D x 1
            Output: M x 1
        """
        b = np.ndarray(self.n_states, dtype=float)
        for i in range(self.n_states):
            means = self.emission_matrix[0, i]
            covariance = np.diag(self.emission_matrix[1, i])
            b[i] = multivariate_normal.logpdf(observation, means, covariance)
        return b

    def forward(self, obs, intervals):
        """
        Input:
            obs T x D
            intervals T
            n_states scalar
        Output:
            alpha M x T
        """
        T = obs.shape[0]

        alpha = np.zeros((self.n_states, T))

#         print ('pi:\n', self.log_pi)
#         print ('emission:\n', self.log_emission(obs[0,:]))
        alpha[:, 0] = self.log_pi + self.log_emission(obs[0, :])
        
#         print('alpha t=0:\n', alpha[:,0])

        tmp = np.zeros((self.n_states))

        for idx, t_delta in enumerate(intervals[1:]):
            log_B = self.log_emission(obs[idx + 1, :])
            log_P = self.log_transition_matrix(t_delta)
            
#             print ('transition:\n', log_P)

            for dest in range(self.n_states):
                for src in range(self.n_states):
                    tmp[src] = alpha[src, idx] + log_P[src, dest]

                alpha[dest, idx + 1] = log_B[dest] + logsumexp(tmp)

#         print ('alpha:\n', alpha)
        return alpha

    def backward(self, observations, time_intervals):
        T = observations.shape[0]
        beta = np.zeros((self.n_states, T), dtype=float)
        for t in range(T - 2, -1, -1):
            a = self.log_transition_matrix(time_intervals[t])
            b = self.log_emission(observations[t + 1])
            for i in range(self.n_states):
                beta[i, t] = logsumexp([beta[j, t + 1] + a[i, j] + b[j] for j in range(self.n_states)])

#         print ('beta:\n', beta)
        return beta

    def update_pi(self, alpha, beta):
        self.log_pi = alpha[0, :] + beta[0, :]

In [None]:
def train(model, training_data, n_epochs):
    log_likelihoods = np.ndarray(n_epochs, dtype=float)
    for epoch in tqdm(range(n_epochs)):
        model.EM_step(training_data[:1000])
        print ('R:\n', model.R)
        print ('Means:\n', model.emission_matrix[0])
        print ('Pi:\n', np.exp(model.log_pi))
#         log_likelihood = model.LogLikelihood(training_data)
#         log_likelihoods[epoch] = log_likelihood
    
    plt.scatter(range(1,n_epochs+1), log_likelihoods)
    plt.xlabel('epoch')
    plt.ylabel('log likelihood')
    plt.title('model training')
    plt.show()

train(CTHMM(5, 10), df, 20)



total_weight_assgn:
 [259.78704314 160.79003207 129.11206825 153.28209973 297.02875681]




R:
 [[-0.09672674  0.03990836  0.02039441  0.01829266  0.0181313 ]
 [-0.         -0.08947536  0.04344862  0.02358904  0.02243771]
 [-0.         -0.         -0.10317918  0.05968459  0.04349459]
 [-0.         -0.         -0.         -0.16375076  0.16375076]
 [-0.         -0.         -0.         -0.         -0.        ]]
Means:
 [[4.         4.         4.         4.         4.         4.
  4.         4.         4.         4.        ]
 [2.969676   3.15779596 3.32166195 2.71392874 2.22195085 1.98046749
  2.6082267  2.26014791 1.3097787  3.45964573]
 [2.92093975 3.13446172 3.25883531 2.65442142 2.1277009  1.86419646
  2.54126143 2.22234375 1.22023623 3.38784355]
 [2.85094678 3.09522362 3.22640321 2.53992362 1.98027684 1.8071585
  2.44727967 2.1837725  1.14588082 3.30209335]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]]
Pi:
 [1.00000000e+00 7.15094292e-14 2.22023000e-14 3.10396856e-15
 6.54945075e-46]
total_weight_assgn:
 [189.4837

total_weight_assgn:
 [145.67961833 396.52535556 215.97999193 212.94643141  28.86860277]
R:
 [[-0.03075901  0.01744008  0.00666305  0.00438677  0.00226912]
 [-0.         -0.02569758  0.01735248  0.00588656  0.00245853]
 [-0.         -0.         -0.02153849  0.01745767  0.00408082]
 [-0.         -0.         -0.         -0.02016902  0.02016902]
 [-0.         -0.         -0.         -0.         -0.        ]]
Means:
 [[4.         4.         4.         4.         4.         4.
  4.         4.         4.         4.        ]
 [3.69397941 3.70781553 3.80651295 3.15688302 2.79427972 2.19174899
  2.69970247 1.96532099 1.01566984 3.57490425]
 [1.58089395 2.26506663 2.4001675  2.28177633 1.5591137  1.81673845
  2.74340521 2.76254166 1.71072861 3.06421808]
 [2.85461625 3.21190828 3.07063381 0.77789098 0.54857426 0.34646999
  0.87056632 1.41833485 0.32193912 2.97407404]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]]
Pi:
 [0.23148123 0.58702