In [None]:
import torch
from torch.nn import TransformerEncoderLayer
from torch import nn, Tensor

import numpy as np

import math

In [None]:
# Data parameters

num_hidden_state = 3 # number of hidden states
num_obs = 5          # number of possible observations
seq_length = 10      # sequence length
nsamples = 1000      # number of samples we want to generate

# Set model parameters
emsize = 200         # embedding dimension/feature dimension
d_hid = 2048         # dimension of the feedforward network in TransformerEncoder
nhead = 2            # number of heads in multi-head attention
ntoken = num_obs     # vocabulary size
batch_size = 200     # batch size 
lr = 1e-3            # learning rate
epochs = 200         # number of training epochs

In [None]:
# Generate HMM parameters
def generate_HMM_params(num_hidden_state, num_obs):
    # random generate the transition matrix and observation matrix, and compute the stationary distribution
    
    alpha_state = np.ones(num_hidden_state)
    alpha_obs = np.ones(num_obs) / num_obs
    trans_mat = np.random.dirichlet(alpha_state, num_hidden_state)
    obs_mat = np.random.dirichlet(alpha_obs, num_hidden_state)
    tmp = np.ones((num_hidden_state + 1, num_hidden_state))
    tmp[:-1] = np.identity(num_hidden_state) - trans_mat.T
    tmp_v = np.zeros(num_hidden_state + 1)
    tmp_v[-1] = 1
    stat_dist = np.linalg.lstsq(tmp, tmp_v, rcond=None)[0]
    return trans_mat, obs_mat, stat_dist

In [None]:
# Sample HMM sequences
def generate_HMM_sequences(trans_mat, obs_mat, init_dist, length, num_samples = 1):
    # generate sample sequences from HMM using the parameters given
    
    states = np.zeros((num_samples, length))
    obs = np.zeros((num_samples, length))
    tmp_state = np.argmax(np.random.multinomial(1, init_dist, num_samples), axis = 1)
    #print(tmp_state)
    for i in range(length):
        #print("i: ", i)
        states[:, i] = tmp_state
        for j in range(num_samples):
            obs[j, i] = np.random.multinomial(1, obs_mat[tmp_state[j]]).argmax()
            tmp_state[j] = np.random.multinomial(1, trans_mat[tmp_state[j]]).argmax()
        #print("obs[:, i]: ", obs[:, i])
    return states, obs

In [None]:
# Define Transformer Model
class TransformerModel(nn.Module):

  def __init__(self, emsize: int, nhead: int, ntoken: int):
    super().__init__()
    self.emsize = emsize
    self.encoder = nn.Embedding(ntoken, emsize)
    #self.pos_encoder = PositionalEncoding(emsize, dropout)
    self.transformer_encoder = TransformerEncoderLayer(emsize, nhead, d_hid, batch_first=True)
    self.decoder = nn.Linear(emsize, ntoken)
  
  def forward(self, src: Tensor) -> Tensor:
    # original input: (batch_size, seq_length)
    #print(src.shape)
    src = self.encoder(src) * math.sqrt(self.emsize)
    # after embedding: (batch_size, seq_length, emsize)
    #src = self.pos_encoder(src)
    #print(src.shape)
    output = self.transformer_encoder(src)
    #print(output.shape)
    # after encoder: (batch_size, seq_length, emsize)
    output = self.decoder(output)
    # after decoder: (batch_size, seq_length, ntoken)
    return output

In [None]:
# Generate HMM parameters and samples used for training
seed = 20211121
np.random.seed(seed)
trans_mat, obs_mat, stat_dist = generate_HMM_params(num_hidden_state, num_obs) # generate parameters for HMM
states, obs = generate_HMM_sequences(trans_mat, obs_mat, stat_dist, seq_length, nsamples) # generate sample sequences

In [None]:
# Prepare input data
dataset = torch.utils.data.TensorDataset(torch.LongTensor(obs))
train_dl = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Set up model instance
model = TransformerModel(emsize, nhead, ntoken)

In [None]:
# Set up optimizer and loss
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [None]:
# Training process
model.train()
for i in range(epochs):
  total_loss = 0.
  for data in train_dl:
    data = data[0]
    output = model(data)
    loss = criterion(output.transpose(1, 2), data)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
  if i % 10 == 0:
    print(total_loss)

6.001508593559265
1.2639819383621216
0.7904462367296219
0.5798290818929672
0.4541909098625183
0.37019966542720795
0.30941810831427574
0.26553548127412796
0.22933536022901535
0.20264381542801857
0.1810063160955906
0.1622806005179882
0.14724822714924812
0.13425217755138874
0.12368939444422722
0.11400996707379818
0.10585581883788109
0.09840241447091103
0.09266022779047489
0.08712353557348251


In [None]:
model.eval()
for data in train_dl:
  data = data[0]
  print(data[:5])
  output = model(data) # (batch_size, seq_length, ntoken)
  print(output[:5])
  #print(output.shape)
  #print(data)
  #print(data.shape)
  #print(output.transpose(1, 2))
  #print(output.transpose(1, 2).shape)
  loss = criterion(output.transpose(1, 2), data) # CrossEntropyLoss takes input of size (N, C, d) and (N, d) where N: number of data, C: number of classes, d: extra dim, so need to swap the dimension of output from (batch_size, seq_length, ntoken) to (batch_size, ntoken, seq_length)
  print(loss)

tensor([[3, 2, 3, 4, 4, 2, 3, 4, 2, 4],
        [4, 2, 2, 4, 4, 2, 4, 2, 4, 4],
        [2, 2, 2, 4, 2, 4, 2, 4, 2, 4],
        [4, 2, 4, 3, 2, 2, 2, 2, 4, 2],
        [4, 4, 2, 2, 2, 4, 4, 0, 3, 2]])
tensor([[[-1.6437, -1.7703, -0.9735,  3.2435, -0.1992],
         [-1.4774, -0.6050,  5.5317, -1.2283, -1.3134],
         [-1.6437, -1.7703, -0.9735,  3.2435, -0.1992],
         [-1.7631, -1.6035, -1.8294, -0.7417,  5.0992],
         [-1.7631, -1.6035, -1.8294, -0.7417,  5.0992],
         [-1.4774, -0.6050,  5.5317, -1.2283, -1.3134],
         [-1.6437, -1.7703, -0.9735,  3.2435, -0.1992],
         [-1.7631, -1.6035, -1.8294, -0.7417,  5.0992],
         [-1.4774, -0.6050,  5.5317, -1.2283, -1.3134],
         [-1.7631, -1.6035, -1.8294, -0.7417,  5.0992]],

        [[-1.7631, -1.6035, -1.8294, -0.7417,  5.0992],
         [-1.3730, -0.7473,  5.5607, -0.8609, -1.5911],
         [-1.3730, -0.7473,  5.5607, -0.8609, -1.5911],
         [-1.7631, -1.6035, -1.8294, -0.7417,  5.0992],
         [-1.