

# **üìò Recap**

### **What is an HMM (Hidden Markov Model)?**

A **Hidden Markov Model** is a statistical model used when:

* Something is happening **behind the scenes** (hidden states) ‚Äî we cannot see these states directly.
* We only observe **outputs** (visible signals).
* The hidden states follow the **Markov property** (next state depends only on the current state).
* Each state emits observations with some probability.

Think of it like:

> You see the symptoms (observations) but not the actual health condition (hidden state).

---

# **What the Notebook is Doing**

The notebook **teaches HMMs through a fun toy example**:

### ‚úÖ **1. A motivating scenario**

A friend travels between cities and sends you a selfie each day.
You want to **guess which city they are in** based only on:

* the selfie (observation)
* knowledge of how they usually move between cities (transition probabilities)

This real-life analogy is modeled as an HMM.

---

### ‚úÖ **2. Explaining HMM structure**

The notebook walks through:

* **Hidden states** ‚Üí the city your friend is in
* **Observations** ‚Üí the type of selfie you receive
* **Transition probabilities** ‚Üí chance of moving from one city to another
* **Emission probabilities** ‚Üí chance of getting a certain kind of selfie from a city

---

### ‚úÖ **3. HMM assumptions**

The two classic assumptions:

1. **Markov property**: next city depends only on current city
2. **Emission independence**: selfie depends only on current city, not previous days

---

### ‚úÖ **4. HMM components**

The notebook introduces all the key parameters:

* **œÄ (initial distribution)** ‚Äì where the journey starts
* **A (transition matrix)** ‚Äì how likely it is to move between states
* **B (emission matrix)** ‚Äì how likely a state is to emit each observation




# Fun with Hidden Markov Models
*by Loren Lugosch*

This notebook introduces the Hidden Markov Model (HMM), a simple model for sequential data.

We will see:
- what an HMM is and when you might want to use it;
- the so-called "three problems" of an HMM; and
- how to implement an HMM in PyTorch.

(The code in this notebook can also be found at https://github.com/lorenlugosch/pytorch_HMM.)

A hypothetical scenario
------

To motivate the use of HMMs, imagine that you have a friend who gets to do a lot of travelling. Every day, this jet-setting friend sends you a selfie from the city they‚Äôre in, to make you envious.

<center>

![Diagram of a traveling friend sending selfies](https://github.com/lorenlugosch/pytorch_HMM/blob/master/img/selfies.png?raw=true)
</center>





How would you go about guessing which city the friend is in each day, just by looking at the selfies?

If the selfie contains a really obvious landmark, like the Eiffel Tower, it will be easy to figure out where the photo was taken. If not, it will be a lot harder to infer the city.

But we have a clue to help us: the city the friend is in each day is not totally random. For example, the friend will probably remain in the same city for a few days to sightsee before flying to a new city.

## The HMM setup

The hypothetical scenario of the friend travelling between cities and sending you selfies can be modeled using an HMM.


An HMM models a system that is in a particular state at any given time and produces an output that depends on that state.

At each timestep or clock tick, the system randomly decides on a new state and jumps into that state. The system then randomly generates an observation. The states are "hidden": we can't observe them. (In the cities/selfies analogy, the unknown cities would be the hidden states, and the selfies would be the observations.)

Let's denote the sequence of states as $\mathbf{z} = \{z_1, z_2, \dots, z_T \}$, where each state is one of a finite set of $N$ states, and the sequence of observations as $\mathbf{x} = \{x_1, x_2, \dots, x_T\}$. The observations could be discrete, like letters, or real-valued, like audio frames.

<center>

![Diagram of an HMM for three timesteps](https://github.com/lorenlugosch/pytorch_HMM/blob/master/img/hmm.png?raw=true)
</center>

An HMM makes two key assumptions:
- **Assumption 1:** The state at time $t$ depends *only* on the state at the previous time $t-1$.
- **Assumption 2:** The output at time $t$ depends *only* on the state at time $t$.

These two assumptions make it possible to efficiently compute certain quantities that we may be interested in.

## Components of an HMM
An HMM has three sets of trainable parameters.
  


- The **transition model** is a square matrix $A$, where $A_{s, s'}$ represents $p(z_t = s|z_{t-1} = s')$, the probability of jumping from state $s'$ to state $s$.

- The **emission model** $b_s(x_t)$ tells us $p(x_t|z_t = s)$, the probability of generating $x_t$ when the system is in state $s$. For discrete observations, which we will use in this notebook, the emission model is just a lookup table, with one row for each state, and one column for each observation. For real-valued observations, it is common to use a Gaussian mixture model or neural network to implement the emission model.

- The **state priors** tell us $p(z_1 = s)$, the probability of starting in state $s$. We use $\pi$ to denote the vector of state priors, so $\pi_s$ is the state prior for state $s$.

Let's program an HMM class in PyTorch.

In [16]:
import torch
import numpy as np

class HMM(torch.nn.Module):
  """
  Hidden Markov Model with discrete observations.
  """
  def __init__(self, M, N):
    super(HMM, self).__init__()
    self.M = M # number of possible observations
    self.N = N # number of states

    # A
    self.transition_model = TransitionModel(self.N)

    # b(x_t)
    self.emission_model = EmissionModel(self.N,self.M)

    # pi
    self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N))

    # use the GPU
    self.is_cuda = torch.cuda.is_available()
    if self.is_cuda: self.cuda()

class TransitionModel(torch.nn.Module):
  def __init__(self, N):
    super(TransitionModel, self).__init__()
    self.N = N
    self.unnormalized_transition_matrix = torch.nn.Parameter(torch.randn(N,N))

class EmissionModel(torch.nn.Module):
  def __init__(self, N, M):
    super(EmissionModel, self).__init__()
    self.N = N
    self.M = M
    self.unnormalized_emission_matrix = torch.nn.Parameter(torch.randn(N,M))

To sample from the HMM, we start by picking a random initial state from the state prior distribution.

Then, we sample an output from the emission distribution, sample a transition from the transition distribution, and repeat.

(Notice that we pass the unnormalized model parameters through a softmax function to make them into probabilities.)




### **1. Why are the transition, emission, and prior parameters named ‚Äúunnormalized‚Äù?**

**Answer:**
They are stored as **raw learnable tensors**, not probabilities yet.
Before using them in Forward/Backward/Viterbi, we typically apply a **softmax** to convert them into valid probability distributions:

* `pi = softmax(unnormalized_state_priors)`
* `A = softmax(unnormalized_transition_matrix, dim=1)`
* `B = softmax(unnormalized_emission_matrix, dim=1)`

This allows PyTorch to optimize them freely without the constraint of summing to 1 during training.

---

### **2. Why use `torch.nn.Parameter` for all HMM matrices?**

**Answer:**
`nn.Parameter` tells PyTorch:
**‚ÄúThis tensor should be updated by gradient descent.‚Äù**
So transition matrix, emission matrix, and state priors become **learnable parameters** when training with `loss.backward()`.

If they were plain tensors, PyTorch would *not* update them.

---

### **3. Why are the parameters initialized with `torch.randn()` instead of fixed values?**

**Answer:**
`torch.randn()` provides **random normal initialization**, which:

* avoids symmetry in optimization
* gives the model flexibility to learn any distribution
* works well with softmax-based probability normalization

Since HMM will be trained, random initialization is appropriate and common practice.

---

### **4. Why is CUDA checked inside the HMM constructor?**

**Answer:**
Because HMM contains learnable parameters (`A`, `B`, `pi`), moving the main model to GPU automatically moves all submodules:

```python
if self.is_cuda:
    self.cuda()
```

This ensures:

* all computation (forward, backward, Viterbi) is **GPU-accelerated**
* no mismatch between CPU and GPU tensors, which would lead to runtime errors

---

### **5. Why does the model separate TransitionModel and EmissionModel into different classes?**

**Answer:**
This design makes your HMM **modular and clean**:

* TransitionModel handles **state ‚Üí state** probabilities
* EmissionModel handles **state ‚Üí observation** probabilities
* Each can be extended or replaced independently, e.g.,

  * add dropout
  * switch to neural emissions
  * visualize parameters
  * plug in learned embeddings

It also makes the code easier to read and demo, since each part of the HMM is encapsulated.




In [17]:
def sample(self, T=10):
  state_priors = torch.nn.functional.softmax(self.unnormalized_state_priors, dim=0)
  transition_matrix = torch.nn.functional.softmax(self.transition_model.unnormalized_transition_matrix, dim=0)
  emission_matrix = torch.nn.functional.softmax(self.emission_model.unnormalized_emission_matrix, dim=1)

  # sample initial state
  z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
  z = []; x = []
  z.append(z_t)
  for t in range(0,T):
    # sample emission
    x_t = torch.distributions.categorical.Categorical(emission_matrix[z_t]).sample().item()
    x.append(x_t)

    # sample transition
    z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item()
    if t < T-1: z.append(z_t)

  return x, z

# Add the sampling method to our HMM class
HMM.sample = sample



### **1. Why do we apply softmax to the priors, transition matrix, and emission matrix before sampling?**

**Answer:**
The parameters stored in the model are **unnormalized logits**.
Softmax converts them into valid probability distributions:

* `state_priors` ‚Üí probabilities of initial states
* `transition_matrix` ‚Üí probabilities of moving from one state to another
* `emission_matrix` ‚Üí probabilities of emitting observations

Without softmax, sampling would fail because the values would not sum to 1.

---

### **2. Why does the transition sampling use `transition_matrix[:, z_t]` instead of `transition_matrix[z_t]`?**

**Answer:**
Because the transition matrix is stored so that:

* **Columns** represent the *next-state* distribution
* **Rows** represent the *from-state*

So `transition_matrix[:, z_t]` means:

> ‚ÄúGiven we are currently in state z_t, sample the next state based on the column z_t.‚Äù

Many students expect row-wise transitions; your code uses column-wise.
This is perfectly valid as long as consistency is maintained.

---

### **3. Why do we sample emissions before sampling the next hidden state?**

**Answer:**
This follows the generative process of an HMM:

1. Pick initial hidden state
2. Emit observation from that state
3. Transition to next hidden state
4. Repeat

The order **emission ‚Üí transition** reflects the standard generative structure:

```
z_t ‚Üí x_t
z_t ‚Üí z_(t+1)
```

Both orders are allowed mathematically, but this one matches common textbooks.

---

### **4. Why does the method return both `x` (observations) and `z` (hidden states)?**

**Answer:**
Because for a demo:

* `x` shows what a **generated observation sequence** looks like
* `z` allows you to **demonstrate the underlying hidden path**

This is extremely helpful when teaching:

* how HMMs generate data
* how Viterbi decoding tries to recover z from x
* how transition and emission probabilities shape the sequence

It‚Äôs a great choice for visualization.

---

### **5. Why is the initial state added to `z` before the sampling loop, but later states added inside the loop?**

**Answer:**
Because the first state is sampled **before** time-step iteration:

```python
z_t = Categorical(state_priors).sample()
z.append(z_t)
```

Then each loop iteration emits an observation and transitions to the next state.
The condition:

```python
if t < T-1:
    z.append(z_t)
```

avoids adding an extra next-state after the final iteration.

This keeps `len(z) = T`, matching the length of `x`.


Let's try hard-coding an HMM for generating fake words. (We'll also add some helper functions for encoding and decoding strings.)

We will assume that the system has one state for generating vowels and one state for generating consonants, and the transition matrix has 0s on the diagonal---in other words, the system cannot stay in the vowel state or the consonant state for one than one timestep; it has to switch.

Since we pass the transition matrix through a softmax, to get 0s we set the unnormalized parameter values to $-\infty$.

In [18]:
import string
alphabet = string.ascii_lowercase

def encode(s):
  """
  Convert a string into a list of integers
  """
  x = [alphabet.index(ss) for ss in s]
  return x

def decode(x):
  """
  Convert list of ints to string
  """
  s = "".join([alphabet[xx] for xx in x])
  return s

# Initialize the model
model = HMM(M=len(alphabet), N=2)

# Hard-wiring the parameters!
# Let state 0 = consonant, state 1 = vowel
for p in model.parameters():
    p.requires_grad = False # needed to do lines below
model.unnormalized_state_priors[0] = 0.    # Let's start with a consonant more frequently
model.unnormalized_state_priors[1] = -0.5
print("State priors:", torch.nn.functional.softmax(model.unnormalized_state_priors, dim=0))

# In state 0, only allow consonants; in state 1, only allow vowels
vowel_indices = torch.tensor([alphabet.index(letter) for letter in "aeiou"])
consonant_indices = torch.tensor([alphabet.index(letter) for letter in "bcdfghjklmnpqrstvwxyz"])
model.emission_model.unnormalized_emission_matrix[0, vowel_indices] = -np.inf
model.emission_model.unnormalized_emission_matrix[1, consonant_indices] = -np.inf
print("Emission matrix:", torch.nn.functional.softmax(model.emission_model.unnormalized_emission_matrix, dim=1))

# Only allow vowel -> consonant and consonant -> vowel
model.transition_model.unnormalized_transition_matrix[0,0] = -np.inf  # consonant -> consonant
model.transition_model.unnormalized_transition_matrix[0,1] = 0.       # vowel -> consonant
model.transition_model.unnormalized_transition_matrix[1,0] = 0.       # consonant -> vowel
model.transition_model.unnormalized_transition_matrix[1,1] = -np.inf  # vowel -> vowel
print("Transition matrix:", torch.nn.functional.softmax(model.transition_model.unnormalized_transition_matrix, dim=0))



State priors: tensor([0.6225, 0.3775], device='cuda:0')
Emission matrix: tensor([[0.0000, 0.2163, 0.0250, 0.0134, 0.0000, 0.0153, 0.0334, 0.0712, 0.0000,
         0.0456, 0.0131, 0.0594, 0.0249, 0.0363, 0.0000, 0.1069, 0.0353, 0.0116,
         0.0144, 0.0151, 0.0000, 0.0340, 0.1091, 0.0645, 0.0422, 0.0132],
        [0.2717, 0.0000, 0.0000, 0.0000, 0.2163, 0.0000, 0.0000, 0.0000, 0.2508,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0335, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.2277, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
       device='cuda:0')
Transition matrix: tensor([[0., 1.],
        [1., 0.]], device='cuda:0')




### **1. Why do we set `requires_grad = False` before manually assigning parameters?**

**Answer:**
`requires_grad=False` is needed because PyTorch **blocks in-place modifications** on tensors that require gradients.
Since you‚Äôre *manually hard-wiring* transition/emission probabilities (e.g., setting some to `-inf`), you must disable gradients:

```python
for p in model.parameters():
    p.requires_grad = False
```

Otherwise PyTorch will throw errors like:
**‚Äúa leaf Variable that requires grad is being used in an in-place operation.‚Äù**

---

### **2. Why is `-np.inf` used in the emission and transition matrices?**

**Answer:**
`-inf` turns into **probability 0** after softmax:

```
softmax(-inf) = 0
```

This is a clean way to *completely forbid* certain transitions or emissions.

Examples in your model:

* Consonant state ‚Üí cannot emit vowels
* Vowel state ‚Üí cannot emit consonants
* Consonant ‚Üí consonant transitions forbidden
* Vowel ‚Üí vowel transitions forbidden

It enforces the structure:

```
consonant ‚Üî vowel alternation only
```

Perfect for demonstrating a rule-based HMM.

---

### **3. Why do we convert letters to numbers using `encode()` and `decode()`?**

**Answer:**
HMMs operate on **indices**, not characters.

`encode()` maps a string like `"hello"` to integer indices:
`[7, 4, 11, 11, 14]`

`decode()` converts predicted integer sequences back to readable strings.

This mapping allows:

* emission matrix to index observations
* sampling to produce integer sequences
* decoding to show human-interpretable letters

It‚Äôs the standard preprocessing step for discrete HMMs.

---

### **4. Why is the emission matrix built using vowel indices and consonant indices?**

**Answer:**
You define:

* **State 0 = consonant state**
* **State 1 = vowel state**

Then enforce:

```python
state 0 emits ONLY consonants  
state 1 emits ONLY vowels
```

by setting impossible emissions to `-inf`.

This creates a **linguistically meaningful HMM** where sampling produces alternating consonant‚Äìvowel patterns (like ‚Äúb a l o n e‚Äù).

It‚Äôs also excellent pedagogically because students immediately see how **emission constraints shape outputs**.

---

### **5. Why does the transition matrix forbid same-type transitions (C‚ÜíC and V‚ÜíV)?**

**Answer:**
You set:

```python
C‚ÜíC = -inf  
V‚ÜíV = -inf  
C‚ÜíV = 0  
V‚ÜíC = 0
```

After softmax, the model becomes:

```
C ‚Üí V  with probability 1  
V ‚Üí C  with probability 1
```

This enforces *strict alternation*:

```
consonant ‚Üí vowel ‚Üí consonant ‚Üí vowel ‚Üí ...
```

This cleanly demonstrates:

* how transition probabilities shape hidden sequences
* how to hard-code structural linguistic patterns
* how HMM sampling behaves under deterministic transitions



Try sampling from our hard-coded model:


In [19]:
# Sample some outputs
for _ in range(4):
  sampled_x, sampled_z = model.sample(T=5)
  print("x:", decode(sampled_x))
  print("z:", sampled_z)

x: ipeje
z: [1, 0, 1, 0, 1]
x: wiqiw
z: [0, 1, 0, 1, 0]
x: babix
z: [0, 1, 0, 1, 0]
x: hiley
z: [0, 1, 0, 1, 0]


## The Three Problems

In a [classic tutorial](https://www.cs.cmu.edu/~cga/behavior/rabiner1.pdf) on HMMs, Lawrence Rabiner describes "three problems" that need to be solved before you can effectively use an HMM. They are:
- Problem 1: How do we efficiently compute $p(\mathbf{x})$?
- Problem 2: How do we find the most likely state sequence $\mathbf{z}$ that could have generated the data?
- Problem 3: How do we train the model?

In the rest of the notebook, we will see how to solve each problem and implement the solutions in PyTorch.

### Problem 1: How do we compute $p(\mathbf{x})$?


#### *Why?*
Why might we care about computing $p(\mathbf{x})$? Here's two reasons.
* Given two HMMs, $\theta_1$ and $\theta_2$, we can compute the likelihood of some data $\mathbf{x}$ under each model, $p_{\theta_1}(\mathbf{x})$ and $p_{\theta_2}(\mathbf{x})$, to decide which model is a better fit to the data.

  (For example, given an HMM for English speech and an HMM for French speech, we could compute the likelihood given each model, and pick the model with the higher likelihood to infer whether the person is speaking English or French.)
* Being able to compute $p(\mathbf{x})$ gives us a way to train the model, as we will see later.

#### *How?*
Given that we want $p(\mathbf{x})$, how do we compute it?

We've assumed that the data is generated by visiting some sequence of states $\mathbf{z}$ and picking an output $x_t$ for each $z_t$ from the emission distribution $p(x_t|z_t)$. So if we knew $\mathbf{z}$, then the probability of $\mathbf{x}$ could be computed as follows:

$$p(\mathbf{x}|\mathbf{z}) = \prod_{t} p(x_t|z_t) p(z_t|z_{t-1})$$

However, we don't know $\mathbf{z}$; it's hidden. But we do know the probability of any given $\mathbf{z}$, independent of what we observe. So we could get the probability of $\mathbf{x}$ by summing over the different possibilities for $\mathbf{z}$, like this:

$$p(\mathbf{x}) = \sum_{\mathbf{z}} p(\mathbf{x}|\mathbf{z}) p(\mathbf{z}) = \sum_{\mathbf{z}} \prod_{t} p(x_t|z_t) p(z_t|z_{t-1})$$

The problem is: if you try to take that sum directly, you will need to compute $N^T$ terms. This is impossible to do for anything but very short sequences. For example, let's say the sequence is of length $T=100$ and there are $N=2$ possible states. Then we would need to check $N^T = 2^{100} \approx 10^{30}$ different possible state sequences.

We need a way to compute $p(\mathbf{x})$ that doesn't require us to explicitly calculate all $N^T$ terms. For this, we use the forward algorithm.

________

<u><b>The Forward Algorithm</b></u>

> for $s=1 \rightarrow N$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;$\alpha_{s,1} := b_s(x_1) \cdot \pi_s$
>
> for $t = 2 \rightarrow T$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;for $s = 1 \rightarrow N$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
> $\alpha_{s,t} := b_s(x_t) \cdot \underset{s'}{\sum} A_{s, s'} \cdot \alpha_{s',t-1} $
>
> $p(\mathbf{x}) := \underset{s}{\sum} \alpha_{s,T}$\
> return $p(\mathbf{x})$
________


The forward algorithm is much faster than enumerating all $N^T$ possible state sequences: it requires only $O(N^2T)$ operations to run, since each step is mostly multiplying the vector of forward variables by the transition matrix. (And very often we can reduce that complexity even further, if the transition matrix is sparse.)

There is one practical problem with the forward algorithm as presented above: it is prone to underflow due to multiplying a long chain of small numbers, since probabilities are always between 0 and 1. Instead, let's do everything in the log domain. In the log domain, a multiplication becomes a sum, and a sum becomes a [logsumexp](https://lorenlugosch.github.io/posts/2020/06/logsumexp/).  

________

<u><b>The Forward Algorithm (Log Domain)</b></u>

> for $s=1 \rightarrow N$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;$\text{log }\alpha_{s,1} := \text{log }b_s(x_1) + \text{log }\pi_s$
>
> for $t = 2 \rightarrow T$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;for $s = 1 \rightarrow N$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
> $\text{log }\alpha_{s,t} := \text{log }b_s(x_t) +  \underset{s'}{\text{logsumexp}} \left( \text{log }A_{s, s'} + \text{log }\alpha_{s',t-1} \right)$
>
> $\text{log }p(\mathbf{x}) := \underset{s}{\text{logsumexp}} \left( \text{log }\alpha_{s,T} \right)$\
> return $\text{log }p(\mathbf{x})$
________

Now that we have a numerically stable version of the forward algorithm, let's implement it in PyTorch.

In [20]:
def HMM_forward(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)

  Compute log p(x) for each example in the batch.
  T = length of each example
  """
  if self.is_cuda:
  	x = x.cuda()
  	T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
  log_alpha = torch.zeros(batch_size, T_max, self.N)
  if self.is_cuda: log_alpha = log_alpha.cuda()

  log_alpha[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    log_alpha[:, t, :] = self.emission_model(x[:,t]) + self.transition_model(log_alpha[:, t-1, :])

  # Select the sum for the final timestep (each x may have different length).
  log_sums = log_alpha.logsumexp(dim=2)
  log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)
  return log_probs

def emission_model_forward(self, x_t):
  log_emission_matrix = torch.nn.functional.log_softmax(self.unnormalized_emission_matrix, dim=1)
  out = log_emission_matrix[:, x_t].transpose(0,1)
  return out

def transition_model_forward(self, log_alpha):
  """
  log_alpha : Tensor of shape (batch size, N)
  Multiply previous timestep's alphas by transition matrix (in log domain)
  """
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

  # Matrix multiplication in the log domain
  out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
  return out

def log_domain_matmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Normally, a matrix multiplication
	computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

	A log domain matrix multiplication
	computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	# log_A_expanded = torch.stack([log_A] * p, dim=2)
	# log_B_expanded = torch.stack([log_B] * m, dim=0)
    # fix for PyTorch > 1.5 by egaznep on Github:
	log_A_expanded = torch.reshape(log_A, (m,n,1))
	log_B_expanded = torch.reshape(log_B, (1,n,p))

	elementwise_sum = log_A_expanded + log_B_expanded
	out = torch.logsumexp(elementwise_sum, dim=1)

	return out

TransitionModel.forward = transition_model_forward
EmissionModel.forward = emission_model_forward
HMM.forward = HMM_forward



### **1. Why do we use the Forward algorithm recurrence instead of brute-forcing all hidden-state sequences?**

**Significance in code:**
The line

```python
log_alpha[:, t, :] = emission(...) + transition_model(log_alpha[:, t-1])
```

compresses an exponential number of state paths into a **single dynamic-programming step**.

**Significance for HMM:**
An HMM with `N` states and sequence length `T` has **N^T** possible hidden paths.
Brute forcing all paths is impossible.

The Forward algorithm *efficiently sums over all possible hidden paths* while preserving exact probabilities:

[
p(x_{1:T}) = \sum_{all ; z} p(x, z)
]

This is the **core reason** why HMMs are computationally feasible.

---

### **2. Why do we work in log space instead of normal probability space?**

**Significance in code:**
All transitions and emissions go through:

```python
log_softmax(...)
```

**Significance for HMM:**
HMMs multiply many probabilities. Even valid sequences may have likelihoods like:

[
10^{-40}, 10^{-80}, 10^{-150}
]

These collapse to zero in floating point arithmetic.

Log space transforms:

* multiplication ‚Üí addition
* summation ‚Üí logsumexp
* extremely tiny probabilities ‚Üí stable numbers

The HMM Forward algorithm **relies on this stability** to produce meaningful likelihoods.

If you do not use log space, the HMM forward pass becomes useless for moderate-length sequences.

---

### **3. Why do we compute emission probabilities using indexing (`emission_model(x[:, t])`) instead of full matrix multiplication?**

**Significance in code:**
You only fetch emission probabilities **for the observed symbol** at time `t`.

**Significance for HMM:**
The HMM emission model defines:

[
p(x_t | z_t)
]

It is **conditional on the observation**, not on every possible observation.
By indexing directly:

```python
emission_matrix[:, x_t]
```

you implement the exact HMM rule:
*‚ÄúWhich states could have emitted this specific observed symbol?‚Äù*

This avoids unnecessary work and respects the structure of discrete emission HMMs.

---

### **4. Why do we use a log-domain matrix multiplication instead of regular matrix multiplication for transitions?**

**Significance in code:**
Your custom method:

```python
log_domain_matmul(log_A, log_B)
```

computes:

[
\log \sum_i \exp(\log A_{i \to j} + \log \alpha_{t-1}(i))
]

**Significance for HMM:**
The Forward algorithm needs to **sum over all previous states**:

[
\alpha_t(j) = b_j(x_t) \sum_i \alpha_{t-1}(i) A_{i \to j}
]

Regular matrix multiplication performs **products**, not log-sum-exp.
Log-domain matmul implements the *exact HMM update rule* in a numerically safe way.

Without this, you are not executing the true HMM Forward algorithm.

---

### **5. Why do we use `logsumexp` over hidden states at the final timestep instead of taking a max or using the last Œ± value directly?**

**Significance in code:**

```python
log_sums = log_alpha.logsumexp(dim=2)
```

**Significance for HMM:**
For an HMM, the probability of the entire observed sequence is:

[
p(x_{1:T}) = \sum_{states ; i} \alpha_T(i)
]

This is a **sum**, not a maximum.
Taking a max would correspond to the **Viterbi algorithm**, which finds:

* most likely hidden path
  not
* total likelihood of the observations.

Forward algorithm = **sum over all possible explanations**
Viterbi = **best single explanation**

So `logsumexp` maintains the probabilistic meaning of an HMM.




Try running the forward algorithm on our vowels/consonants model from before:

In [21]:
x = torch.stack( [torch.tensor(encode("cat"))] )
T = torch.tensor([3])
print(model.forward(x, T))

x = torch.stack( [torch.tensor(encode("aba")), torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.forward(x, T))

tensor([[-9.6603]], device='cuda:0')
tensor([[-5.1112],
        [   -inf]], device='cuda:0')




### **1. Why do we encode characters into integers instead of feeding letters directly?**

**Significance:**
HMMs operate on **discrete symbols**, not characters or strings.
By converting `"cat"` ‚Üí `[2,0,19]`, we map letters into indices so they act as **observable emission symbols**.

**Why not feed characters?**
Characters have no numerical meaning; the HMM needs integer categories to index:

* emission probabilities `P(x | state)`
* transition matrices
* priors

This step mirrors how NLP models typically use **tokenization**.

---

### **2. Why did we *hard-wire* the model parameters instead of training them?**

**Significance:**
This demo is intended to **illustrate the structure and logic of an HMM**, not how it learns.

Hard-wiring lets you:

* **visually understand** how priors, transitions, and emissions influence probability
* **enforce linguistic constraints** (like vowel/consonant alternation)
* avoid randomness from training that can hide the intended behavior

**Why not train?**
Training would produce arbitrary parameter values unless the corpora contained perfect vowel‚Äìconsonant alternation, making the conceptual lesson weaker.

---

### **3. Why force the emission matrix to allow only vowel‚Äìstate emissions and consonant‚Äìstate emissions?**

**Significance:**
This shows **how HMM states represent *latent categories*** ‚Äî here:

* **State 0 ‚Üí consonant-emitter**
* **State 1 ‚Üí vowel-emitter**

By putting `-inf` for invalid symbols, we force:

```
P(vowel | consonant-state) = 0
P(consonant | vowel-state) = 0
```

**Why not allow soft probabilities?**
Because the purpose is to **illustrate deterministic constraints inside an HMM**.
Soft probabilities would dilute the structure and make the alternation pattern less clear.

---

### **4. Why restrict the transition matrix to ONLY vowel‚Üíconsonant and consonant‚Üívowel?**

**Significance:**
This enforces an **alternating HMM**:

```
C ‚Üí V ‚Üí C ‚Üí V ‚Üí ...
```

This helps students clearly see how transitions control **sequence structure**, independent of emissions.

**Why not allow C‚ÜíC or V‚ÜíV transitions?**
Because then the model could generate arbitrary sequences and you lose the clean, interpretable pattern.

This is a teaching example of how transitions encode **grammar-like constraints**.

---

### **5. Why call `model.forward(x, T)` with tensor shapes like this?**

Example:

```python
x = torch.stack([torch.tensor(encode("cat"))])
T = torch.tensor([3])
```

**Significance:**
The forward pass computes:

* the **log-likelihood** of each sequence
* respecting:

  * priors
  * allowed transitions
  * allowed emissions
  * sequence length `T`

`x` is shaped as a **batch of sequences**, so you see how HMMs naturally support **multiple sequences**, e.g.:

```python
x = torch.stack([encode("aba"), encode("abb")])
T = [3, 3]
```

**Why not pass raw strings?**
The model needs:

* numeric emissions
* fixed-length tensors
* batch structure
  None of these exist with raw Python strings.

---

### **6. Why do we see different probabilities for sequences like ‚Äúaba‚Äù vs ‚Äúabb‚Äù?**

**Significance:**
Because the HMM‚Äôs forced structure influences how likely a sequence is.
Example:

```
a b a ‚Üí vowel, consonant, vowel ‚Üí fits perfectly
a b b ‚Üí vowel, consonant, consonant ‚Üí INVALID final consonant
```

So the HMM assigns:

* **high probability** to sequences that match the alternating pattern
* **zero probability** where the last emission forbids the state structure

This shows the power of HMMs in modeling **latent structure constraints**.

---

### **7. Why set priors to prefer starting in a consonant?**

**Significance:**
Most English-like words begin with consonants.
So you set:

```
P(state0 = consonant) > P(state1 = vowel)
```

**Why not keep uniform priors?**
Uniform priors hide the effect priors have on:

* forward probabilities
* decoding
* sequence likelihood

Explicit priors demonstrate how initial state assumptions affect the entire probability calculation.




When using the vowel <-> consonant HMM from above, notice that the forward algorithm returns $-\infty$ for $\mathbf{x} = \text{"abb"}$. That's because our transition matrix says the probability of vowel -> vowel and consonant -> consonant is 0, so the probability of $\text{"abb"}$ happening is 0, and thus the log probability is $-\infty$.

#### *Side note: deriving the forward algorithm*

If you're interested in understanding how the forward algorithm actually computes $p(\mathbf{x})$, read this section; if not, skip to the next part on "Problem 2" (finding the most likely state sequence).



To derive the forward algorithm, start by deriving the forward variable:

$
\begin{align}
    \alpha_{s,t} &= p(x_1, x_2, \dots, x_t, z_t=s) \\
     &= p(x_t | x_1, x_2, \dots, x_{t-1}, z_t = s) \cdot p(x_1, x_2, \dots, x_{t-1}, z_t = s)  \\
    &= p(x_t | z_t = s) \cdot p(x_1, x_2, \dots, x_{t-1}, z_t = s) \\
    &= p(x_t | z_t = s) \cdot \left( \sum_{s'} p(x_1, x_2, \dots, x_{t-1}, z_{t-1}=s', z_t = s) \right)\\
    &= p(x_t | z_t = s) \cdot \left( \sum_{s'} p(z_t = s | x_1, x_2, \dots, x_{t-1}, z_{t-1}=s') \cdot p(x_1, x_2, \dots, x_{t-1}, z_{t-1}=s') \right)\\
    &= \underbrace{p(x_t | z_t = s)}_{\text{emission model}} \cdot \left( \sum_{s'} \underbrace{p(z_t = s | z_{t-1}=s')}_{\text{transition model}} \cdot \underbrace{p(x_1, x_2, \dots, x_{t-1}, z_{t-1}=s')}_{\text{forward variable for previous timestep}} \right)\\
    &= b_s(x_t) \cdot \left( \sum_{s'} A_{s, s'} \cdot \alpha_{s',t-1} \right)
\end{align}
$

I'll explain how to get to each line of this equation from the previous line.

Line 1 is the definition of the forward variable $\alpha_{s,t}$.

Line 2 is the chain rule ($p(A,B) = p(A|B) \cdot p(B)$, where $A$ is $x_t$ and $B$ is all the other variables).

In Line 3, we apply Assumption 2: the probability of observation $x_t$ depends only on the current state $z_t$.

In Line 4, we marginalize over all the possible states in the previous timestep $t-1$.

In Line 5, we apply the chain rule again.

In Line 6, we apply Assumption 1: the current state depends only on the previous state.

In Line 7, we substitute in the emission probability, the transition probability, and the forward variable for the previous timestep, to get the complete recursion.

The formula above can be used for $t = 2 \rightarrow T$. At $t=1$, there is no previous state, so instead of the transition matrix $A$, we use the state priors $\pi$, which tell us the probability of starting in each state. Thus for $t=1$, the forward variables are computed as follows:

$$\begin{align}
\alpha_{s,1} &= p(x_1, z_1=s) \\
  &= p(x_1 | z_1 = s) \cdot p(z_1 = s)  \\
&= b_s(x_1) \cdot \pi_s
\end{align}$$

Finally, to compute $p(\mathbf{x}) = p(x_1, x_2, \dots, x_T)$, we marginalize over $\alpha_{s,T}$, the forward variables computed in the last timestep:

$$\begin{align*}
p(\mathbf{x}) &= \sum_{s} p(x_1, x_2, \dots, x_T, z_T = s) \\
&= \sum_{s} \alpha_{s,T}
\end{align*}$$

You can get from this formulation to the log domain formulation by taking the log of the forward variable, and using these identities:
- $\text{log }(a \cdot b) = \text{log }a + \text{log }b$
- $\text{log }(a + b) = \text{log }(e^{\text{log }a} + e^{\text{log }b}) = \text{logsumexp}(\text{log }a, \text{log }b)$

### Problem 2: How do we compute $\underset{\mathbf{z}}{\text{argmax }} p(\mathbf{z}|\mathbf{x})$?

Given an observation sequence $\mathbf{x}$, we may want to find the most likely sequence of states that could have generated $\mathbf{x}$. (Given the sequence of selfies, we want to infer what cities the friend visited.) In other words, we want $\underset{\mathbf{z}}{\text{argmax }} p(\mathbf{z}|\mathbf{x})$.

We can use Bayes' rule to rewrite this expression:
$$\begin{align*}
    \underset{\mathbf{z}}{\text{argmax }} p(\mathbf{z}|\mathbf{x}) &= \underset{\mathbf{z}}{\text{argmax }} \frac{p(\mathbf{x}|\mathbf{z}) p(\mathbf{z})}{p(\mathbf{x})} \\
    &= \underset{\mathbf{z}}{\text{argmax }} p(\mathbf{x}|\mathbf{z}) p(\mathbf{z})
\end{align*}$$

Hmm! That last expression, $\underset{\mathbf{z}}{\text{argmax }} p(\mathbf{x}|\mathbf{z}) p(\mathbf{z})$, looks suspiciously similar to the intractable expression we encountered before introducing the forward algorithm, $\underset{\mathbf{z}}{\sum} p(\mathbf{x}|\mathbf{z}) p(\mathbf{z})$.

And indeed, just as the intractable *sum* over all $\mathbf{z}$ can be implemented efficiently using the forward algorithm, so too this intractable *argmax* can be implemented efficiently using a similar divide-and-conquer algorithm: the legendary Viterbi algorithm!

________

<u><b>The Viterbi Algorithm</b></u>

> for $s=1 \rightarrow N$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;$\delta_{s,1} := b_s(x_1) \cdot \pi_s$\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;$\psi_{s,1} := 0$
>
> for $t = 2 \rightarrow T$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;for $s = 1 \rightarrow N$:\
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;$\delta_{s,t} := b_s(x_t) \cdot \left( \underset{s'}{\text{max }} A_{s, s'} \cdot \delta_{s',t-1} \right)$\
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;$\psi_{s,t} := \underset{s'}{\text{argmax }} A_{s, s'} \cdot \delta_{s',t-1}$
>
> $z_T^* := \underset{s}{\text{argmax }} \delta_{s,T}$\
> for $t = T-1 \rightarrow 1$:\
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;$z_{t}^* := \psi_{z_{t+1}^*,t+1}$
>
> $\mathbf{z}^* := \{z_{1}^*, \dots, z_{T}^* \}$\
return $\mathbf{z}^*$
________

The Viterbi algorithm looks somewhat gnarlier than the forward algorithm, but it is essentially the same algorithm, with two tweaks: 1) instead of taking the sum over previous states, we take the max; and 2) we record the argmax of the previous states in a table, and loop back over this table at the end to get $\mathbf{z}^*$, the most likely state sequence. (And like the forward algorithm, we should run the Viterbi algorithm in the log domain for better numerical stability.)

Let's add the Viterbi algorithm to our PyTorch model:

In [22]:
def viterbi(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)
  Find argmax_z log p(x|z) for each (x) in the batch.
  """
  if self.is_cuda:
    x = x.cuda()
    T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
  log_delta = torch.zeros(batch_size, T_max, self.N).float()
  psi = torch.zeros(batch_size, T_max, self.N).long()
  if self.is_cuda:
    log_delta = log_delta.cuda()
    psi = psi.cuda()

  log_delta[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    max_val, argmax_val = self.transition_model.maxmul(log_delta[:, t-1, :])
    log_delta[:, t, :] = self.emission_model(x[:,t]) + max_val
    psi[:, t, :] = argmax_val

  # Get the log probability of the best path
  log_max = log_delta.max(dim=2)[0]
  best_path_scores = torch.gather(log_max, 1, T.view(-1,1) - 1)

  # This next part is a bit tricky to parallelize across the batch,
  # so we will do it separately for each example.
  z_star = []
  for i in range(0, batch_size):
    z_star_i = [ log_delta[i, T[i] - 1, :].max(dim=0)[1].item() ]
    for t in range(T[i] - 1, 0, -1):
      z_t = psi[i, t, z_star_i[0]].item()
      z_star_i.insert(0, z_t)

    z_star.append(z_star_i)

  return z_star, best_path_scores # return both the best path and its log probability

def transition_model_maxmul(self, log_alpha):
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

  out1, out2 = maxmul(log_transition_matrix, log_alpha.transpose(0,1))
  return out1.transpose(0,1), out2.transpose(0,1)

def maxmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Similar to the log domain matrix multiplication,
	this computes out_{i,j} = max_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	log_A_expanded = torch.stack([log_A] * p, dim=2)
	log_B_expanded = torch.stack([log_B] * m, dim=0)

	elementwise_sum = log_A_expanded + log_B_expanded
	out1,out2 = torch.max(elementwise_sum, dim=1)

	return out1,out2

TransitionModel.maxmul = transition_model_maxmul
HMM.viterbi = viterbi



### **1. Why do we use log probabilities in Viterbi instead of raw probabilities?**

**Significance:**
Multiplying many small probabilities quickly underflows to zero.
Viterbi uses **addition of log-probs instead of multiplication of probs**, which is numerically stable.

**Why not stay in probability space?**

* HMM sequences quickly get tiny: (10^{-30}), (10^{-50}), etc.
* Logs prevent collapse to zero
* Max over log-probs = max over probs (monotonic transformation)

---

### **2. Why does Viterbi use `max` (argmax path) while forward uses `sum` (total probability)?**

**Significance:**
Forward computes:

> **Likelihood of the sequence** over *all* state paths.

Viterbi computes:

> **Most likely hidden-state path**, not its total probability.

**Why not use sum here?**
Sum answers *‚ÄúHow likely is the sequence?‚Äù*
Max answers *‚ÄúWhich sequence of states best explains it?‚Äù*
For decoding ‚Üí max is the correct operation.

---

### **3. Why do we use a custom `maxmul()` instead of PyTorch matrix multiply?**

**Significance:**
Typical matrix multiplication uses:

```
sum_k A[i,k] * B[k,j]
```

Viterbi needs:

```
max_k A[i,k] + B[k,j]
```

This is **max-plus semiring** (a different algebra).

Standard matmul cannot do that.

Using a custom `maxmul()` explicitly demonstrates to learners:

* HMM algorithms operate in special semirings
* Viterbi uses max-plus
* Forward uses log-sum-exp

This is an educational moment: *same HMM, different algebra ‚Üí different inference*.

---

### **4. Why do we store `psi` (the backpointer matrix)?**

**Significance:**
`psi[t][j]` tells you:

> *‚ÄúFrom which previous state did we come to reach state `j` at time `t` with the highest probability?‚Äù*

Without `psi`, you can compute the score of the best path, but **not the actual path**.

**Why not recompute path afterward?**
Because recomputing loses all intermediate decisions, impossible without backpointers.

---

### **5. Why do we manually backtrack through time in a loop instead of parallelizing?**

**Significance:**
Parallelizing backtracking is hard because:

* each step depends on the previous state‚Äôs choice
* backtracking is inherently sequential
  (you can only know the best state at `t-1` after knowing the best at `t`)

For teaching purposes, this loop makes the logic transparent.

**Why not vectorize anyway?**
It complicates the code and hides the conceptual flow.

---

### **6. Why do we apply emission probabilities before transitions at the first step?**

Code:

```python
log_delta[:, 0, :] = emission(x[:,0]) + log_state_priors
```

**Significance:**
At time 0, no transitions have occurred.
The best score for being in state `i` is:

```
P(z0 = i) + P(x0 | zi)
```

Teaching point: **first step = priors + first observation**.

**Why not include transition from a dummy start state?**
That‚Äôs equivalent, but increases conceptual overhead for students.

---

### **7. Why does ‚Äúaba‚Äù give a valid path but ‚Äúabb‚Äù fails or gives worse score?**

**Significance:**
Your HMM enforces alternating vowel/consonant states:

```
C ‚Üí V ‚Üí C ‚Üí V ...
```

‚Äúaba‚Äù

```
a (vowel)  
b (consonant)  
a (vowel)
‚Üí perfectly fits transitions + emissions
```

‚Äúabb‚Äù

```
a (vowel)
b (consonant)
b (consonant) ‚Üí forbidden emission ‚Üí -inf score
```

Thus Viterbi correctly returns:

* a valid state path for ‚Äúaba‚Äù
* impossible/low-probability path for ‚Äúabb‚Äù

This demonstrates how **HMM constraints produce structured decoding**.

---

### **8. Why is `best_path_scores` taken from the last timestep using gather?**

**Significance:**
If sequences have different lengths, Viterbi must take:

```
score of best state at time T[i] - 1
```

not at `T_max - 1`.

This teaches:

* HMMs support variable-length sequences
* we score only up to the true end, not the padded part

**Why not use a fixed-length sequence?**
Because batching variable-length sequences is standard in NLP.

---

### **9. Why run Viterbi on a batch (two words at once) instead of one word at a time?**

**Significance:**
Efficient NLP models process multiple sequences simultaneously.
Running:

```python
["aba", "abb"]
```

together demonstrates:

* vectorized HMM inference
* shared parameters across sequences
* practical batching behavior

**Why not run separately?**
Less realistic and slower in any real system.

---

### **10. Why return both `(z_star, best_path_scores)` instead of only the path?**

**Significance:**
Viterbi decoding answers two questions:

1. **Which path is most probable?** ‚Üí `z_star`
2. **How probable is that path?** ‚Üí `best_path_scores`

This matters for:

* comparing sequences
* confidence scoring
* ranking alternative segmentations or spellings

**Why not return only the path?**
You lose important interpretability and scoring information.




Try running Viterbi on an input sequence, given the vowel/consonant HMM:

In [23]:
x = torch.stack( [torch.tensor(encode("aba")), torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.viterbi(x, T))

([[1, 0, 1], [1, 0, 0]], tensor([[-5.1112],
        [   -inf]], device='cuda:0'))


For $\mathbf{x} = \text{"aba"}$, the Viterbi algorithm returns $\mathbf{z}^* = \{1,0,1\}$. This corresponds to "vowel, consonant, vowel" according to the way we defined the states above, which is correct for this input sequence. Yay!

For $\mathbf{x} = \text{"abb"}$, the Viterbi algorithm still returns a $\mathbf{z}^*$, but we know this is gibberish because "vowel, consonant, consonant" is impossible under this HMM, and indeed the log probability of this path is $-\infty$.

Let's compare the "forward score" (the log probability of all possible paths, returned by the forward algorithm) with the "Viterbi score" (the log probability of the maximum likelihood path, returned by the Viterbi algorithm):

In [24]:
print(model.forward(x, T))
print(model.viterbi(x, T)[1])

tensor([[-5.1112],
        [   -inf]], device='cuda:0')
tensor([[-5.1112],
        [   -inf]], device='cuda:0')


The two scores are the same! That's because in this instance there is only one possible path through the HMM, so the probability of the most likely path is the same as the sum of the probabilities of all possible paths.

In general, though, the forward score and Viterbi score will always be somewhat close. This is because of a property of the $\text{logsumexp}$ function: $\text{logsumexp}(\mathbf{x}) \approx \max (\mathbf{x})$. ($\text{logsumexp}$ is sometimes referred to as the "smooth maximum" function.)

In [25]:
x = torch.tensor([1., 2., 3.])
print(x.max(dim=0)[0])
print(x.logsumexp(dim=0))

tensor(3.)
tensor(3.4076)


### Problem 3: How do we train the model?





Earlier, we hard-coded an HMM to have certain behavior. What we would like to do instead is have the HMM learn to model the data on its own. And while it is possible to use supervised learning with an HMM (by hard-coding the emission model or the transition model) so that the states have a particular interpretation, the really cool thing about HMMs is that they are naturally unsupervised learners, so they can learn to use their different states to represent different patterns in the data, without the programmer needing to indicate what each state means.

Like many machine learning models, an HMM can be trained using maximum likelihood estimation, i.e.:

$$\theta^* = \underset{\theta}{\text{argmin }} -\sum_{\mathbf{x}^i}\text{log }p_{\theta}(\mathbf{x}^i)$$

where $\mathbf{x}^1, \mathbf{x}^2, \dots$ are training examples.

The standard method for doing this is the Expectation-Maximization (EM) algorithm, which for HMMs is also called the "Baum-Welch" algorithm. In EM training, we alternate between an "E-step", where we estimate the values of the latent variables, and an "M-step", where the model parameters are updated given the estimated latent variables. (Think $k$-means, where you guess which cluster each data point belongs to, then reestimate where the clusters are, and repeat.) The EM algorithm has some nice properties: it is guaranteed at each step to decrease the loss function, and the E-step and M-step may have an exact closed form solution, in which case no pesky learning rates are required.

But because the HMM forward algorithm is differentiable with respect to all the model parameters, we can also just take advantage of automatic differentiation methods in libraries like PyTorch and try to minimize $-\text{log }p_{\theta}(\mathbf{x})$ directly, by backpropagating through the forward algorithm and running stochastic gradient descent. That means we don't need to write any additional HMM code to implement training: `loss.backward()` is all you need.

Here we will implement SGD training for an HMM in PyTorch. First, some helper classes:

In [26]:
import torch.utils.data
from collections import Counter
from sklearn.model_selection import train_test_split

class TextDataset(torch.utils.data.Dataset):
  def __init__(self, lines):
    self.lines = lines # list of strings
    collate = Collate() # function for generating a minibatch from strings
    self.loader = torch.utils.data.DataLoader(self, batch_size=1024, num_workers=1, shuffle=True, collate_fn=collate)

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    line = self.lines[idx].lstrip(" ").rstrip("\n").rstrip(" ").rstrip("\n")
    return line

class Collate:
  def __init__(self):
    pass

  def __call__(self, batch):
    """
    Returns a minibatch of strings, padded to have the same length.
    """
    x = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_ = batch[index]

      # convert letters to integers
      x.append(encode(x_))

    # pad all sequences with 0 to have same length
    x_lengths = [len(x_) for x_ in x]
    T = max(x_lengths)
    for index in range(batch_size):
      x[index] += [0] * (T - len(x[index]))
      x[index] = torch.tensor(x[index])

    # stack into single tensor
    x = torch.stack(x)
    x_lengths = torch.tensor(x_lengths)
    return (x,x_lengths)

Let's load some training/testing data. By default, this will use the unix "words" file, but you could also use your own text file.



## **1. Why do we print both `model.forward(x, T)` and `model.viterbi(x, T)[1]`?**

### **Significance (HMM Concept)**

You are demonstrating that:

* **Forward algorithm** computes
  ‚Üí *total probability of all state paths producing the sequence*.

* **Viterbi algorithm** computes
  ‚Üí *probability of only the best single state path*.

So the forward log-probability is always **‚â•** the Viterbi log-score (log-sum-exp ‚â• max).
Printing them together visually proves this property.

### **Why not print only one?**

Because students often confuse ‚Äúlikelihood of the sequence‚Äù with ‚Äúlikelihood of the best path.‚Äù
Seeing both reinforces the conceptual distinction.

---

## **2. Why do we compute both `x.max(dim=0)` and `x.logsumexp(dim=0)`?**

### **Significance (HMM Math Insight)**

You are demonstrating the key algebraic difference:

* Viterbi uses **max** (max-plus semiring)
* Forward uses **logsumexp** (sum-product semiring)

This simple vector example:

```python
x = [1, 2, 3]
```

lets students see:

* Max = 3
* Log-sum-exp = log(e¬π + e¬≤ + e¬≥) > 3

This illustrates why forward scores > Viterbi scores always.

### **Why not explain using only theory?**

This tiny concrete example makes the connection obvious and intuitive.

---

## **3. Why do we pad sequences inside `Collate()` instead of letting PyTorch handle variable lengths automatically?**

### **Significance (HMM Implementation Detail)**

Your HMM implementation expects a tensor shaped:

```
batch_size √ó T_max
```

because:

* Forward algorithm needs uniform time steps
* Viterbi needs a consistent DP table shape
* You manually track the true lengths via `x_lengths`

### **Why not use PyTorch‚Äôs PackedSequence / RNN utilities?**

Because this is a *hand-built* HMM, not an RNN.
Packing/unpacking is unnecessary and would hide important details from students.

Padding + explicit lengths is the standard method in *classic HMM implementations*.

---

## **4. Why does `Collate()` both pad **and** return `x_lengths` (T)?**

### **Significance (HMM Requirement)**

Even though you pad sequences, the forward and Viterbi algorithms must only use the true length.

If you didn‚Äôt pass `T`, the model would try to interpret padded zeros as real emissions, which ruins:

* log-probabilities
* state decoding
* evaluation results

### **Why not mask instead of using `T`?**

Masking complicates the code.
Using `T` keeps the algorithms clean:

* forward ‚Üí `log_probs = torch.gather(..., T-1)`
* Viterbi ‚Üí backtracking starts at `T-1`

This is the simplest correct approach.

---

## **5. Why does `__getitem__` strip whitespace before encoding the line?**

### **Significance**

Many datasets contain:

* leading spaces
* trailing spaces
* stray newline characters

If you don‚Äôt strip:

* HMM will treat spaces as extra tokens
* encoded sequences will be wrong
* emissions matrices become polluted

### **Why not let the encoder ignore unknown chars?**

Encoding unknown characters silently hides mistakes.
Stripping ensures clean, consistent training data.

---

## **6. Why stack two sequences (`aba`, `abb`) into the same batch?**

### **Significance**

Batching demonstrates how HMMs scale.
Your forward and Viterbi implementations are **vectorized across batch dimension**, so they:

* share transition/emission parameters
* compute in parallel
* mimic real HMM usage in NLP tasks

### **Why not run sequences independently?**

It hides the efficiency benefits of HMMs and breaks the demo‚Äôs focus on batching.

---

## **7. Why is padding done using the integer `0` specifically?**

### **Significance**

`0` corresponds to letter `'a'` in your alphabet.
But in HMM training, that doesn‚Äôt matter because:

* you always **mask out padded positions using `T`**
* forward and Viterbi never look at padded timesteps
* emissions for the padded section are irrelevant

### **Why not use a special PAD token?**

Because the model never actually reads padded timesteps‚Äî`T` handles it.
This avoids enlarging the alphabet.

---

## **8. Why is the `TextDataset` holding a `DataLoader` inside it?**

### **Significance**

This design makes your training loop extremely simple:

```python
for x, T in dataset.loader:
    ...
```

It packages:

* the data
* batching
* padding
* encoding

into a single object.

### **Why not put the DataLoader outside?**

For a classroom demo, embedding it inside keeps code clean and easy to follow.

---

## **9. Why do we use `Counter` and train/test split but not shown in this snippet?**

### **Significance**

You are preparing the dataset to later show:

* distribution of characters
* splitting of corpus into training/testing for likelihood evaluation
* classic NLP preprocessing

This lays groundwork for:

* learning HMM parameters
* evaluating perplexity
* computing held-out likelihood

---

## **10. Why are we printing outputs after each step in this notebook-style code?**

### **Significance**

This is a **didactic style**:

* Show intermediate numeric results
* Reinforce understanding of each algorithmic step
* Compare max vs log-sum-exp
* Compare forward vs Viterbi
* See batching effects directly

### **Why not wrap everything in functions and hide output?**

Because transparency is crucial for teaching how HMM inference works.




In [27]:
!wget https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt

filename = "training.txt"

with open(filename, "r") as f:
  lines = f.readlines() # each line of lines will have one word

alphabet = list(Counter(("".join(lines))).keys())
train_lines, valid_lines = train_test_split(lines, test_size=0.1, random_state=42)
train_dataset = TextDataset(train_lines)
valid_dataset = TextDataset(valid_lines)

M = len(alphabet)

--2025-11-23 04:43:58--  https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2493109 (2.4M) [text/plain]
Saving to: ‚Äòtraining.txt.1‚Äô


2025-11-23 04:43:58 (66.3 MB/s) - ‚Äòtraining.txt.1‚Äô saved [2493109/2493109]



We will use a Trainer class for training and testing the model:



In [28]:
from tqdm import tqdm # for displaying progress bar

class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.00001)

  def train(self, dataset):
    train_loss = 0
    num_samples = 0
    self.model.train()
    print_interval = 50
    for idx, batch in enumerate(tqdm(dataset.loader)):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      train_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        for _ in range(5):
          sampled_x, sampled_z = self.model.sample()
          print(decode(sampled_x))
          print(sampled_z)
    train_loss /= num_samples
    return train_loss

  def test(self, dataset):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    print_interval = 50
    for idx, batch in enumerate(dataset.loader):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      test_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        sampled_x, sampled_z = self.model.sample()
        print(decode(sampled_x))
        print(sampled_z)
    test_loss /= num_samples
    return test_loss

Finally, initialize the model and run the main training loop. Every 50 batches, the code will produce a few samples from the model. Over time, these samples should look more and more realistic.

In [29]:
# Initialize model
model = HMM(N=64, M=M)

# Train the model
num_epochs = 10
trainer = Trainer(model, lr=0.01)

for epoch in range(num_epochs):
        print("========= Epoch %d of %d =========" % (epoch+1, num_epochs))
        train_loss = trainer.train(train_dataset)
        valid_loss = trainer.test(valid_dataset)

        print("========= Results: epoch %d of %d =========" % (epoch+1, num_epochs))
        print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss) )



  0%|          | 1/208 [00:00<00:40,  5.09it/s]

loss: 38.02503204345703
FbYkdNLhKk
[1, 51, 8, 61, 49, 35, 22, 63, 28, 12]
YlqRm-yXBU
[54, 26, 27, 28, 22, 40, 32, 4, 51, 44]
ZXnvXoCQDz
[22, 43, 42, 58, 4, 39, 9, 41, 13, 46]
eeWqolQGVN
[52, 44, 25, 31, 40, 0, 31, 51, 51, 37]
DWUYxpvudP
[61, 46, 44, 34, 33, 33, 31, 14, 48, 37]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 16.25it/s]

loss: 32.93939971923828
UgoasnNGni
[1, 50, 32, 42, 20, 32, 20, 16, 42, 42]

piLespevY
[2, 43, 61, 22, 7, 56, 14, 36, 17, 20]
paieohscKs
[47, 49, 2, 40, 2, 6, 26, 32, 18, 55]
BNNH-eFaTK
[30, 4, 17, 30, 37, 15, 33, 0, 28, 53]
tnosAtoagj
[47, 2, 62, 26, 19, 3, 7, 56, 55, 10]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:06, 16.44it/s]

loss: 30.032291412353516
XghquxnVtH
[23, 38, 59, 48, 15, 32, 36, 17, 44, 9]
bncgeQsmJr
[36, 17, 34, 50, 7, 16, 14, 58, 17, 40]
moilcpaWiL
[15, 35, 17, 46, 34, 43, 33, 9, 4, 22]
UEqvPAxOel
[63, 49, 33, 40, 16, 2, 50, 35, 16, 45]
-asTiupmTa
[1, 27, 10, 33, 0, 35, 18, 41, 46, 42]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 16.44it/s]

loss: 28.68271827697754
Lfiptndabn
[47, 15, 26, 33, 40, 32, 48, 37, 26, 19]
Cvlslppmmy
[58, 14, 35, 26, 4, 58, 14, 45, 29, 17]
oeoeesnono
[42, 7, 45, 43, 15, 26, 42, 61, 47, 15]
mJsyhzvelC
[54, 61, 26, 32, 44, 32, 15, 26, 29, 33]
permipisuM
[47, 58, 27, 34, 7, 20, 7, 48, 15, 26]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 16.18it/s]

loss: 26.962074279785156
ZaEeesatic
[4, 16, 61, 21, 44, 26, 32, 36, 7, 47]
mrelesBLni
[4, 58, 15, 42, 7, 26, 45, 43, 42, 7]
ezihotavib
[15, 34, 16, 36, 17, 40, 16, 36, 17, 31]
dooelGLana
[48, 15, 15, 15, 26, 45, 43, 16, 42, 31]
vuhaQaiono
[54, 2, 43, 58, 36, 17, 47, 15, 42, 7]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.38it/s]


loss: 27.272933959960938
inrhiiteae
[54, 2, 40, 43, 17, 17, 45, 43, 17, 36]
train loss: 30.86| valid loss: 26.82



  0%|          | 1/208 [00:00<00:38,  5.37it/s]

loss: 26.66340446472168
iseeidanra
[16, 48, 7, 50, 7, 0, 16, 42, 40, 16]
cjulnmngol
[4, 13, 35, 36, 17, 36, 17, 50, 15, 42]
unotrmytrn
[54, 35, 17, 40, 16, 36, 17, 40, 43, 48]
nosgcosavo
[42, 7, 26, 20, 48, 15, 26, 16, 36, 17]
boLaniarlc
[47, 7, 40, 32, 29, 17, 40, 35, 36, 36]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 16.24it/s]

loss: 25.913742065429688
ecityzlyth
[15, 36, 17, 40, 32, 36, 36, 17, 40, 43]
cxoathoadh
[47, 15, 35, 16, 40, 43, 32, 61, 48, 43]
parialrtoa
[30, 44, 35, 17, 16, 36, 17, 40, 32, 58]
orwranlhea
[58, 34, 53, 35, 16, 2, 1, 45, 15, 16]
uoargrsssi
[54, 2, 17, 2, 50, 15, 26, 26, 26, 32]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:07, 13.84it/s]

loss: 25.620635986328125
Rorastitan
[4, 15, 35, 17, 26, 29, 17, 40, 32, 2]
nrlamlleri
[47, 35, 36, 16, 36, 36, 36, 7, 35, 32]
armlitooes
[16, 42, 12, 36, 17, 40, 15, 16, 26, 26]
QchklacoJc
[20, 47, 43, 32, 36, 17, 40, 32, 16, 40]
sngrawvyan
[54, 2, 50, 36, 16, 36, 36, 17, 16, 42]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 16.97it/s]

loss: 24.614309310913086
oemineabin
[48, 15, 34, 33, 50, 7, 16, 36, 7, 36]
melibcnvap
[4, 15, 35, 17, 36, 7, 2, 48, 17, 29]
demespleni
[48, 15, 34, 7, 26, 45, 43, 15, 35, 32]
fodnGdtame
[48, 15, 2, 50, 35, 38, 40, 16, 34, 16]
losclonyna
[36, 17, 16, 42, 4, 16, 36, 17, 2, 16]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 16.20it/s]

loss: 24.512184143066406
mogatriurm
[4, 17, 50, 17, 40, 43, 17, 15, 35, 36]
bpytessyca
[39, 36, 17, 40, 7, 26, 29, 17, 40, 17]
salevessti
[4, 16, 42, 7, 48, 7, 26, 26, 29, 17]
ruarcyxteo
[48, 32, 16, 24, 37, 61, 42, 40, 15, 35]
Boliniolat
[4, 16, 36, 17, 36, 7, 58, 36, 17, 40]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.65it/s]


loss: 24.574323654174805
tolotealar
[47, 16, 36, 17, 40, 7, 26, 4, 16, 35]
train loss: 25.44| valid loss: 24.76



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 24.371315002441406
Novinollis
[4, 16, 36, 17, 40, 15, 35, 36, 7, 26]
tagkatiaac
[47, 16, 50, 43, 16, 40, 32, 16, 42, 40]
ungoramyom
[54, 2, 50, 15, 35, 16, 36, 17, 16, 36]


  1%|‚ñè         | 3/208 [00:00<00:20, 10.01it/s]

oneonubnyo
[58, 36, 7, 16, 2, 16, 25, 36, 17, 16]
caresmolef
[47, 32, 48, 7, 26, 29, 15, 35, 32, 42]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 15.91it/s]

loss: 24.66459083557129
prortanari
[47, 35, 16, 35, 4, 16, 36, 16, 35, 17]
aplesslete
[16, 25, 36, 7, 26, 29, 36, 7, 40, 7]
pocesperos
[47, 16, 40, 7, 26, 29, 15, 35, 16, 26]
Herothaouo
[18, 15, 35, 16, 40, 43, 16, 11, 44, 16]
maceaagipo
[4, 16, 40, 43, 17, 16, 50, 17, 45, 15]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:07, 13.97it/s]

loss: 24.52764892578125
stlebetarg
[26, 29, 36, 7, 50, 32, 40, 16, 35, 50]
unredmedrs
[54, 2, 48, 7, 0, 36, 7, 0, 35, 48]
pecsaLogle
[4, 16, 42, 40, 16, 54, 2, 50, 36, 7]
lolbyprini
[4, 16, 34, 39, 61, 14, 43, 16, 42, 32]
dedosirine
[48, 7, 0, 16, 42, 7, 35, 17, 42, 7]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 15.85it/s]

loss: 24.304121017456055
mentruresy
[4, 7, 42, 40, 43, 15, 35, 7, 48, 17]
tumburatur
[47, 15, 34, 4, 46, 35, 17, 40, 15, 35]
vetionirit
[4, 7, 40, 32, 16, 34, 17, 36, 17, 40]
eyssunthur
[4, 17, 26, 29, 15, 42, 40, 32, 16, 35]
thytablest
[47, 43, 17, 45, 16, 25, 36, 7, 26, 29]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 16.42it/s]

loss: 24.471256256103516
lycarshoge
[43, 17, 45, 16, 42, 40, 43, 17, 50, 7]
Anarcastda
[54, 2, 16, 42, 40, 17, 26, 29, 48, 17]
ssmedrozhi
[26, 29, 57, 46, 0, 35, 17, 40, 43, 17]
tobivedyce
[47, 16, 40, 32, 48, 7, 36, 17, 40, 7]
emnetrarol
[20, 34, 13, 21, 40, 43, 16, 35, 16, 36]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.56it/s]


loss: 24.334529876708984
cefllwassg
[4, 7, 0, 36, 36, 36, 17, 26, 29, 15]
train loss: 24.46| valid loss: 24.28



  0%|          | 1/208 [00:00<00:39,  5.22it/s]

loss: 24.5795955657959
antedlidoc
[16, 42, 40, 7, 0, 36, 17, 48, 17, 45]
flaucthoas
[21, 36, 7, 46, 42, 40, 43, 16, 46, 26]
saronlotis
[47, 33, 35, 16, 36, 36, 17, 40, 32, 26]
tolelosuca
[47, 16, 36, 7, 35, 16, 26, 26, 45, 16]
batiglunes
[4, 17, 40, 17, 50, 36, 17, 36, 17, 26]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 16.44it/s]

loss: 24.024890899658203
IrphifonGe
[21, 35, 47, 43, 32, 48, 15, 42, 40, 32]
odeneredlo
[16, 48, 32, 48, 15, 35, 7, 0, 35, 16]
scotodatis
[26, 45, 16, 40, 16, 48, 17, 40, 32, 26]
oYtvetliss
[16, 42, 40, 36, 17, 40, 43, 32, 26, 29]
diwnislaph
[48, 17, 25, 36, 17, 45, 43, 16, 45, 43]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:06, 15.72it/s]

loss: 23.986164093017578
insjormnos
[54, 2, 26, 29, 15, 35, 57, 36, 17, 29]
merenpissa
[4, 7, 35, 16, 2, 45, 17, 26, 29, 16]
dkectrrphr
[60, 12, 7, 42, 40, 43, 35, 47, 43, 35]
teumiriolm
[47, 15, 54, 34, 7, 35, 17, 16, 36, 36]
myxatemesh
[4, 61, 34, 17, 40, 15, 34, 7, 26, 29]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 16.22it/s]

loss: 24.548397064208984
terideotih
[4, 15, 35, 32, 48, 15, 16, 40, 32, 48]
stydesmacb
[26, 29, 32, 36, 7, 26, 29, 16, 42, 55]
tolciuscur
[4, 15, 35, 45, 16, 46, 26, 45, 15, 35]
thiontical
[47, 43, 32, 16, 42, 40, 32, 45, 33, 35]
ritrinsmss
[4, 17, 40, 6, 32, 2, 26, 29, 10, 42]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 16.33it/s]

loss: 23.776012420654297
chesordond
[47, 43, 7, 29, 15, 35, 0, 16, 2, 0]
whubhavion
[47, 43, 46, 0, 43, 16, 48, 32, 16, 2]
lonvithysa
[4, 16, 2, 48, 17, 40, 6, 61, 29, 16]
franouscar
[47, 43, 17, 2, 16, 46, 26, 45, 16, 35]
denranflti
[48, 7, 42, 40, 16, 2, 47, 43, 40, 32]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.52it/s]


loss: 24.0762939453125
pslaggandb
[26, 29, 36, 17, 50, 50, 16, 2, 0, 53]
train loss: 24.06| valid loss: 24.00



  0%|          | 1/208 [00:00<00:40,  5.14it/s]

loss: 24.225631713867188
ratarobned
[36, 17, 40, 16, 35, 16, 25, 36, 7, 0]
alyceangha
[16, 36, 61, 45, 15, 16, 2, 50, 43, 17]
orthauntra
[16, 42, 40, 43, 16, 46, 42, 40, 43, 17]
luntiiutas
[4, 46, 42, 40, 32, 16, 46, 29, 17, 26]
cinanvitio
[4, 32, 2, 16, 2, 48, 32, 40, 32, 16]


 25%|‚ñà‚ñà‚ñç       | 51/208 [00:03<00:13, 11.25it/s]

loss: 23.670207977294922
destlertin
[48, 17, 26, 29, 36, 7, 42, 40, 32, 36]
suspiragga
[4, 46, 26, 29, 32, 48, 17, 50, 50, 33]
skideadeis
[26, 29, 32, 48, 15, 17, 40, 43, 32, 26]
relataremo
[4, 32, 36, 17, 40, 16, 48, 15, 34, 17]
yonenepled
[4, 16, 2, 7, 48, 7, 0, 36, 7, 0]


 49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 102/208 [00:06<00:06, 15.54it/s]

loss: 23.64531898498535
phauctysat
[47, 43, 58, 46, 42, 40, 61, 29, 17, 40]
osmeansmet
[16, 26, 29, 15, 16, 2, 26, 29, 17, 40]
cosresWoga
[47, 16, 50, 43, 7, 26, 29, 16, 50, 17]
Mortyginys
[47, 15, 35, 29, 17, 50, 17, 36, 61, 29]
stdertodis
[26, 29, 48, 7, 42, 40, 16, 0, 32, 26]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 154/208 [00:09<00:03, 16.59it/s]

loss: 24.15966033935547
astoncurin
[17, 26, 29, 16, 2, 45, 33, 35, 16, 36]
derouslioZ
[48, 15, 35, 58, 46, 29, 36, 32, 16, 7]
inilinansp
[54, 2, 32, 36, 17, 2, 16, 42, 26, 29]
jentianten
[27, 7, 42, 40, 32, 16, 42, 40, 15, 42]
calenicife
[47, 33, 36, 7, 42, 32, 45, 32, 18, 15]


 97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 202/208 [00:12<00:00, 12.47it/s]

loss: 23.66124725341797
Cantivenoa
[47, 16, 42, 40, 32, 48, 7, 42, 32, 16]
ciphuostan
[4, 20, 14, 43, 44, 17, 26, 40, 16, 2]
micardager
[4, 32, 45, 33, 35, 0, 17, 50, 7, 35]
daneenisma
[4, 16, 2, 32, 7, 42, 32, 26, 29, 16]
ingeonciph
[54, 2, 50, 32, 16, 42, 40, 32, 14, 6]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:13<00:00, 15.55it/s]


loss: 23.87270736694336
rancsouscu
[4, 16, 2, 40, 43, 58, 46, 26, 45, 33]
train loss: 23.85| valid loss: 23.80



  0%|          | 1/208 [00:00<00:40,  5.07it/s]

loss: 23.867509841918945
aussanicke
[16, 46, 26, 29, 16, 42, 32, 45, 12, 7]
ungananges
[54, 2, 50, 16, 48, 16, 2, 50, 7, 26]
pessorchan
[4, 7, 26, 29, 15, 35, 9, 6, 16, 42]
wertinesho
[4, 7, 42, 40, 32, 48, 7, 26, 29, 16]
phanedicat
[47, 43, 16, 36, 7, 0, 32, 45, 33, 40]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:10, 14.85it/s]

loss: 23.760662078857422
wacelentos
[47, 16, 42, 7, 36, 7, 42, 48, 16, 42]
nerivancec
[48, 15, 35, 17, 48, 16, 42, 40, 7, 47]
adisitynlo
[16, 0, 32, 48, 32, 40, 61, 36, 36, 58]
musquenani
[4, 46, 26, 30, 44, 7, 42, 33, 36, 17]
teleciabne
[47, 16, 36, 7, 42, 32, 16, 25, 36, 7]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:06, 16.29it/s]

loss: 23.685054779052734
whhumoworf
[47, 6, 43, 37, 34, 16, 22, 16, 35, 48]
elidicaciv
[7, 35, 32, 48, 32, 45, 33, 40, 32, 48]
palderatyn
[47, 33, 35, 0, 15, 35, 17, 40, 61, 36]
mustaiflen
[4, 46, 26, 40, 58, 32, 18, 36, 7, 42]
taliteleno
[4, 16, 36, 17, 40, 15, 35, 7, 42, 32]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 15.96it/s]

loss: 23.92264175415039
dotrodumes
[47, 16, 40, 43, 58, 48, 33, 36, 7, 26]
Memmasseor
[4, 20, 34, 34, 17, 26, 29, 15, 16, 42]
muboeunomo
[4, 46, 39, 47, 1, 54, 2, 16, 34, 16]
castoydelz
[47, 16, 42, 40, 58, 32, 48, 15, 35, 48]
cauntinelt
[4, 16, 46, 42, 40, 32, 48, 15, 2, 40]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 15.78it/s]

loss: 23.52283477783203
lerpizetis
[4, 15, 35, 29, 32, 48, 7, 40, 32, 26]
gecatiessr
[47, 15, 45, 33, 40, 17, 7, 26, 29, 15]
Viconsescu
[4, 32, 45, 16, 2, 48, 7, 26, 47, 15]
pocermazeb
[47, 16, 40, 15, 35, 49, 16, 36, 7, 13]
unemlyhrxe
[54, 2, 20, 34, 36, 61, 29, 35, 34, 7]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.47it/s]


loss: 23.63994026184082
lexdigilid
[4, 7, 42, 48, 17, 50, 32, 36, 32, 48]
train loss: 23.67| valid loss: 23.64



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.194503784179688
apertortan
[16, 14, 15, 35, 29, 16, 42, 40, 16, 2]


  1%|‚ñè         | 3/208 [00:00<00:23,  8.54it/s]

bounchaler
[47, 1, 54, 2, 45, 6, 16, 36, 7, 42]
cobakeshic
[4, 16, 13, 17, 12, 7, 26, 29, 32, 45]
quicanlyou
[30, 44, 32, 45, 33, 36, 36, 61, 58, 46]
iucrubnest
[54, 2, 47, 43, 17, 25, 36, 7, 26, 29]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 16.19it/s]

loss: 23.39376449584961
angralitas
[16, 2, 50, 43, 16, 36, 17, 40, 17, 26]
doveftzele
[4, 16, 48, 7, 18, 32, 48, 7, 36, 7]
unacarogdo
[54, 2, 17, 45, 33, 35, 58, 50, 48, 33]
rermilymed
[4, 15, 35, 49, 16, 36, 61, 29, 15, 0]
ivertrogan
[54, 48, 15, 35, 40, 43, 58, 50, 16, 2]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:06, 16.02it/s]

loss: 23.470476150512695
frekeshtob
[47, 43, 17, 12, 7, 26, 6, 40, 33, 25]
bongetesti
[47, 16, 2, 50, 7, 40, 7, 26, 40, 32]
cuilnyphda
[30, 44, 32, 36, 36, 61, 14, 6, 48, 16]
mirpeblyge
[4, 20, 34, 14, 17, 25, 36, 61, 50, 7]
holymaldic
[4, 16, 36, 61, 34, 16, 36, 48, 32, 45]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 16.36it/s]

loss: 23.511812210083008
iltetatist
[54, 2, 29, 17, 40, 17, 40, 32, 26, 29]
ipheurogym
[20, 14, 6, 58, 46, 42, 58, 50, 61, 34]
praminesso
[47, 43, 20, 34, 32, 48, 7, 26, 29, 16]
Slhoneceda
[21, 9, 6, 16, 48, 7, 42, 16, 0, 16]
gritidamag
[50, 43, 17, 40, 32, 48, 16, 34, 16, 50]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 14.46it/s]

loss: 23.536495208740234
gixberemer
[4, 20, 34, 13, 15, 35, 58, 34, 15, 35]
Ghedasereb
[47, 6, 7, 0, 17, 40, 15, 35, 17, 25]
rodiectora
[4, 16, 48, 7, 16, 42, 40, 16, 35, 16]
eneanienzi
[7, 42, 7, 16, 42, 32, 16, 2, 48, 32]
culsenglel
[47, 54, 42, 40, 15, 2, 50, 36, 7, 36]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.48it/s]


loss: 23.625118255615234
vophtiaakl
[4, 58, 14, 6, 40, 32, 48, 17, 12, 36]
train loss: 23.53| valid loss: 23.52



  0%|          | 1/208 [00:00<00:56,  3.68it/s]

loss: 23.70162582397461
cactearsin
[47, 16, 42, 40, 7, 33, 35, 29, 32, 48]
nestedfale
[4, 7, 26, 40, 15, 0, 18, 33, 35, 7]
hitillabru
[6, 32, 40, 32, 36, 36, 17, 13, 43, 46]
eneshenter
[58, 48, 7, 26, 6, 7, 42, 40, 15, 35]
diccuzalop
[4, 32, 45, 40, 32, 48, 33, 36, 58, 14]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 16.22it/s]

loss: 23.505718231201172
thoscedert
[9, 6, 16, 26, 40, 15, 0, 15, 35, 9]
Carmedfice
[47, 33, 35, 29, 15, 0, 18, 32, 45, 7]
lispelzice
[4, 20, 34, 14, 15, 35, 48, 32, 45, 7]
iltinestik
[54, 2, 40, 32, 48, 7, 26, 29, 17, 12]
anogematae
[16, 42, 58, 50, 20, 34, 17, 40, 17, 7]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:06, 16.16it/s]

loss: 23.149898529052734
Poncichyla
[47, 16, 42, 40, 32, 40, 6, 61, 36, 17]
mymaJuster
[4, 28, 49, 17, 41, 46, 26, 40, 15, 35]
leesserziv
[4, 7, 7, 26, 29, 15, 35, 48, 32, 48]
Nakitesckm
[4, 17, 12, 17, 40, 7, 26, 45, 12, 4]
unicaylemi
[54, 2, 32, 45, 16, 5, 43, 58, 49, 32]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 16.30it/s]

loss: 23.393461227416992
inpryntuek
[54, 2, 47, 43, 61, 42, 40, 44, 16, 12]
griinitura
[47, 43, 17, 32, 48, 32, 40, 46, 35, 33]
damanecass
[4, 58, 49, 16, 2, 58, 45, 33, 26, 29]
uncyeoussb
[54, 2, 45, 61, 63, 58, 46, 26, 29, 39]
hagrablysn
[6, 16, 50, 43, 16, 25, 36, 61, 34, 48]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 14.64it/s]

loss: 23.228116989135742
hilllatera
[6, 32, 36, 36, 36, 17, 40, 15, 35, 17]
nonogindio
[4, 16, 42, 58, 50, 32, 2, 0, 32, 16]
Merllylerc
[4, 15, 35, 36, 36, 61, 36, 7, 35, 9]
iicalonpel
[3, 32, 45, 33, 36, 58, 2, 47, 15, 36]
Squestcoso
[21, 30, 44, 7, 26, 40, 63, 58, 29, 16]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.47it/s]


loss: 23.39180564880371
teoukomate
[4, 7, 16, 46, 29, 58, 49, 17, 40, 15]
train loss: 23.42| valid loss: 23.42



  0%|          | 1/208 [00:00<00:39,  5.23it/s]

loss: 23.45665168762207
raslentrou
[4, 20, 29, 36, 16, 42, 40, 43, 58, 46]
arpyablafa
[20, 35, 14, 5, 17, 25, 36, 16, 48, 33]
merotroupp
[48, 15, 35, 17, 40, 43, 58, 46, 14, 14]
Lkinicrood
[4, 12, 32, 2, 32, 45, 43, 58, 38, 0]
tanatlonte
[4, 16, 42, 17, 40, 43, 16, 42, 40, 15]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 15.91it/s]

loss: 23.599506378173828
Agiotinedi
[21, 50, 32, 16, 40, 32, 48, 7, 0, 32]
Obperivath
[21, 3, 14, 15, 35, 32, 48, 17, 40, 6]
reutarhymb
[4, 16, 46, 29, 58, 9, 6, 61, 34, 25]
sustoleary
[60, 46, 26, 40, 15, 36, 7, 33, 35, 61]
scinditave
[60, 45, 32, 2, 0, 32, 40, 16, 48, 15]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:06, 16.49it/s]

loss: 23.485023498535156
scthomatic
[60, 45, 40, 6, 58, 34, 16, 42, 32, 45]
Ercyumysma
[21, 42, 9, 28, 37, 34, 61, 26, 49, 17]
pmatongetl
[26, 49, 17, 40, 16, 2, 50, 7, 29, 36]
phomorocoo
[14, 6, 58, 34, 15, 35, 16, 45, 51, 51]
metchimeno
[4, 7, 42, 9, 6, 32, 48, 7, 42, 16]


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 153/208 [00:09<00:03, 16.81it/s]

loss: 23.612308502197266
sprastogla
[60, 47, 43, 20, 60, 40, 58, 50, 43, 16]
stentesest
[60, 40, 15, 42, 40, 7, 40, 7, 26, 29]
fwilaotire
[47, 22, 33, 36, 17, 38, 40, 32, 48, 7]
symiceboys
[60, 61, 34, 32, 45, 7, 52, 51, 28, 26]
grozzallyn
[50, 43, 58, 8, 8, 33, 36, 36, 61, 2]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 15.60it/s]

loss: 23.08568572998047
podveniosm
[47, 15, 0, 48, 7, 42, 32, 16, 26, 29]
honfarsuQu
[6, 16, 2, 23, 33, 35, 60, 46, 30, 44]
chtarivece
[47, 6, 40, 33, 35, 32, 48, 7, 42, 7]
frypertity
[47, 43, 61, 14, 15, 35, 40, 32, 40, 61]
tulyousvir
[47, 33, 36, 61, 58, 46, 29, 48, 7, 35]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.60it/s]


loss: 23.78101348876953
metlemmeub
[4, 7, 40, 43, 20, 34, 49, 7, 17, 25]
train loss: 23.34| valid loss: 23.36



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.565593719482422
remaciabca
[4, 20, 34, 17, 40, 32, 16, 39, 45, 17]
inenenamme
[54, 2, 16, 42, 7, 42, 58, 34, 49, 16]
ponomatine
[47, 16, 42, 58, 49, 17, 40, 32, 48, 7]
intitieifi
[54, 2, 40, 32, 40, 32, 16, 32, 18, 32]


  1%|‚ñè         | 3/208 [00:00<00:20, 10.14it/s]

Tudoaredwe
[4, 46, 0, 58, 16, 35, 7, 0, 22, 15]


 25%|‚ñà‚ñà‚ñå       | 53/208 [00:03<00:09, 17.01it/s]

loss: 23.107942581176758
norsupbaph
[4, 51, 35, 60, 46, 29, 31, 16, 14, 6]
prppruarsh
[47, 43, 20, 14, 43, 44, 17, 35, 26, 6]
lureaestre
[4, 46, 35, 7, 17, 7, 26, 40, 43, 20]
Ariesistic
[21, 35, 32, 16, 48, 32, 26, 29, 32, 45]
goperetora
[47, 16, 14, 15, 35, 7, 42, 58, 43, 58]


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 103/208 [00:06<00:06, 15.73it/s]

loss: 23.038089752197266
enrodedynd
[16, 2, 43, 58, 0, 7, 0, 61, 2, 0]
crylochied
[47, 43, 61, 36, 58, 9, 6, 32, 16, 48]
knteriloif
[54, 42, 40, 15, 35, 32, 36, 58, 38, 18]
ojenoodMly
[16, 27, 7, 42, 58, 38, 0, 61, 36, 61]
banapenoma
[4, 16, 2, 20, 14, 15, 42, 58, 49, 17]


 73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 151/208 [00:09<00:04, 14.11it/s]

loss: 23.52859878540039
prasoptral
[47, 43, 20, 29, 16, 14, 40, 43, 16, 36]
sipushyder
[60, 20, 14, 46, 26, 6, 61, 0, 15, 35]
pirisiapii
[47, 33, 35, 32, 26, 32, 16, 14, 43, 20]
Nestyrrope
[4, 7, 26, 40, 33, 35, 43, 58, 14, 15]
trogissito
[47, 43, 58, 50, 32, 26, 29, 32, 40, 58]


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 203/208 [00:12<00:00, 16.83it/s]

loss: 23.023834228515625
hmentineli
[6, 49, 16, 42, 40, 32, 48, 15, 36, 32]
pyunnierwh
[47, 1, 54, 2, 48, 32, 16, 35, 9, 6]
thettorolm
[9, 6, 7, 42, 40, 15, 35, 33, 36, 49]
nanistPung
[4, 20, 42, 32, 26, 29, 47, 54, 2, 50]
moirrorome
[4, 51, 33, 35, 43, 58, 43, 58, 49, 7]


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 208/208 [00:12<00:00, 16.48it/s]


loss: 23.31765365600586
sprierycop
[60, 14, 43, 20, 15, 35, 28, 45, 16, 14]
train loss: 23.30| valid loss: 23.32



You may wish to try different values of $N$ and see what the impact on sample quality is.

In [30]:
x = torch.tensor(encode("quack")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T))

x = torch.tensor(encode("quick")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T))

x = torch.tensor(encode("qurck")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T)) # should have lower probability---in English only vowels follow "qu"

x = torch.tensor(encode("qiick")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T)) # should have lower probability---in English only "u" follows "q"


([[30, 44, 17, 45, 12]], tensor([[-13.7824]], device='cuda:0', grad_fn=<GatherBackward0>))
([[30, 44, 32, 45, 12]], tensor([[-10.9867]], device='cuda:0', grad_fn=<GatherBackward0>))
([[30, 44, 35, 45, 12]], tensor([[-19.2461]], device='cuda:0', grad_fn=<GatherBackward0>))
([[30, 44, 32, 45, 12]], tensor([[-19.3918]], device='cuda:0', grad_fn=<GatherBackward0>))


In [31]:
import torch

def show_prob(word):
    x = torch.tensor(encode(word)).unsqueeze(0)
    T = torch.tensor([len(word)])

    path, logp = model.viterbi(x, T)

    # Convert log-prob ‚Üí normal probability
    prob = torch.exp(logp).item()

    print(f"word: {word}")
    print(f"best state path: {path[0]}")
    print(f"log probability: {logp.item():.4f}")
    print(f"normal probability: {prob:.10f}\n")  # 10 decimal places


# Run your examples
show_prob("quack")
show_prob("quick")
show_prob("qurck")  # should be low
show_prob("qiick")  # should be low


word: quack
best state path: [30, 44, 17, 45, 12]
log probability: -13.7824
normal probability: 0.0000010337

word: quick
best state path: [30, 44, 32, 45, 12]
log probability: -10.9867
normal probability: 0.0000169247

word: qurck
best state path: [30, 44, 35, 45, 12]
log probability: -19.2461
normal probability: 0.0000000044

word: qiick
best state path: [30, 44, 32, 45, 12]
log probability: -19.3918
normal probability: 0.0000000038





## **1. Why do we build the alphabet using `Counter(("".join(lines)))` instead of reusing the earlier alphabet = 'a'..'z'?**

### **Significance (HMM + Data-Driven Modeling)**

This time the HMM is being **trained on real text**, not hard-coded toy examples.

So we extract the alphabet *from the actual dataset itself*, meaning:

* If the corpus contains accented letters, apostrophes, hyphens ‚Äî they get included
* If the dataset excludes some letters, the model won‚Äôt waste emission parameters on unused symbols

### **Why not use fixed 26-letters?**

Because real data vocabulary may be:

* smaller
* larger
* domain-specific

A data-driven alphabet prevents building an incorrect emission matrix.

---

## **2. Why do we train the HMM with `N=64` states? Why 64 specifically?**

### **Significance (What HMM states are actually doing)**

In a real HMM, the **states are not vowels/consonants** anymore ‚Äî they are **latent linguistic clusters**.

64 states gives the model enough ‚Äúexpressive capacity‚Äù to learn:

* onset vs nucleus vs coda patterns
* consonant types (plosives, fricatives, liquids‚Ä¶)
* vowel classes
* clusters like ‚Äúqu‚Äù, ‚Äúck‚Äù, ‚Äúng‚Äù, ‚Äúsh‚Äù, ‚Äúth‚Äù
* position-dependent variants (initial-q, mid-q, final-q, etc.)

### **Why not use 2 or 3 states?**

Because natural language morphology is too rich ‚Äî low-state HMMs cannot capture even basic English orthography patterns.

### **Why not use 256 states?**

Training becomes unstable and slow for this demo; 64 is a sweet spot of expressive power vs simplicity.

---

## **3. Why is the training loss defined as `loss = -log_probs.mean()`?**

### **Significance (Maximum Likelihood for HMMs)**

Forward returns **log p(x)**, so maximizing likelihood means minimizing **negative log-likelihood**.

This corresponds to MLE:

```
maximize    ‚àë log pŒ∏(x)
equivalent to minimize  -‚àë log pŒ∏(x)
```

### **Why not cross-entropy?**

Cross-entropy applies to classification targets.
Here, the target is an entire *sequence*, so the forward probability is correct.

---

## **4. Why show sampled sequences during training (`model.sample()`) every 50 batches?**

### **Significance (Interpreting HMM learning qualitatively)**

Sampling shows **what the model believes English words look like** at the current epoch.

You can visually track:

* random gibberish at epoch 1
* emergence of vowel/consonant balance
* discovery of frequent patterns:

  * ‚Äúqu‚Äù
  * ‚Äúck‚Äù
  * double consonants
  * English-like endings (‚Äúing‚Äù, ‚Äúed‚Äù, ‚Äúer‚Äù)

Sampling is the **best sanity check** to ensure training is working.

### **Why not wait until the end?**

Intermediate monitoring is essential when teaching or debugging sequence models.

---

## **5. Why use Adam optimizer with weight decay?**

### **Significance (Stabilizing HMM Learning)**

HMMs trained by gradient descent can suffer from:

* very sharp posteriors
* collapse of states
* overfitting emission distributions
* becoming degenerate (one state dominates)

Weight decay acts like **soft regularization** on:

* state priors
* transitions
* emissions

Adam helps with noisy gradients from variable-length minibatches.

### **Why not use Baum‚ÄìWelch (EM)?**

Because:

* EM is harder to implement from scratch
* cannot run on GPUs cleanly
* gradient-based HMMs are easier to integrate into PyTorch‚Äôs ecosystem
* pedagogically simpler for showing training loops

---

## **6. Why do we split 90% training / 10% validation?**

### **Significance (True Evaluation of Generative Models)**

Validation loss measures:

* generalization of the HMM
* whether the model has learned real English structure instead of memorization

Overfitting appears when:

* train loss ‚Üì
* validation loss ‚Üë

Classic for generative sequence models.

---

## **7. Why does the model correctly score `"quack"` or `"quick"` higher than `"qurck"` or `"qiick"`?**

### **Significance (HMM‚Äôs ability to learn orthographic dependencies)**

During training, the model learns high transition+emission probability for:

```
q ‚Üí u
```

Because in English, almost every "q" is followed by "u".

Similarly it learns:

* ‚Äúck‚Äù
* ‚Äúqu‚Äù
* ‚Äúing‚Äù
* ‚Äúsh‚Äù, ‚Äúch‚Äù
* vowel/consonant alternation tendencies

### **Why not use RNNs to learn this?**

Because you are demonstrating that **even simple HMMs** can capture surprising amounts of structure.

---

## **8. Why do we wrap everything in a `Trainer` class instead of writing a single loop?**

### **Significance (Clean Architecture)**

This emphasizes modularity:

* `trainer.train(dataset)`
* `trainer.test(dataset)`

In real machine learning pipelines, such separation is important.

### **Why not use Lightning or high-level trainers?**

Because you want students to see **exactly how training works**, step-by-step.

---

## **9. Why do we repeatedly evaluate validation loss every epoch?**

### **Significance**

Monitoring validation loss ensures the HMM:

* is not collapsing
* is improving English modeling
* is not diverging
* is not overfitting

Generative models MUST be monitored because they can silently degrade.

---

## **10. Why does `viterbi("qurck")` have lower probability?**

### **Significance (Sequence-Level Reasoning)**

Viterbi detects the **best hidden state path**.
For irregular words like ‚Äúqurck‚Äù, the emission/transition structure must force unlikely sequences of states, lowering score.

This directly shows students the interpretability of HMMs:

* Viterbi path is inspectable
* You can see state clusters that correspond to linguistic categories
* You observe where the sequence violates English patterns




## Conclusion

HMMs used to be very popular in natural language processing, but they have largely been overshadowed by neural network models like RNNs and Transformers. Still, it is fun and instructive to study the HMM; some commonly used machine learning techniques like [Connectionist Temporal Classification](https://www.cs.toronto.edu/~graves/icml_2006.pdf) are inspired by HMM methods. HMMs are [still used in conjunction with neural networks in speech recognition](https://arxiv.org/abs/1811.07453), where the assumption of a one-hot state makes sense for modelling phonemes, which are spoken one at a time.

## Acknowledgments

This notebook is based partly on Lawrence Rabiner's excellent article "[A Tutorial on Hidden Markov Models and Selected Applications in Speech Recognition](https://www.cs.cmu.edu/~cga/behavior/rabiner1.pdf)", which you may also like to check out. Thanks also to Dima Serdyuk and Kyle Gorman for their feedback on the draft.