In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class HMM(nn.Module):
    def __init__(self, num_states, num_observations):
        super(HMM, self).__init__()
        self.num_states = num_states
        self.num_observations = num_observations
        # Transition matrix A (state to state)
        self.A = nn.Parameter(torch.randn(num_states, num_states))
        # Emission matrix B (state to observation)
        self.B = nn.Parameter(torch.randn(num_states, num_observations))
        # Initial state distribution π
        self.pi = nn.Parameter(torch.randn(num_states))

    def forward(self, observations):
        seq_len = len(observations)
        log_alpha = torch.zeros(seq_len, self.num_states)
        # Initial state probabilities
        log_alpha[0] = torch.log_softmax(self.pi, dim=0) + torch.log_softmax(self.B, dim=1)[:, observations[0]]
        # Forward pass
        for t in range(1, seq_len):
            for j in range(self.num_states):
                log_alpha[t, j] = torch.logsumexp(log_alpha[t-1] + torch.log_softmax(self.A, dim=1)[:, j], dim=0) + torch.log_softmax(self.B, dim=1)[j,observations[t]]
        return log_alpha
    
    def predict(self, observations):
        seq_len = len(observations)
        log_delta = torch.zeros(seq_len, self.num_states)
        psi = torch.zeros(seq_len, self.num_states, dtype=torch.long)
        # Initial state probabilities
        log_delta[0] = torch.log_softmax(self.pi, dim=0) + torch.log_softmax(self.B, dim=1)[:, observations[0]]
        # Viterbi pass
        for t in range(1, seq_len):
            for j in range(self.num_states):
                max_val, max_idx = torch.max(log_delta[t-1] + torch.log_softmax(self.A, dim=1)[:, j], dim=0)
                log_delta[t, j] = max_val + torch.log_softmax(self.B, dim=1)[j, observations[t]]
                psi[t, j] = max_idx
        # Backtrack
        states = torch.zeros(seq_len, dtype=torch.long)
        states[-1] = torch.argmax(log_delta[-1])
        for t in range(seq_len-2, -1, -1):
            states[t] = psi[t+1, states[t+1]]
        return states

# Generate some synthetic data for demonstration
def generate_data(num_sequences, sequence_length, num_states, num_observations):
    A = np.random.rand(num_states, num_states)
    A = A / A.sum(axis=1, keepdims=True)
    B = np.random.rand(num_states, num_observations)
    B = B / B.sum(axis=1, keepdims=True)
    pi = np.random.rand(num_states)
    pi = pi / pi.sum()
    print(A,B,pi)
    
    sequences = []
    states = []
    
    for _ in range(num_sequences):
        seq = []
        state_seq = []
        state = np.random.choice(num_states, p=pi)
        for _ in range(sequence_length):
            obs = np.random.choice(num_observations, p=B[state])
            seq.append(obs)
            state_seq.append(state)
            state = np.random.choice(num_states, p=A[state])
        sequences.append(seq)
        states.append(state_seq)
    
    return sequences, states

# Parameters
num_states = 3
num_observations = 5
num_sequences = 100
sequence_length = 10

# Generate synthetic data
sequences, _ = generate_data(num_sequences, sequence_length, num_states, num_observations)

# Convert sequences to tensor
sequences = torch.tensor(sequences, dtype=torch.long)

# Initialize HMM model
model = HMM(num_states, num_observations)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    for sequence in sequences:
        optimizer.zero_grad()
        log_alpha = model(sequence)
        loss = -torch.logsumexp(log_alpha[-1], dim=0)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/num_sequences}')

# Test the model with a new sequence
test_sequence = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], dtype=torch.long)
predicted_states = model.predict(test_sequence)
print("Predicted States:", predicted_states)


In [None]:
torch.softmax(model.A, dim=1),torch.softmax(model.B, dim=1),torch.softmax(model.pi, dim=0)