In [2]:
import sys
import numpy as np
import robot
def log_prob(prob):
    """compute log probability"""
    with np.errstate(divide="ignore"):
        log_prob = np.log(prob)
    log_prob[
        prob <= 0
    ] = -np.inf  # Set log probability to -inf where prob is 0 or negative
    return log_prob


def create_state_index_mapping(all_possible_hidden_states):
    """Creates a mapping from hidden states to indices and vice versa."""
    state_to_index = {state: i for i, state in enumerate(all_possible_hidden_states)}
    index_to_state = {i: state for state, i in state_to_index.items()}
    return state_to_index, index_to_state


def distribution_to_vector(distribution, state_to_index):
    """Converts a distribution dictionary to a vector using the state index mapping."""
    vector = np.zeros(len(state_to_index))
    for state, prob in distribution.items():
        vector[state_to_index[state]] = prob
    return vector


def matrix_from_transition_model(
    transition_model, all_possible_hidden_states, state_to_index
):
    """Creates a transition matrix from the transition model using the state index mapping."""
    size = len(all_possible_hidden_states)
    matrix = np.zeros((size, size))
    for from_state in all_possible_hidden_states:
        from_index = state_to_index[from_state]
        for to_state, prob in transition_model(from_state).items():
            to_index = state_to_index[to_state]
            matrix[from_index, to_index] = prob
    return matrix


def matrix_from_observation_model(
    observation_model,
    all_possible_hidden_states,
    all_possible_observed_states,
    state_to_index,
):
    """Creates an observation matrix from the observation model for a specific observation."""
    num_states = len(all_possible_hidden_states)
    num_observations = len(all_possible_observed_states)
    matrix = np.zeros((num_states, num_observations))
    for state in all_possible_hidden_states:
        state_index = state_to_index[state]
        for observation, prob in observation_model(state).items():
            if observation is not None:  # handle None observation separately if needed
                observation_index = all_possible_observed_states.index(observation)
                matrix[state_index, observation_index] = prob
    return matrix

In [74]:
def forward_backward(
    all_possible_hidden_states,
    all_possible_observed_states,
    prior_distribution,
    transition_model,
    observation_model,
    observations,
):
    """
    Inputs
    ------
    all_possible_hidden_states: a list of possible hidden states

    all_possible_observed_states: a list of possible observed states

    prior_distribution: a distribution over states

    transition_model: a function that takes a hidden state and returns a
        Distribution for the next state

    observation_model: a function that takes a hidden state and returns a
        Distribution for the observation from that hidden state

    observations: a list of observations, one per hidden state
        (a missing observation is encoded as None)

    Output
    ------
    A list of marginal distributions at each time step; each distribution
    should be encoded as a Distribution (see the Distribution class in
    robot.py and see how it is used in both robot.py and the function
    generate_data() above, and the i-th Distribution should correspond to time
    step i
    """

    num_time_steps = len(observations)
    # precompute transtition matrix and observation matrix
    state_to_index, index_to_state = create_state_index_mapping(
        all_possible_hidden_states
    )
    log_transition_matrix = log_prob(
        matrix_from_transition_model(
            transition_model, all_possible_hidden_states, state_to_index
        )
    )
    log_observation_matrix = log_prob(
        matrix_from_observation_model(
            observation_model,
            all_possible_hidden_states,
            all_possible_observed_states,
            state_to_index,
        )
    )  # a |X|*|Y| matrix
    prior_vector = distribution_to_vector(prior_distribution, state_to_index)
    # forward_messages = [None] * num_time_steps
    forward_messages = np.zeros((num_time_steps, len(all_possible_hidden_states)))
    forward_messages[0] = log_prob(prior_vector)

    # TODO: Compute the forward messages
    for t in range(1, num_time_steps):
        if observations[t - 1] is None:
            log_observation_vector = np.zeros(len(all_possible_hidden_states))

        else:
            observation_index = all_possible_observed_states.index(observations[t - 1])
            log_observation_vector = log_observation_matrix[:, observation_index]

        # avoid numerical issues with the exp and log
        log_forward_message = forward_messages[t - 1][:, None]+ log_transition_matrix + log_observation_vector[:,None]
        log_forward_message = np.sum(np.exp(log_forward_message),axis=0)
        log_forward_message = log_prob(log_forward_message)

    

        forward_messages[t, :] = log_forward_message.squeeze()

    # backward_messages = [None] * num_time_steps
    backward_messages = np.zeros((num_time_steps, len(all_possible_hidden_states)))
    backward_messages[-1] = np.log(np.ones(len(all_possible_hidden_states)))
    # TODO: Compute the backward messages
    for t in range(num_time_steps - 2, -1, -1):
        if observations[t + 1] is not None:
            observation_index = all_possible_observed_states.index(observations[t + 1])
            log_observation_vector = log_observation_matrix[:, observation_index]
        else:
            log_observation_vector = np.zeros(len(all_possible_hidden_states))

        log_backward_message = backward_messages[t + 1][:,None] + log_transition_matrix.T + log_observation_vector[:,None]
        log_backward_message = np.sum(np.exp(log_backward_message),axis=0)
        log_backward_message = log_prob(log_backward_message)
       
        backward_messages[t, :] = log_backward_message.squeeze()
        

    marginals = [None] * num_time_steps  # remove this
    # TODO: Compute the marginals
    for t in range(num_time_steps):
        if observations[t] is not None:
            observation_index = all_possible_observed_states.index(observations[t])
            log_observation_vector = log_observation_matrix[:, observation_index]
        else:
            log_observation_vector = np.zeros(len(all_possible_hidden_states))

        log_marginal = forward_messages[t] + backward_messages[t] + log_observation_vector
        log_max = np.max(log_marginal)
        if np.isinf(log_max):  # if log_max is -inf, then all probabilities are zero
            marginal = np.zeros(len(all_possible_hidden_states))
        else:
            marginal = np.exp(log_marginal)
            
        # Convert the marginal distribution back to the Distribution form
        marginals[t] = robot.Distribution(
            {index_to_state[i]: prob for i, prob in enumerate(marginal)}
        )
        marginals[t].renormalize()

    return marginals

In [77]:
filename = 'test.txt'
hidden_states, observations = robot.load_data(filename)
need_to_generate_data = False
num_time_steps = len(hidden_states)
all_possible_hidden_states = robot.get_all_hidden_states()
all_possible_observed_states = robot.get_all_observed_states()
prior_distribution = robot.initial_distribution()
print("Running forward-backward...")
marginals = forward_backward(
    all_possible_hidden_states,
    all_possible_observed_states,
    prior_distribution,
    robot.transition_model,
    robot.observation_model,
    observations,
)
print("\n")

timestep = num_time_steps - 1
print("Most likely parts of marginal at time %d:" % (timestep))
if marginals[timestep] is not None:
    top_10_states = sorted(
        marginals[timestep].items(), key=lambda x: x[-1], reverse=True
    )[:10]
    print([s for s in top_10_states if s[-1] > 0])
else:
    print("*No marginal computed*")
print("\n")
timestep = 1
print("Most likely parts of marginal at time %d:" % (timestep))
if marginals[timestep] is not None:
    top_10_states = sorted(
        marginals[timestep].items(), key=lambda x: x[-1], reverse=True
    )[:10]
    print([s for s in top_10_states if s[-1] > 0])
else:
    print("*No marginal computed*")
print("\n")

Running forward-backward...


Most likely parts of marginal at time 99:
[((11, 0, 'stay'), 0.8102633355840676), ((11, 0, 'right'), 0.1796083727211315), ((10, 1, 'down'), 0.010128291694800885)]


Most likely parts of marginal at time 1:
[((6, 5, 'right'), 0.5), ((6, 5, 'down'), 0.5)]




In [46]:
m1 = np.array([[1,7],[2,1],[10,5]])
m_sum = np.sum(m1,axis=0)
m_sum

array([13, 13])

In [68]:
m_prev = np.array([0.5,0.2,0.1])
m_tran = np.eye(3)
m_obser = np.array([0.1,0.2,0.3])
sum = m_prev[:, None]+ m_tran + m_obser[:,None]
print(sum)
sum = np.sum(sum,axis=0)
print(sum)

[[1.6 0.6 0.6]
 [0.4 1.4 0.4]
 [0.4 0.4 1.4]]
[2.4 2.4 2.4]
