In [1]:
from CRF import CRF
from utils import crf_train_loop
import numpy as np
import torch

In [2]:
# two dice one is fair, one is loaded
fair_dice = np.array([1/6]*6)
loaded_dice = np.array([0.04,0.04,0.04,0.04,0.04,0.8])

probabilities = {'fair': fair_dice,
                'loaded': loaded_dice}

In [3]:
# if dice is fair at time t, 0.6 chance we stay fair, 0.4 chance it is loaded at time 2
transition_mat = {'fair': np.array([0.6, 0.4, 0.0]),
                 'loaded': np.array([0.3, 0.7, 0.0]),
                 'start': np.array([0.5, 0.5, 0.0])}
states = ['fair', 'loaded', 'start']
state2ix = {'fair': 0,
           'loaded': 1,
           'start': 2}

log_likelihood = np.hstack([np.log(fair_dice).reshape(-1,1), 
                            np.log(loaded_dice).reshape(-1,1)])

In [4]:
def simulate_data(n_timesteps):
    data = np.zeros(n_timesteps)
    prev_state = 'start'
    state_list = np.zeros(n_timesteps)
    for n in range(n_timesteps):
        next_state = np.random.choice(states, p=transition_mat[prev_state])
        state_list[n] = state2ix[next_state]
        next_data = np.random.choice([0,1,2,3,4,5], p=probabilities[next_state])
        data[n] = next_data
        prev_state = next_state
    return data, state_list

In [5]:
n_obs = 15
rolls = np.zeros((5000, n_obs)).astype(int)
targets = np.zeros((5000, n_obs)).astype(int)

for i in range(5000):
    data, dices = simulate_data(n_obs)
    rolls[i] = data.reshape(1, -1).astype(int)
    targets[i] = dices.reshape(1, -1).astype(int)


In [None]:
model = CRF(2, log_likelihood)

In [None]:
model = crf_train_loop(model, rolls, targets, 1, 0.001)

Epoch 0: Batch 0/100 loss is 12.0569
Epoch 0: Batch 1/100 loss is 7.5017
Epoch 0: Batch 2/100 loss is 7.3981
Epoch 0: Batch 3/100 loss is 7.0718
Epoch 0: Batch 4/100 loss is 7.1291
Epoch 0: Batch 5/100 loss is 7.2085
Epoch 0: Batch 6/100 loss is 7.2193
Epoch 0: Batch 7/100 loss is 7.9664
Epoch 0: Batch 8/100 loss is 7.1083
Epoch 0: Batch 9/100 loss is 7.7929
Epoch 0: Batch 10/100 loss is 7.1832
Epoch 0: Batch 11/100 loss is 7.0405
Epoch 0: Batch 12/100 loss is 7.3751
Epoch 0: Batch 13/100 loss is 7.4534
Epoch 0: Batch 14/100 loss is 7.2513
Epoch 0: Batch 15/100 loss is 6.7386
Epoch 0: Batch 16/100 loss is 7.1120
Epoch 0: Batch 17/100 loss is 7.7527
Epoch 0: Batch 18/100 loss is 7.3870
Epoch 0: Batch 19/100 loss is 7.0290
Epoch 0: Batch 20/100 loss is 7.4686
Epoch 0: Batch 21/100 loss is 7.2330
Epoch 0: Batch 22/100 loss is 6.9453
Epoch 0: Batch 23/100 loss is 6.9831
Epoch 0: Batch 24/100 loss is 7.7516
Epoch 0: Batch 25/100 loss is 7.4639
Epoch 0: Batch 26/100 loss is 6.5551
Epoch 0: B

In [None]:
torch.save(model.state_dict(), "./checkpoint.hdf5")

In [None]:
model.load_state_dict(torch.load("./checkpoint.hdf5"))

In [None]:
data, dices = simulate_data(15)
test_rolls = data.reshape(1, -1).astype(int)
test_targets = dices.reshape(1, -1).astype(int)

In [None]:
test_rolls[0]

In [None]:
model.forward(test_rolls[0])[0]

In [None]:
test_targets[0]

In [None]:
list(model.parameters())[0].data.numpy()

In [None]:
data, dices = simulate_data(15)
test_rolls = data.reshape(1, -1).astype(int)
test_targets = dices.reshape(1, -1).astype(int)

In [None]:
test_rolls[0]

In [None]:
model.forward(test_rolls[0])[0]

In [None]:
test_targets[0]