In [1]:
import numpy as np
import torch

In [2]:
#parameter setting
num_hidden_state = 3
num_obs = 5
length = 3
num_samples = 100

In [3]:
#define some useful functions
def generate_HMM_params(num_hidden_state, num_obs):
    # random generate the transition matrix and observation matrix, and compute the stationary distribution
    
    alpha_state = np.ones(num_hidden_state)
    alpha_obs = np.ones(num_obs) / num_obs
    trans_mat = np.random.dirichlet(alpha_state, num_hidden_state)
    obs_mat = np.random.dirichlet(alpha_obs, num_hidden_state)
    tmp = np.ones((num_hidden_state + 1, num_hidden_state))
    tmp[:-1] = np.identity(num_hidden_state) - trans_mat.T
    tmp_v = np.zeros(num_hidden_state + 1)
    tmp_v[-1] = 1
    stat_dist = np.linalg.lstsq(tmp, tmp_v, rcond=None)[0]
    return trans_mat, obs_mat, stat_dist

In [4]:
def generate_HMM_sequences(trans_mat, obs_mat, init_dist, length, num_samples = 1):
    # generate sample sequences from HMM
    
    states = np.zeros((num_samples, length))
    obs = np.zeros((num_samples, length))
    tmp_state = np.argmax(np.random.multinomial(1, init_dist, num_samples), axis = 1)
    #print(tmp_state)
    for i in range(length):
        #print("i: ", i)
        states[:, i] = tmp_state
        for j in range(num_samples):
            obs[j, i] = np.random.multinomial(1, obs_mat[tmp_state[j]]).argmax()
            tmp_state[j] = np.random.multinomial(1, trans_mat[tmp_state[j]]).argmax()
        #print("obs[:, i]: ", obs[:, i])
    return states, obs

In [5]:
def forward_compute(trans_mat, obs_mat, init_dist, obs_to_pos):
    # compute \sum_{h_1,...,h_{pos-1}} P(h_1,...,h_{pos},x_1,...,x_{pos-1})
    pos = obs_to_pos.shape[0] + 1
    num_hidden_state = trans_mat.shape[0]
    num_obs = obs_mat.shape[1]
    forward = np.zeros((pos, num_hidden_state))
    forward[0] = init_dist
    for i in range(1, pos):
        for j in range(num_hidden_state):
            for k in range(num_hidden_state):
                #print(i, j, k)
                #print(forward[i - 1, k], trans_mat[k, j], obs_mat[k, int(obs_to_pos[i - 1])])
                forward[i, j] += forward[i - 1, k] * trans_mat[k, j] * obs_mat[k, int(obs_to_pos[i - 1])]
    #print("forward: ", forward)
    return forward[pos - 1]

In [6]:
def backward_compute(trans_mat, obs_mat, obs_from_pos):
    num_hidden_state = trans_mat.shape[0]
    num_obs = obs_mat.shape[1]
    back_length = obs_from_pos.shape[0]
    if (back_length == 0):
        return np.ones(num_hidden_state)
    backward = np.zeros((back_length, num_hidden_state))
    for j in range(num_hidden_state):
         for k in range(num_hidden_state):
            backward[0, j] += trans_mat[j, k] * obs_mat[k, int(obs_from_pos[-1])]
    for i in range(1, back_length):
        for j in range(num_hidden_state):
            for k in range(num_hidden_state):
                backward[i, j] += trans_mat[j, k] * obs_mat[k, int(obs_from_pos[-(i + 1)])] * backward[i - 1, k]
    #print("backward: ", backward)
    return backward[-1]

In [7]:
def x_i_conditional_prob(trans_mat, obs_mat, init_dist, known_X, pos):
    num_hidden_state = trans_mat.shape[0]
    num_obs = obs_mat.shape[1]
    num_samples = known_X.shape[0]
    length = known_X.shape[1]
    x_pos_conditional_prob = np.zeros((num_samples, num_obs))
    h_pos_conditional_prob = np.zeros((num_samples, num_hidden_state))
    h_all_pos_conditional_prob = np.zeros((num_samples, num_hidden_state))
    for i in range(num_samples):
        #print("x_i_conditional_prob: i=", i)
        sample_obs_vec = known_X[i]
        forward_vec = forward_compute(trans_mat, obs_mat, init_dist, known_X[i, :pos[i]])
        backward_vec = backward_compute(trans_mat, obs_mat, known_X[i, pos[i] + 1:])
        #print("forward_vec: ", forward_vec)
        #print("backward_vec: ", backward_vec)
        h_prob_tmp = forward_vec * backward_vec
        tmp = h_prob_tmp.sum()
        h_prob_tmp /= tmp
        h_pos_conditional_prob[i] = h_prob_tmp
        x_pos_conditional_prob[i] = h_prob_tmp @ obs_mat
        h_all_pos_conditional_prob[i] = h_prob_tmp * obs_mat[:, int(known_X[i, pos[i]])] / x_pos_conditional_prob[i, int(known_X[i, pos[i]])]
    return h_pos_conditional_prob, x_pos_conditional_prob, h_all_pos_conditional_prob

In [8]:
trans_mat, obs_mat, stat_dist = generate_HMM_params(num_hidden_state, num_obs) # generate parameters for HMM

states, obs = generate_HMM_sequences(trans_mat, obs_mat, stat_dist, length, num_samples) # generate sample sequences

pos = np.random.randint(length, size = num_samples)

print("transition matrix")
print(trans_mat)
print("observation matrix")
print(obs_mat)
print("stationary distribution")
print(stat_dist)
print("states and observations, first half of each row is states")
print(np.concatenate((states, obs), axis = 1))
print("positions: ", pos)
h, x, hh = x_i_conditional_prob(trans_mat, obs_mat, stat_dist, obs, pos)
print("Pr[H_i|x_-i], j-th row is for j-th sample and i=positions[j]:")
print(h)
print("Pr[X_i|x_-i], j-th row is for j-th sample and i=positions[j]:")
print(x)
print("Pr[H_i|x_i,x_-i], j-th row is for j-th sample and i=positions[j]:")
print(hh)

transition matrix
[[0.49723535 0.47180762 0.03095703]
 [0.37852201 0.50632584 0.11515216]
 [0.5253994  0.11642175 0.35817885]]
observation matrix
[[4.51351361e-04 4.05187237e-02 1.33331734e-01 1.23057687e-01
  7.02640504e-01]
 [8.06260940e-06 3.22726641e-01 6.39043461e-01 3.82209436e-02
  8.91509141e-07]
 [2.75679296e-01 4.11532475e-06 9.71653088e-04 8.84559138e-02
  6.34889022e-01]]
stationary distribution
[0.44658521 0.45096487 0.10244993]
states and observations, first half of each row is states
[[2. 2. 0. 4. 4. 2.]
 [0. 0. 1. 4. 4. 2.]
 [0. 1. 2. 4. 2. 4.]
 [0. 1. 1. 2. 1. 2.]
 [1. 0. 1. 2. 4. 1.]
 [0. 0. 0. 4. 4. 4.]
 [1. 0. 0. 2. 4. 4.]
 [0. 0. 1. 4. 4. 2.]
 [1. 1. 2. 2. 2. 4.]
 [2. 2. 0. 4. 4. 4.]
 [0. 1. 0. 4. 2. 4.]
 [1. 0. 1. 2. 4. 2.]
 [0. 0. 1. 4. 4. 3.]
 [1. 1. 1. 2. 2. 2.]
 [0. 1. 0. 4. 2. 4.]
 [1. 1. 1. 1. 2. 2.]
 [0. 1. 1. 4. 2. 1.]
 [0. 0. 1. 4. 4. 2.]
 [1. 0. 1. 2. 4. 2.]
 [1. 0. 0. 1. 4. 4.]
 [1. 0. 0. 2. 4. 4.]
 [2. 0. 0. 4. 4. 4.]
 [1. 0. 1. 2. 4. 2.]
 [0. 0. 0. 4.

In [10]:
logh = np.log(h)

In [21]:
from torch import nn
from torch.utils import data

In [73]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.linearnetwork = nn.Sequential(
            nn.Linear(num_obs, num_hidden_state, bias=False)
        )
    
    def forward(self, logh, ind_x):
        logits = self.linearnetwork(ind_x)
        return nn.Softmax(dim = 1)(logh + logits)

In [74]:
x_one_hot = np.zeros((num_samples, num_obs))
for i in range(num_samples):
    x_one_hot[i, int(obs[i, pos[i]])] = 1
print(pos[:9])
print(obs[:9])
print(x_one_hot[:9])

[2 1 0 0 1 1 2 1 1]
[[4. 4. 2.]
 [4. 4. 2.]
 [4. 2. 4.]
 [2. 1. 2.]
 [2. 4. 1.]
 [4. 4. 4.]
 [2. 4. 4.]
 [4. 4. 2.]
 [2. 2. 4.]]
[[0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0.]]


In [75]:
features1, features2, labels = logh, x_one_hot, hh

In [88]:
# Model parameters.
lr = 1
epochs = 1000
batch_size = 10

In [89]:
dataset = data.TensorDataset(torch.FloatTensor(features1), torch.FloatTensor(features2), torch.FloatTensor(labels))
train_dl = data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [90]:
net = NeuralNetwork()

In [91]:
loss = nn.MSELoss()

In [92]:
trainer = torch.optim.SGD(net.parameters(), lr=lr)

In [93]:
for i in range(epochs):
    total_loss = 0
    for X1, X2, y in train_dl:
        l = loss(net(X1, X2) ,y)
        total_loss += l
        trainer.zero_grad()
        l.backward()
        trainer.step()
    if (i % 100 == 0):
        print("epoch: ", i)
        print("loss: ", total_loss)

epoch:  0
loss:  tensor(0.5779, grad_fn=<AddBackward0>)
epoch:  100
loss:  tensor(0.0111, grad_fn=<AddBackward0>)
epoch:  200
loss:  tensor(0.0043, grad_fn=<AddBackward0>)
epoch:  300
loss:  tensor(0.0026, grad_fn=<AddBackward0>)
epoch:  400
loss:  tensor(0.0019, grad_fn=<AddBackward0>)
epoch:  500
loss:  tensor(0.0015, grad_fn=<AddBackward0>)
epoch:  600
loss:  tensor(0.0012, grad_fn=<AddBackward0>)
epoch:  700
loss:  tensor(0.0010, grad_fn=<AddBackward0>)
epoch:  800
loss:  tensor(0.0009, grad_fn=<AddBackward0>)
epoch:  900
loss:  tensor(0.0008, grad_fn=<AddBackward0>)


In [95]:
net.eval()

NeuralNetwork(
  (linearnetwork): Sequential(
    (0): Linear(in_features=5, out_features=3, bias=False)
  )
)

In [96]:
for X1, X2, y in train_dl:
    print(net(X1, X2)[:9])
    print(y[:9])
    break

tensor([[0.1940, 0.8005, 0.0055],
        [0.9354, 0.0079, 0.0567],
        [0.9151, 0.0096, 0.0753],
        [0.1632, 0.8344, 0.0024],
        [0.9176, 0.0122, 0.0702],
        [0.7959, 0.0061, 0.1980],
        [0.8546, 0.0070, 0.1384],
        [0.9354, 0.0079, 0.0567],
        [0.1492, 0.8392, 0.0115]], grad_fn=<SliceBackward>)
tensor([[1.9789e-01, 8.0190e-01, 2.1677e-04],
        [9.4185e-01, 9.9453e-07, 5.8145e-02],
        [9.2260e-01, 1.1978e-06, 7.7401e-02],
        [1.6609e-01, 8.3381e-01, 9.4405e-05],
        [9.2768e-01, 1.5305e-06, 7.2323e-02],
        [7.9775e-01, 7.6093e-07, 2.0225e-01],
        [8.5832e-01, 8.7939e-07, 1.4168e-01],
        [9.4185e-01, 9.9453e-07, 5.8145e-02],
        [1.5322e-01, 8.4632e-01, 4.5619e-04]])


In [97]:
states, obs = generate_HMM_sequences(trans_mat, obs_mat, stat_dist, length, num_samples) # generate sample sequences

pos = np.random.randint(length, size = num_samples)

print("transition matrix")
print(trans_mat)
print("observation matrix")
print(obs_mat)
print("stationary distribution")
print(stat_dist)
print("states and observations, first half of each row is states")
print(np.concatenate((states, obs), axis = 1))
print("positions: ", pos)
h, x, hh = x_i_conditional_prob(trans_mat, obs_mat, stat_dist, obs, pos)
print("Pr[H_i|x_-i], j-th row is for j-th sample and i=positions[j]:")
print(h)
print("Pr[X_i|x_-i], j-th row is for j-th sample and i=positions[j]:")
print(x)
print("Pr[H_i|x_i,x_-i], j-th row is for j-th sample and i=positions[j]:")
print(hh)

transition matrix
[[0.49723535 0.47180762 0.03095703]
 [0.37852201 0.50632584 0.11515216]
 [0.5253994  0.11642175 0.35817885]]
observation matrix
[[4.51351361e-04 4.05187237e-02 1.33331734e-01 1.23057687e-01
  7.02640504e-01]
 [8.06260940e-06 3.22726641e-01 6.39043461e-01 3.82209436e-02
  8.91509141e-07]
 [2.75679296e-01 4.11532475e-06 9.71653088e-04 8.84559138e-02
  6.34889022e-01]]
stationary distribution
[0.44658521 0.45096487 0.10244993]
states and observations, first half of each row is states
[[1. 1. 0. 1. 1. 3.]
 [2. 0. 1. 0. 2. 2.]
 [0. 0. 0. 2. 4. 4.]
 [0. 0. 1. 4. 3. 2.]
 [1. 1. 1. 2. 3. 1.]
 [1. 2. 0. 1. 4. 4.]
 [0. 1. 1. 3. 2. 1.]
 [1. 1. 2. 1. 1. 4.]
 [1. 1. 1. 1. 1. 1.]
 [2. 2. 2. 4. 0. 0.]
 [1. 1. 0. 2. 1. 1.]
 [1. 0. 1. 2. 3. 2.]
 [1. 1. 0. 2. 2. 4.]
 [1. 1. 1. 1. 2. 2.]
 [1. 1. 1. 1. 1. 2.]
 [0. 0. 0. 4. 4. 4.]
 [1. 2. 2. 1. 4. 4.]
 [1. 0. 0. 2. 4. 4.]
 [0. 0. 0. 4. 4. 1.]
 [0. 1. 0. 4. 2. 4.]
 [1. 0. 0. 2. 3. 2.]
 [2. 1. 1. 4. 1. 2.]
 [0. 0. 0. 2. 4. 4.]
 [1. 0. 0. 2.

In [98]:
logh = np.log(h)
x_one_hot = np.zeros((num_samples, num_obs))
for i in range(num_samples):
    x_one_hot[i, int(obs[i, pos[i]])] = 1

In [106]:
test_dataset = data.TensorDataset(torch.FloatTensor(features1), torch.FloatTensor(features2), torch.FloatTensor(labels))
test_dl = data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [107]:
for X1, X2, y in test_dl:
    print(net(X1, X2)[:9])
    print(y[:9])
    break

tensor([[0.1940, 0.8005, 0.0055],
        [0.9354, 0.0079, 0.0567],
        [0.9151, 0.0096, 0.0753],
        [0.1632, 0.8344, 0.0024],
        [0.9176, 0.0122, 0.0702],
        [0.7959, 0.0061, 0.1980],
        [0.8546, 0.0070, 0.1384],
        [0.9354, 0.0079, 0.0567],
        [0.1492, 0.8392, 0.0115]], grad_fn=<SliceBackward>)
tensor([[1.9789e-01, 8.0190e-01, 2.1677e-04],
        [9.4185e-01, 9.9453e-07, 5.8145e-02],
        [9.2260e-01, 1.1978e-06, 7.7401e-02],
        [1.6609e-01, 8.3381e-01, 9.4405e-05],
        [9.2768e-01, 1.5305e-06, 7.2323e-02],
        [7.9775e-01, 7.6093e-07, 2.0225e-01],
        [8.5832e-01, 8.7939e-07, 1.4168e-01],
        [9.4185e-01, 9.9453e-07, 5.8145e-02],
        [1.5322e-01, 8.4632e-01, 4.5619e-04]])
