In [9]:
import numpy as np
from scipy.special import logsumexp
import random

random.seed(42)

def init_params(n_states, n_obs):
    trans_probs = np.log(np.random.rand(n_states, n_states) + 1e-4)
    emiss_probs = np.log(np.random.rand(n_states, n_obs) + 1e-4)
    pi = np.log(np.random.rand(n_states) + 1e-4)

    trans_probs -= logsumexp(trans_probs, axis=1, keepdims=True)
    emiss_probs -= logsumexp(emiss_probs, axis=1, keepdims=True)
    pi -= logsumexp(pi)
    
    return trans_probs, emiss_probs, pi

def forward(observations, trans_probs, emiss_probs, pi):
    '''Computes the forward probabilities in log.'''
    T, N = len(observations), len(pi)
    alpha = np.full((T, N), -np.inf)
    alpha[0] = pi + emiss_probs[:, observations[0]]

    for t in range(1, T):
        for j in range(N):
            alpha[t, j] = logsumexp(alpha[t-1] + trans_probs[:, j]) + emiss_probs[j, observations[t]]

    return alpha

def backward(observations, trans_probs, emiss_probs):
    '''Computes the backward probabilities in log.'''
    T, N = len(observations), trans_probs.shape[0]
    beta = np.full((T, N), -np.inf)
    beta[-1] = 0

    for t in range(T-2, -1, -1):
        for i in range(N):
            beta[t, i] = logsumexp(trans_probs[i, :] + emiss_probs[:, observations[t+1]] + beta[t+1])

    return beta

def baum_welch(observations, n_states=2, n_obs=6, max_iter=100):
    '''Applies the Baum-Welch algorithm to find the most likely parameters.'''
    trans_probs, emiss_probs, pi = init_params(n_states, n_obs)
    for _ in range(max_iter):
        alpha = forward(observations, trans_probs, emiss_probs, pi)
        beta = backward(observations, trans_probs, emiss_probs)
        
        # Expectation step for xi and gamma
        xi = np.full((len(observations)-1, n_states, n_states), -np.inf)
        for t in range(len(observations)-1):
            for i in range(n_states):
                for j in range(n_states):
                    xi[t, i, j] = alpha[t, i] + trans_probs[i, j] + emiss_probs[j, observations[t+1]] + beta[t+1, j]
            xi[t] -= logsumexp(xi[t])

        gamma = alpha + beta
        gamma -= logsumexp(gamma, axis=1, keepdims=True)
        
        # Maximization step
        for i in range(n_states):
            trans_probs[i] = logsumexp(xi[:-1, :, i], axis=0)
            trans_probs[i] -= logsumexp(trans_probs[i])
        for j in range(n_states):
            for k in range(n_obs):
                mask = observations == k
                emiss_probs[j, k] = logsumexp(gamma[mask, j])
            emiss_probs[j] -= logsumexp(emiss_probs[j])
        pi = gamma[0] - logsumexp(gamma[0])

    return np.exp(pi), np.exp(trans_probs), np.exp(emiss_probs)

def forward_backward(observations, trans_probs, emiss_probs, pi):
    alpha = forward(observations, trans_probs, emiss_probs, pi)
    beta = backward(observations, trans_probs, emiss_probs)

    gamma = np.exp(alpha + beta - logsumexp(alpha[-1]))
    most_likely_states = np.argmax(gamma, axis=1)
    
    return most_likely_states, gamma

def run_multiple_baum_welch(observations, n_states, n_obs, n_runs, max_iter):
    '''Runs the Baum-Welch algorithm multiple times.'''
    results = [baum_welch(observations, n_states, n_obs, max_iter) for _ in range(n_runs)]
    results_pi, results_trans_probs, results_emiss_probs = zip(*results)

    # Calculating statistics for initial, transition, and emission probabilities
    mean_pi, std_pi = calc_stats(results_pi)
    mean_trans_probs, std_trans_probs = calc_stats(results_trans_probs)
    mean_emiss_probs, std_emiss_probs = calc_stats(results_emiss_probs)

    return (mean_pi, std_pi), (mean_trans_probs, std_trans_probs), (mean_emiss_probs, std_emiss_probs)

def calc_stats(results):
    '''Calculates statistics (mean, standard deviation) for arrays of matrices.'''
    mean = np.mean(results, axis=0)
    std = np.std(results, axis=0)
    return mean, std

def analyze_observations(file_path, n_states=2, n_obs=6, n_runs=10, max_iter=50):
    observations = np.loadtxt(file_path, dtype=int)

    (stats_pi, stats_trans_probs, stats_emiss_probs) = run_multiple_baum_welch(
        observations, n_states, n_obs, n_runs, max_iter
    )
    (mean_pi, std_pi), (mean_trans_probs, std_trans_probs), (mean_emiss_probs, std_emiss_probs) = stats_pi, stats_trans_probs, stats_emiss_probs

    pi_sample, A_sample, B_sample = baum_welch(observations[:100], n_states=2, n_obs=6, max_iter=10)
    most_likely_states_sample, state_probabilities_sample = forward_backward(observations[:100], A_sample, B_sample, pi_sample)
    # Decide to switch dice
    print("Most Likely States Sequence (Sample):", most_likely_states_sample)
    print("State Probabilities (Sample - First 10):")
    for prob in state_probabilities_sample[:10]:
        print(f"Fair Die: {prob[0]:.3f}, Loaded Die: {prob[1]:.3f}")

    # Calculating and printing mean state probabilities for the sample
    mean_state_probabilities_sample = np.mean(state_probabilities_sample, axis=0)
    print("Mean State Probabilities (Sample):")
    print(f"Fair Die: {mean_state_probabilities_sample[0]:.3f}")
    print(f"Loaded Die: {mean_state_probabilities_sample[1]:.3f}")

    print("Learned initial state probabilities (mean):", mean_pi)
    print("Learned initial state probabilities (std deviation):", std_pi)
    print("Learned transition probabilities (mean):\n", mean_trans_probs)
    print("Learned transition probabilities (std deviation):\n", std_trans_probs)
    print("Learned emission probabilities (mean):\n", mean_emiss_probs)
    print("Learned emission probabilities (std deviation):\n", std_emiss_probs)

analyze_observations("rolls.txt")

Most Likely States Sequence (Sample): [0 1 1 1 1 1 1 1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 1 1 1
 1 1 1 1 0 0 0 1 1 0 0 0 1 1 0 0 0 0 1 1 1 0 0 1 1 1 0 0 1 1 0 0 0 0 0 0 1
 1 1 1 1 0 1 1 1 1 1 1 1 1 0 1 0 0 0 1 1 1 1 0 0 0 0]
State Probabilities (Sample - First 10):
Fair Die: 0.756, Loaded Die: 0.244
Fair Die: 0.457, Loaded Die: 0.543
Fair Die: 0.458, Loaded Die: 0.542
Fair Die: 0.396, Loaded Die: 0.604
Fair Die: 0.399, Loaded Die: 0.601
Fair Die: 0.477, Loaded Die: 0.523
Fair Die: 0.472, Loaded Die: 0.528
Fair Die: 0.477, Loaded Die: 0.523
Fair Die: 0.491, Loaded Die: 0.509
Fair Die: 0.474, Loaded Die: 0.526
Mean State Probabilities (Sample):
Fair Die: 0.495
Loaded Die: 0.505
Learned initial state probabilities (mean): [0.58519562 0.41480438]
Learned initial state probabilities (std deviation): [0.47970883 0.47970883]
Learned transition probabilities (mean):
 [[0.65628919 0.34371081]
 [0.42873138 0.57126862]]
Learned transition probabilities (std deviation):
 [[0.12684