<a href="https://colab.research.google.com/github/samaneh-m/TU-simulation-base-inference/blob/main/Viterbi_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install "bayesflow>=2.0"
!pip install tensorflow  # CPU version is fine; GPU optional



In [2]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import bayesflow

INFO:bayesflow:Using backend 'tensorflow'


In [3]:
import numpy as np
import tensorflow as tf
import bayesflow as bf

from bayesflow.types import Shape
from bayesflow.utils import tree_concatenate
from bayesflow.utils.decorators import allow_batch_size

In [4]:
class HiddenStateSimulator:
    def __init__(self, seq_len=50):
        self.seq_len = seq_len
        self.states = ['alpha', 'other']
        self.amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
                            'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

        # Emission probabilities
        self.emissions = {
            'alpha': np.array([
                12, 6, 3, 5, 1, 9, 5, 4, 2, 7,
                12, 6, 3, 4, 2, 5, 4, 1, 3, 6
            ]) / 100,
            'other': np.array([
                6, 5, 5, 6, 2, 5, 3, 9, 3, 5,
                8, 6, 2, 4, 6, 7, 6, 1, 4, 7
            ]) / 100
        }

        # Transition matrix: rows are from-states, columns are to-states
        self.transitions = {
            'alpha': {'alpha': 0.90, 'other': 0.10},
            'other': {'alpha': 0.05, 'other': 0.95}
        }

    def sample(self, batch_shape=1):
        samples = []
        states = []

        for _ in range(batch_shape):
            seq = []
            st_seq = []

            current_state = 'other'
            for _ in range(self.seq_len):
                # Emit amino acid
                probs = self.emissions[current_state]
                aa = np.random.choice(self.amino_acids, p=probs)
                seq.append(aa)
                st_seq.append(current_state)

                # Transition to next state
                next_state = np.random.choice(self.states, p=[
                    self.transitions[current_state]['alpha'],
                    self.transitions[current_state]['other']
                ])
                current_state = next_state

            samples.append(seq)
            states.append(st_seq)

        return {
            'observed_sequence': samples,
            'hidden_states': states
        }

In [5]:
sim = HiddenStateSimulator(seq_len=50)
data = sim.sample(batch_shape=1000)

print("Amino acid sequence:")
print(data['observed_sequence'][0])
print("\nHidden states:")
print(data['hidden_states'][0])
# len(data['observed_sequence'][0])

Amino acid sequence:
['E', 'P', 'L', 'G', 'P', 'F', 'K', 'S', 'G', 'K', 'E', 'D', 'L', 'W', 'I', 'D', 'S', 'D', 'G', 'C', 'P', 'T', 'K', 'S', 'R', 'S', 'Q', 'V', 'F', 'N', 'H', 'M', 'D', 'T', 'K', 'P', 'I', 'E', 'I', 'S', 'Y', 'P', 'N', 'Y', 'D', 'P', 'K', 'L', 'I', 'S']

Hidden states:
['other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'alpha', 'alpha', 'alpha', 'alpha', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other']


In [18]:
# Map amino acids to integers (0 to 19)
amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
               'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
aa_to_index = {aa: i for i, aa in enumerate(amino_acids)}

# Map states to integers
states = ["alpha", "other"]
state_to_int = {'alpha': 1, 'other': 0}
index_to_state = {0: "alpha", 1: "other"}

In [10]:
def viterbi(obs_seq, states, start_prob, trans_prob, emit_prob):
    n_states = len(states)
    T = len(obs_seq)

    log_prob = np.full((n_states, T), -np.inf)
    backpointer = np.zeros((n_states, T), dtype=int)

    # Initialization
    for s, state in enumerate(states):
        aa_idx = aa_to_index[obs_seq[0]]
        log_prob[s, 0] = np.log(start_prob[state]) + np.log(emit_prob[state][aa_idx])

    # Recursion
    for t in range(1, T):
        aa_idx = aa_to_index[obs_seq[t]]
        for s, state in enumerate(states):
            trans_probs = [
                log_prob[sp, t - 1] + np.log(trans_prob[states[sp]][state])
                for sp in range(n_states)
            ]
            best_prev_state = np.argmax(trans_probs)
            log_prob[s, t] = trans_probs[best_prev_state] + np.log(emit_prob[state][aa_idx])
            backpointer[s, t] = best_prev_state

    # Backtracking
    best_last_state = np.argmax(log_prob[:, -1])
    best_path = [best_last_state]

    for t in range(T - 1, 0, -1):
        best_last_state = backpointer[best_last_state, t]
        best_path.insert(0, best_last_state)

    return [index_to_state[i] for i in best_path]

In [24]:
transition_probs = {
    "alpha": {"alpha": 0.90, "other": 0.10},
    "other": {"alpha": 0.05, "other": 0.95}
}

emission_probs = {
    "alpha": np.array([12, 6, 3, 5, 1, 9, 5, 4, 2, 7, 12, 6, 3, 4, 2, 5, 4, 1, 3, 6]) / 100,  # your alpha emission vector
    "other": np.array([6, 5, 5, 6, 2, 5, 3, 9, 3, 5, 8, 6, 2, 4, 6, 7, 6, 1, 4, 7]) / 100,  # your other emission vector
}

In [25]:
sequence = data["observed_sequence"][0]  # e.g. ["A", "E", "F", ..., "V"]
start_prob = {"alpha": 0, "other": 1}  # or use the HMM's long-run proportions

viterbi_path = viterbi(
    obs_seq=sequence,
    states=states,
    start_prob=start_prob,
    trans_prob=transition_probs,
    emit_prob=emission_probs
)

print("Viterbi decoded states:")
print(viterbi_path)

Viterbi decoded states:
['other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other']


  log_prob[s, 0] = np.log(start_prob[state]) + np.log(emit_prob[state][aa_idx])
