In [3]:
import torch
torch.set_default_device('mps:0')

In [4]:
import torch
import numpy as np

class HMM(torch.nn.Module):
  """
  Hidden Markov Model with discrete observations.
  """
  def __init__(self, M, N):
    super(HMM, self).__init__()
    self.M = M # number of possible observations
    self.N = N # number of states

    # A
    self.transition_model = TransitionModel(self.N)

    # b(x_t)
    self.emission_model = EmissionModel(self.N,self.M)

    # pi
    self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N))

    # use the GPU
    self.is_cuda = torch.cuda.is_available()
    if self.is_cuda: self.cuda()

class TransitionModel(torch.nn.Module):
  def __init__(self, N):
    super(TransitionModel, self).__init__()
    self.N = N
    self.unnormalized_transition_matrix = torch.nn.Parameter(torch.randn(N,N))

class EmissionModel(torch.nn.Module):
  def __init__(self, N, M):
    super(EmissionModel, self).__init__()
    self.N = N
    self.M = M
    self.unnormalized_emission_matrix = torch.nn.Parameter(torch.randn(N,M))

In [5]:
def sample(self, T=10):
  state_priors = torch.nn.functional.softmax(self.unnormalized_state_priors, dim=0)
  transition_matrix = torch.nn.functional.softmax(self.transition_model.unnormalized_transition_matrix, dim=0)
  emission_matrix = torch.nn.functional.softmax(self.emission_model.unnormalized_emission_matrix, dim=1)

  # sample initial state
  z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
  z = []; x = []
  z.append(z_t)
  for t in range(0,T):
    # sample emission
    x_t = torch.distributions.categorical.Categorical(emission_matrix[z_t]).sample().item()
    x.append(x_t)

    # sample transition
    z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item()
    if t < T-1: z.append(z_t)

  return x, z

# Add the sampling method to our HMM class
HMM.sample = sample

In [6]:
import string
alphabet = string.ascii_lowercase

def encode(s):
  """
  Convert a string into a list of integers
  """
  x = [alphabet.index(ss) for ss in s]
  return x

def decode(x):
  """
  Convert list of ints to string
  """
  s = "".join([alphabet[xx] for xx in x])
  return s

# Initialize the model
model = HMM(M=len(alphabet), N=2) 

# Hard-wiring the parameters!
# Let state 0 = consonant, state 1 = vowel
for p in model.parameters():
    p.requires_grad = False # needed to do lines below
model.unnormalized_state_priors[0] = 0.    # Let's start with a consonant more frequently
model.unnormalized_state_priors[1] = -0.5
print("State priors:", torch.nn.functional.softmax(model.unnormalized_state_priors, dim=0))

# In state 0, only allow consonants; in state 1, only allow vowels
vowel_indices = torch.tensor([alphabet.index(letter) for letter in "aeiou"])
consonant_indices = torch.tensor([alphabet.index(letter) for letter in "bcdfghjklmnpqrstvwxyz"])
model.emission_model.unnormalized_emission_matrix[0, vowel_indices] = -np.inf
model.emission_model.unnormalized_emission_matrix[1, consonant_indices] = -np.inf 
print("Emission matrix:", torch.nn.functional.softmax(model.emission_model.unnormalized_emission_matrix, dim=1))

# Only allow vowel -> consonant and consonant -> vowel
model.transition_model.unnormalized_transition_matrix[0,0] = -np.inf  # consonant -> consonant
model.transition_model.unnormalized_transition_matrix[0,1] = 0.       # vowel -> consonant
model.transition_model.unnormalized_transition_matrix[1,0] = 0.       # consonant -> vowel
model.transition_model.unnormalized_transition_matrix[1,1] = -np.inf  # vowel -> vowel
print("Transition matrix:", torch.nn.functional.softmax(model.transition_model.unnormalized_transition_matrix, dim=0))



State priors: tensor([0.6225, 0.3775], device='mps:0')
Emission matrix: tensor([[0.0000, 0.0120, 0.1351, 0.0388, 0.0000, 0.0127, 0.1529, 0.1171, 0.0000,
         0.0141, 0.0510, 0.0121, 0.0148, 0.0163, 0.0000, 0.0757, 0.0100, 0.0425,
         0.0298, 0.0134, 0.0000, 0.0137, 0.0644, 0.0119, 0.1228, 0.0391],
        [0.1667, 0.0000, 0.0000, 0.0000, 0.0621, 0.0000, 0.0000, 0.0000, 0.2245,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4337, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.1129, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
       device='mps:0')
Transition matrix: tensor([[0., 1.],
        [1., 0.]], device='mps:0')


In [7]:
# Sample some outputs
for _ in range(4):
  sampled_x, sampled_z = model.sample(T=5)
  print("x:", decode(sampled_x))
  print("z:", sampled_z)

x: cidoy
z: [0, 1, 0, 1, 0]
x: gigub
z: [0, 1, 0, 1, 0]
x: woyug
z: [0, 1, 0, 1, 0]
x: azaha
z: [1, 0, 1, 0, 1]


In [8]:
def HMM_forward(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)

  Compute log p(x) for each example in the batch.
  T = length of each example
  """
  if self.is_cuda:
    x = x.cuda()
    T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
  log_alpha = torch.zeros(batch_size, T_max, self.N)
  if self.is_cuda: log_alpha = log_alpha.cuda()

  log_alpha[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    log_alpha[:, t, :] = self.emission_model(x[:,t]) + self.transition_model(log_alpha[:, t-1, :])

  # Select the sum for the final timestep (each x may have different length).
  log_sums = log_alpha.logsumexp(dim=2)
  log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)
  return log_probs

def emission_model_forward(self, x_t):
  log_emission_matrix = torch.nn.functional.log_softmax(self.unnormalized_emission_matrix, dim=1)
  out = log_emission_matrix[:, x_t].transpose(0,1)
  return out

def transition_model_forward(self, log_alpha):
  """
  log_alpha : Tensor of shape (batch size, N)
  Multiply previous timestep's alphas by transition matrix (in log domain)
  """
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

  # Matrix multiplication in the log domain
  out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
  return out

def log_domain_matmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Normally, a matrix multiplication
	computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

	A log domain matrix multiplication
	computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	# log_A_expanded = torch.stack([log_A] * p, dim=2)
	# log_B_expanded = torch.stack([log_B] * m, dim=0)
    # fix for PyTorch > 1.5 by egaznep on Github:
	log_A_expanded = torch.reshape(log_A, (m,n,1))
	log_B_expanded = torch.reshape(log_B, (1,n,p))

	elementwise_sum = log_A_expanded + log_B_expanded
	out = torch.logsumexp(elementwise_sum, dim=1)

	return out

TransitionModel.forward = transition_model_forward
EmissionModel.forward = emission_model_forward
HMM.forward = HMM_forward

In [9]:
x = torch.stack( [torch.tensor(encode("cat"))] )
T = torch.tensor([3])
print(model.forward(x, T))

x = torch.stack( [torch.tensor(encode("aba")), torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.forward(x, T))

tensor([[-4.3148]], device='mps:0')
tensor([[-8.9836],
        [   -inf]], device='mps:0')


In [10]:
def viterbi(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)
  Find argmax_z log p(x|z) for each (x) in the batch.
  """
  if self.is_cuda:
    x = x.cuda()
    T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
  log_delta = torch.zeros(batch_size, T_max, self.N).float()
  psi = torch.zeros(batch_size, T_max, self.N).long()
  if self.is_cuda:
    log_delta = log_delta.cuda()
    psi = psi.cuda()

  log_delta[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    max_val, argmax_val = self.transition_model.maxmul(log_delta[:, t-1, :])
    log_delta[:, t, :] = self.emission_model(x[:,t]) + max_val
    psi[:, t, :] = argmax_val

  # Get the log probability of the best path
  log_max = log_delta.max(dim=2)[0]
  best_path_scores = torch.gather(log_max, 1, T.view(-1,1) - 1)

  # This next part is a bit tricky to parallelize across the batch,
  # so we will do it separately for each example.
  z_star = []
  for i in range(0, batch_size):
    z_star_i = [ log_delta[i, T[i] - 1, :].max(dim=0)[1].item() ]
    for t in range(T[i] - 1, 0, -1):
      z_t = psi[i, t, z_star_i[0]].item()
      z_star_i.insert(0, z_t)

    z_star.append(z_star_i)

  return z_star, best_path_scores # return both the best path and its log probability

def transition_model_maxmul(self, log_alpha):
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

  out1, out2 = maxmul(log_transition_matrix, log_alpha.transpose(0,1))
  return out1.transpose(0,1), out2.transpose(0,1)

def maxmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Similar to the log domain matrix multiplication,
	this computes out_{i,j} = max_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	log_A_expanded = torch.stack([log_A] * p, dim=2)
	log_B_expanded = torch.stack([log_B] * m, dim=0)

	elementwise_sum = log_A_expanded + log_B_expanded
	out1,out2 = torch.max(elementwise_sum, dim=1)

	return out1,out2

TransitionModel.maxmul = transition_model_maxmul
HMM.viterbi = viterbi

In [11]:
x = torch.stack( [torch.tensor(encode("aba")), torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.viterbi(x, T))

([[1, 0, 1], [1, 0, -1]], tensor([[-8.9836],
        [   -inf]], device='mps:0'))


In [12]:
print(model.forward(x, T))
print(model.viterbi(x, T)[1])

tensor([[-8.9836],
        [   -inf]], device='mps:0')
tensor([[-8.9836],
        [   -inf]], device='mps:0')


In [13]:
x = torch.tensor([1., 2., 3.])
print(x.max(dim=0)[0])
print(x.logsumexp(dim=0))

tensor(3., device='mps:0')
tensor(3.4076, device='mps:0')


In [14]:
import torch.utils.data
from collections import Counter
from sklearn.model_selection import train_test_split

class TextDataset(torch.utils.data.Dataset):
  def __init__(self, lines):
    self.lines = lines # list of strings
    collate = Collate() # function for generating a minibatch from strings
    self.loader = torch.utils.data.DataLoader(self, batch_size=1024, shuffle=True, collate_fn=collate, generator=torch.Generator(device='mps:0'))

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    line = self.lines[idx].lstrip(" ").rstrip("\n").rstrip(" ").rstrip("\n")
    return line

class Collate:
  def __init__(self):
    pass

  def __call__(self, batch):
    """
    Returns a minibatch of strings, padded to have the same length.
    """
    x = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_ = batch[index]

      # convert letters to integers
      x.append(encode(x_))

    # pad all sequences with 0 to have same length
    x_lengths = [len(x_) for x_ in x]
    T = max(x_lengths)
    for index in range(batch_size):
      x[index] += [0] * (T - len(x[index]))
      x[index] = torch.tensor(x[index])

    # stack into single tensor
    x = torch.stack(x)
    x_lengths = torch.tensor(x_lengths)
    return (x,x_lengths)

In [15]:
# !wget https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt

filename = "training.txt"

with open(filename, "r") as f:
  lines = f.readlines() # each line of lines will have one word

alphabet = list(Counter(("".join(lines))).keys())
train_lines, valid_lines = train_test_split(lines, test_size=0.1, random_state=42)
train_dataset = TextDataset(train_lines)
valid_dataset = TextDataset(valid_lines)

M = len(alphabet)

In [16]:
from tqdm import tqdm # for displaying progress bar

class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.00001)
  
  def train(self, dataset):
    train_loss = 0
    num_samples = 0
    self.model.train()
    print_interval = 50
    for idx, batch in enumerate(tqdm(dataset.loader)):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      train_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        for _ in range(5):
          sampled_x, sampled_z = self.model.sample()
          print(decode(sampled_x))
          print(sampled_z)
    train_loss /= num_samples
    return train_loss

  def test(self, dataset):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    print_interval = 50
    for idx, batch in enumerate(dataset.loader):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      test_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        sampled_x, sampled_z = self.model.sample()
        print(decode(sampled_x))
        print(sampled_z)
    test_loss /= num_samples
    return test_loss

In [17]:
# Initialize model
model = HMM(N=64, M=M)

# Train the model
num_epochs = 10
trainer = Trainer(model, lr=0.01)

for epoch in range(num_epochs):
        print("========= Epoch %d of %d =========" % (epoch+1, num_epochs))
        train_loss = trainer.train(train_dataset)
        valid_loss = trainer.test(valid_dataset)

        print("========= Results: epoch %d of %d =========" % (epoch+1, num_epochs))
        print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss) )



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 38.04960632324219
kLXe-bfWnN
[53, 53, 42, 63, 21, 32, 40, 35, 44, 16]
DImAmwKvLq
[55, 17, 39, 33, 46, 60, 33, 34, 7, 48]


  0%|          | 1/208 [00:03<10:48,  3.13s/it]

jqaSSWQxVL
[6, 32, 30, 30, 52, 16, 44, 21, 45, 52]
epQ
KPzHkq
[47, 28, 2, 0, 63, 18, 39, 14, 14, 9]
xC
PnB
oEx
[7, 48, 8, 11, 10, 10, 53, 33, 14, 46]


 24%|██▍       | 50/208 [00:16<00:42,  3.73it/s]

loss: 32.982521057128906
mnYlaASezq
[30, 4, 48, 36, 14, 44, 21, 50, 21, 53]
amlizxEUyN
[26, 37, 28, 0, 35, 12, 25, 40, 60, 16]
MpQMRQnscp
[37, 0, 14, 29, 7, 63, 25, 49, 50, 16]
qdrrat
uae
[24, 34, 22, 52, 30, 28, 58, 30, 49, 19]


 25%|██▍       | 51/208 [00:17<00:53,  2.93it/s]

KtjlsoWzru
[22, 4, 52, 6, 32, 34, 58, 19, 48, 56]


 48%|████▊     | 100/208 [00:30<00:29,  3.70it/s]

loss: 29.74675750732422
kerSmidrwt
[46, 45, 36, 47, 58, 42, 6, 36, 38, 9]
lmfcttRnpG
[45, 10, 4, 25, 21, 32, 0, 29, 14, 43]
OLrsoaEnTe
[1, 45, 36, 18, 52, 18, 36, 32, 2, 47]
lrsptCa
lW
[26, 22, 26, 34, 28, 24, 4, 49, 38, 37]


 49%|████▊     | 101/208 [00:31<00:36,  2.93it/s]

rNtthcXocr
[59, 16, 38, 55, 31, 21, 50, 21, 50, 22]


 72%|███████▏  | 150/208 [00:44<00:16,  3.56it/s]

loss: 28.041339874267578
ddOhypicge
[46, 32, 16, 11, 56, 62, 16, 25, 8, 0]
uJtmnvoncn
[46, 0, 42, 57, 0, 29, 0, 29, 50, 21]
qeWPoihpng
[7, 37, 0, 39, 14, 14, 16, 46, 4, 18]
oucytreooi
[18, 16, 6, 47, 55, 36, 47, 4, 26, 0]


 73%|███████▎  | 151/208 [00:45<00:19,  2.86it/s]

tPejumaCbt
[22, 20, 50, 21, 52, 39, 0, 59, 26, 55]


 96%|█████████▌| 200/208 [00:58<00:02,  3.80it/s]

loss: 27.020925521850586
drtlstotoo
[30, 16, 42, 63, 9, 55, 0, 29, 14, 0]
edgucgluin
[47, 21, 24, 37, 25, 3, 34, 4, 16, 25]
cletoaenrJ
[30, 53, 59, 39, 14, 26, 50, 21, 42, 22]
ladeliZhxc
[46, 0, 29, 45, 22, 0, 29, 50, 21, 26]


 97%|█████████▋| 201/208 [00:58<00:02,  3.05it/s]

yepheritHc
[7, 47, 15, 55, 45, 36, 49, 42, 4, 30]


100%|██████████| 208/208 [01:00<00:00,  3.44it/s]


loss: 26.387760162353516
kanemipron
[37, 0, 29, 50, 18, 16, 9, 7, 60, 25]
train loss: 30.49| valid loss: 26.45



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 26.39613151550293
-rrhcchznt
[1, 16, 9, 55, 45, 30, 55, 0, 29, 55]
dluNlyWMcH
[46, 51, 37, 34, 34, 60, 9, 45, 30, 55]
gytteeatyp
[24, 0, 29, 55, 45, 63, 0, 42, 5, 9]


  0%|          | 1/208 [00:00<01:52,  1.84it/s]

rateguanip
[22, 0, 14, 13, 14, 30, 0, 21, 24, 22]
tcuclyofai
[24, 24, 4, 26, 34, 60, 4, 16, 30, 52]


 24%|██▍       | 50/208 [00:14<00:42,  3.69it/s]

loss: 25.71279525756836
niorapiyug
[29, 16, 0, 59, 0, 29, 16, 26, 37, 29]
saxobyolin
[24, 37, 62, 0, 35, 5, 0, 34, 0, 29]
vtrFasesbu
[46, 14, 36, 63, 0, 29, 50, 21, 26, 37]
ePXecohcli
[50, 36, 51, 45, 30, 4, 39, 26, 34, 16]


 25%|██▍       | 51/208 [00:14<00:53,  2.91it/s]

oqeotoaict
[37, 20, 50, 21, 26, 0, 35, 16, 9, 55]


 48%|████▊     | 100/208 [00:28<00:32,  3.29it/s]

loss: 25.006629943847656
PaticEniri
[24, 4, 42, 16, 25, 50, 21, 49, 36, 16]
weeratisha
[46, 0, 29, 22, 0, 29, 16, 9, 55, 45]
iftekgfyup
[0, 29, 55, 45, 36, 8, 51, 60, 4, 26]


 49%|████▊     | 101/208 [00:28<00:42,  2.51it/s]

nickercoto
[29, 16, 25, 3, 45, 36, 30, 60, 29, 52]
priverboti
[24, 22, 0, 39, 45, 36, 34, 49, 42, 16]


 72%|███████▏  | 150/208 [00:42<00:15,  3.81it/s]

loss: 24.798236846923828
crantlapoe
[24, 22, 0, 29, 55, 22, 0, 29, 47, 45]
bigiinaden
[8, 0, 29, 16, 0, 29, 16, 25, 50, 21]
cymctaxUed
[24, 45, 21, 30, 55, 0, 21, 26, 45, 36]
unrhelpiep
[37, 21, 30, 55, 45, 36, 39, 5, 0, 29]


 73%|███████▎  | 151/208 [00:42<00:19,  2.98it/s]

icWohoTtri
[16, 25, 8, 0, 29, 45, 36, 42, 22, 52]


 96%|█████████▌| 200/208 [00:55<00:02,  3.73it/s]

loss: 24.698890686035156
fagousiver
[46, 0, 29, 52, 4, 9, 16, 53, 45, 36]
voflalnstE
[46, 0, 26, 34, 45, 36, 29, 9, 55, 14]
jhocynateg
[30, 55, 0, 29, 60, 29, 0, 29, 0, 29]


 97%|█████████▋| 201/208 [00:56<00:02,  2.92it/s]

dandqprsim
[46, 0, 21, 42, 40, 45, 36, 30, 52, 39]
perlhiahio
[46, 45, 36, 9, 55, 16, 9, 55, 16, 0]


100%|██████████| 208/208 [00:58<00:00,  3.58it/s]


loss: 24.69890785217285
pablyneKju
[24, 0, 32, 34, 60, 29, 45, 36, 11, 37]
train loss: 25.18| valid loss: 24.54



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 24.635353088378906
abrotenshi
[0, 29, 22, 0, 42, 45, 36, 9, 55, 16]
thishitieg
[30, 55, 16, 30, 55, 16, 9, 16, 0, 29]
suprorbest
[24, 4, 55, 22, 0, 36, 26, 50, 21, 42]


  0%|          | 1/208 [00:00<01:47,  1.92it/s]

oOtotariok
[0, 4, 42, 0, 29, 0, 36, 16, 0, 29]
eanbatging
[0, 37, 21, 26, 0, 29, 3, 16, 25, 3]


 24%|██▍       | 50/208 [00:14<00:43,  3.66it/s]

loss: 24.287635803222656
lyactorast
[22, 5, 0, 9, 55, 0, 36, 49, 42, 60]
irgeistilm
[0, 29, 31, 45, 0, 9, 55, 16, 34, 39]
padengusyb
[43, 0, 42, 45, 36, 29, 16, 30, 60, 26]


 25%|██▍       | 51/208 [00:14<00:55,  2.85it/s]

prozylopke
[24, 22, 0, 39, 60, 34, 0, 9, 55, 45]
cenWidetia
[46, 52, 21, 42, 16, 25, 45, 42, 16, 0]


 48%|████▊     | 100/208 [00:27<00:28,  3.73it/s]

loss: 24.18640899658203
peeskicina
[24, 45, 0, 9, 55, 16, 9, 16, 25, 50]
siddigilat
[30, 0, 36, 29, 16, 25, 3, 34, 60, 29]
Trasiolana
[24, 22, 0, 29, 16, 0, 29, 0, 29, 0]


 49%|████▊     | 101/208 [00:28<00:37,  2.88it/s]

onatoamizi
[0, 29, 0, 29, 0, 37, 39, 16, 25, 16]
motinilate
[39, 0, 29, 16, 25, 3, 34, 45, 42, 45]


 72%|███████▏  | 150/208 [00:42<00:15,  3.73it/s]

loss: 24.31041717529297
pransedNdB
[24, 22, 0, 29, 42, 50, 6, 16, 25, 3]
bomatatete
[46, 52, 39, 49, 42, 45, 42, 49, 42, 50]
poticoicai
[46, 0, 42, 16, 42, 60, 16, 25, 3, 16]


 73%|███████▎  | 151/208 [00:42<00:20,  2.78it/s]

shertigere
[9, 55, 45, 36, 42, 16, 31, 45, 36, 50]
stymmenmba
[30, 55, 60, 39, 39, 45, 36, 39, 18, 0]


 96%|█████████▌| 200/208 [00:56<00:02,  3.67it/s]

loss: 24.15997314453125
incerouneg
[37, 21, 42, 45, 36, 38, 37, 29, 0, 29]
semstynsoc
[46, 52, 39, 49, 42, 60, 21, 30, 0, 29]
bWreereter
[46, 0, 36, 47, 45, 36, 52, 24, 45, 36]
tagaxoraty
[24, 0, 29, 0, 29, 0, 36, 14, 55, 60]


 97%|█████████▋| 201/208 [00:56<00:02,  2.91it/s]

nibinalars
[29, 16, 32, 16, 25, 3, 34, 0, 36, 30]


100%|██████████| 208/208 [00:58<00:00,  3.56it/s]


loss: 23.726512908935547
wostormban
[24, 0, 21, 42, 45, 36, 39, 18, 0, 29]
train loss: 24.25| valid loss: 24.11



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 24.068622589111328
oeyrostorw
[0, 36, 60, 22, 0, 21, 42, 45, 36, 18]
rievarypar
[22, 16, 0, 53, 45, 22, 5, 8, 0, 36]
pwetiaxmen
[24, 22, 50, 42, 16, 0, 21, 39, 50, 21]
dullenRhig
[40, 4, 34, 34, 50, 21, 30, 55, 16, 25]


  0%|          | 1/208 [00:00<01:47,  1.93it/s]

flydroviph
[51, 34, 60, 24, 22, 0, 53, 16, 9, 55]


 24%|██▍       | 50/208 [00:13<00:42,  3.73it/s]

loss: 23.755441665649414
daqurisher
[46, 0, 40, 4, 36, 16, 9, 55, 45, 36]
stypluYiom
[30, 55, 60, 26, 34, 4, 35, 16, 0, 39]
shamerount
[30, 55, 0, 39, 45, 36, 38, 37, 21, 55]
utgrovefco
[37, 29, 31, 22, 0, 53, 50, 21, 30, 0]


 25%|██▍       | 51/208 [00:14<00:53,  2.94it/s]

nocerilyst
[22, 52, 26, 45, 36, 16, 34, 60, 9, 55]


 48%|████▊     | 100/208 [00:27<00:29,  3.68it/s]

loss: 23.777591705322266
veaikillka
[53, 50, 6, 16, 53, 16, 34, 34, 53, 45]
ushinaciac
[37, 30, 55, 16, 25, 3, 25, 16, 0, 9]
lantelymvm
[22, 0, 21, 42, 45, 34, 60, 39, 39, 39]
kanflylame
[46, 0, 21, 51, 22, 60, 34, 0, 39, 50]


 49%|████▊     | 101/208 [00:28<00:37,  2.85it/s]

rescambedl
[22, 50, 21, 9, 0, 39, 18, 45, 6, 34]


 72%|███████▏  | 150/208 [00:41<00:15,  3.74it/s]

loss: 23.57806396484375
Eotiandanr
[24, 0, 42, 16, 0, 21, 42, 0, 29, 22]
hamiptetan
[46, 0, 39, 16, 9, 55, 45, 42, 0, 21]
alypaviaxe
[0, 34, 60, 26, 0, 53, 16, 0, 29, 0]
agablurmMr
[0, 29, 0, 32, 34, 45, 36, 39, 45, 36]


 73%|███████▎  | 151/208 [00:42<00:19,  2.92it/s]

biZmacashi
[46, 16, 9, 55, 0, 25, 3, 9, 55, 16]


 96%|█████████▌| 200/208 [00:55<00:02,  3.75it/s]

loss: 24.040958404541016
sencineica
[30, 45, 21, 30, 16, 25, 50, 16, 9, 0]
munaesenti
[46, 4, 25, 3, 50, 21, 50, 21, 42, 16]
uncenguati
[37, 21, 42, 50, 21, 40, 4, 49, 42, 16]


 97%|█████████▋| 201/208 [00:56<00:02,  2.84it/s]

onediskrer
[0, 29, 50, 6, 16, 9, 55, 22, 45, 36]
stltheerli
[30, 50, 21, 42, 7, 52, 45, 36, 38, 37]


100%|██████████| 208/208 [00:57<00:00,  3.59it/s]


loss: 24.051681518554688
hostiahush
[46, 0, 9, 55, 16, 0, 43, 4, 9, 55]
train loss: 23.97| valid loss: 23.93



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 24.166139602661133
lorsnabliz
[22, 0, 36, 30, 25, 3, 32, 34, 16, 25]
culescolle
[30, 4, 34, 50, 21, 30, 0, 34, 34, 52]
droeendici
[24, 22, 52, 50, 0, 21, 42, 16, 9, 16]


  0%|          | 1/208 [00:00<01:48,  1.92it/s]

spinatiark
[30, 55, 16, 25, 49, 42, 16, 0, 36, 14]
sunarbochi
[30, 4, 26, 45, 36, 8, 0, 9, 55, 16]


 24%|██▍       | 50/208 [00:13<00:43,  3.66it/s]

loss: 23.644033432006836
symadirpte
[30, 60, 39, 49, 42, 45, 36, 9, 55, 52]
endelusana
[50, 21, 42, 45, 34, 5, 9, 0, 21, 30]
paritoriol
[24, 0, 36, 49, 42, 45, 36, 16, 0, 34]


 25%|██▍       | 51/208 [00:14<00:54,  2.87it/s]

qZiteatety
[40, 4, 16, 42, 45, 49, 42, 52, 29, 60]
umpurlicen
[37, 39, 18, 4, 36, 34, 16, 25, 50, 21]


 48%|████▊     | 100/208 [00:27<00:28,  3.76it/s]

loss: 23.997222900390625
terarregla
[24, 45, 36, 0, 36, 22, 50, 3, 34, 0]
bantedilll
[18, 0, 21, 42, 52, 29, 16, 34, 34, 34]
riamitenoc
[22, 16, 0, 39, 16, 42, 52, 29, 0, 9]
parQtolmpl
[26, 45, 36, 49, 42, 45, 36, 39, 26, 34]


 49%|████▊     | 101/208 [00:27<00:36,  2.95it/s]

isuinailim
[37, 30, 4, 16, 29, 0, 16, 34, 16, 40]


 72%|███████▏  | 150/208 [00:40<00:15,  3.75it/s]

loss: 23.67516326904297
dAntinaler
[46, 37, 21, 42, 16, 25, 3, 34, 0, 36]
caecapllis
[46, 0, 52, 8, 0, 26, 34, 34, 16, 9]
nackapolen
[46, 0, 25, 31, 0, 26, 0, 34, 50, 21]
demonenshy
[46, 0, 39, 0, 21, 50, 21, 30, 55, 60]


 73%|███████▎  | 151/208 [00:41<00:19,  2.92it/s]

esferedayi
[50, 21, 51, 45, 36, 50, 6, 0, 29, 16]


 96%|█████████▌| 200/208 [00:54<00:02,  3.61it/s]

loss: 23.630165100097656
Cylogithet
[24, 60, 34, 52, 29, 16, 9, 55, 45, 42]
kermJgacom
[46, 45, 36, 39, 49, 11, 49, 42, 52, 39]
istrototng
[37, 21, 42, 22, 52, 29, 52, 42, 25, 3]
hobuerxord
[7, 52, 8, 4, 45, 36, 29, 0, 36, 6]


 97%|█████████▋| 201/208 [00:55<00:02,  2.82it/s]

vemoredHer
[46, 0, 39, 0, 34, 0, 29, 46, 45, 36]


100%|██████████| 208/208 [00:56<00:00,  3.66it/s]


loss: 24.044784545898438
onensterip
[0, 29, 50, 21, 30, 55, 0, 36, 16, 9]
train loss: 23.83| valid loss: 23.82



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.885086059570312
Benebitfur
[46, 0, 25, 50, 32, 16, 42, 62, 4, 36]
Hizgzehale
[46, 16, 25, 31, 22, 52, 8, 0, 34, 52]
onivuruxno
[52, 29, 16, 53, 45, 22, 37, 21, 25, 3]


  0%|          | 1/208 [00:00<01:52,  1.84it/s]

gamopturce
[46, 0, 39, 0, 14, 55, 45, 36, 8, 0]
dhibyiteci
[42, 55, 16, 32, 53, 16, 42, 52, 9, 16]


 24%|██▍       | 50/208 [00:13<00:42,  3.71it/s]

loss: 24.263641357421875
qulereapra
[40, 4, 34, 50, 22, 52, 0, 26, 22, 0]
ablartpric
[0, 32, 34, 0, 21, 30, 24, 22, 16, 42]
ctelonnero
[24, 55, 0, 34, 0, 21, 25, 45, 36, 38]
squtyacpru
[30, 40, 4, 42, 60, 0, 9, 55, 22, 37]


 25%|██▍       | 51/208 [00:14<00:53,  2.91it/s]

nisdllogro
[46, 16, 21, 42, 28, 34, 0, 54, 22, 0]


 48%|████▊     | 100/208 [00:27<00:29,  3.64it/s]

loss: 23.761247634887695
romiimphau
[22, 52, 39, 16, 0, 39, 14, 55, 0, 54]
ladiposite
[8, 0, 29, 16, 26, 0, 9, 16, 42, 0]
ilusuntagi
[37, 21, 37, 21, 37, 21, 42, 0, 29, 16]
recporpans
[22, 52, 9, 55, 45, 36, 26, 0, 29, 40]


 49%|████▊     | 101/208 [00:28<00:37,  2.86it/s]

coscrephyp
[8, 0, 9, 55, 22, 52, 26, 43, 60, 26]


 72%|███████▏  | 150/208 [00:41<00:16,  3.52it/s]

loss: 23.517261505126953
garatouste
[46, 0, 36, 49, 42, 38, 37, 21, 42, 50]
bistiamane
[32, 16, 9, 55, 16, 0, 39, 0, 25, 52]
roscinoins
[22, 0, 21, 25, 16, 25, 52, 37, 21, 30]


 73%|███████▎  | 151/208 [00:42<00:20,  2.79it/s]

bicglerema
[46, 16, 25, 3, 34, 52, 22, 52, 39, 0]
sonousdist
[30, 0, 29, 38, 37, 21, 42, 16, 9, 55]


 96%|█████████▌| 200/208 [00:55<00:02,  3.58it/s]

loss: 23.653709411621094
parvichilo
[46, 45, 36, 35, 16, 9, 55, 16, 34, 49]
otharaline
[0, 9, 55, 0, 36, 3, 34, 16, 25, 50]
cyloletort
[8, 60, 22, 52, 34, 52, 42, 45, 36, 42]


 97%|█████████▋| 201/208 [00:56<00:02,  2.82it/s]

marliperaw
[46, 45, 36, 34, 16, 26, 45, 36, 0, 1]
entilynaph
[50, 21, 42, 16, 34, 60, 29, 0, 14, 55]


100%|██████████| 208/208 [00:57<00:00,  3.59it/s]


loss: 23.487239837646484
dolacahfer
[46, 0, 34, 0, 9, 0, 36, 51, 45, 36]
train loss: 23.73| valid loss: 23.74



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.74848175048828
entotlegen
[0, 21, 24, 0, 24, 22, 52, 8, 0, 21]
beseridith
[18, 0, 53, 45, 36, 16, 25, 16, 9, 55]
pradinathe
[24, 22, 0, 29, 16, 25, 49, 42, 22, 52]


  0%|          | 1/208 [00:00<01:55,  1.78it/s]

isoxalesti
[37, 30, 0, 29, 0, 34, 0, 9, 55, 16]
Tholateuns
[24, 43, 0, 34, 49, 42, 52, 37, 21, 30]


 24%|██▍       | 50/208 [00:14<00:42,  3.68it/s]

loss: 23.719844818115234
miseyosawo
[46, 16, 30, 50, 63, 0, 9, 0, 29, 0]
docoigyazu
[13, 52, 8, 38, 37, 29, 60, 0, 25, 5]
cublifyplt
[40, 4, 32, 34, 16, 20, 60, 26, 34, 42]
miseninale
[27, 37, 30, 52, 35, 16, 25, 3, 34, 50]


 25%|██▍       | 51/208 [00:14<00:53,  2.91it/s]

romistiarb
[22, 52, 39, 16, 9, 55, 16, 0, 36, 41]


 48%|████▊     | 100/208 [00:28<00:29,  3.65it/s]

loss: 23.925342559814453
canisatifi
[24, 0, 29, 16, 9, 49, 42, 16, 20, 16]
norofhenit
[46, 0, 29, 52, 8, 7, 52, 29, 16, 42]
alcariedig
[0, 21, 42, 45, 36, 16, 0, 6, 16, 25]
Dotismably
[46, 0, 42, 16, 9, 39, 49, 32, 34, 60]


 49%|████▊     | 101/208 [00:28<00:37,  2.88it/s]

sueticteph
[30, 4, 45, 42, 16, 9, 55, 0, 26, 43]


 72%|███████▏  | 150/208 [00:42<00:16,  3.56it/s]

loss: 24.228832244873047
Tuptaledot
[46, 5, 9, 55, 0, 34, 50, 6, 52, 42]
iwerdyxyge
[37, 24, 45, 36, 29, 60, 29, 60, 42, 45]
Katoindiad
[46, 49, 42, 38, 37, 21, 42, 16, 0, 29]


 73%|███████▎  | 151/208 [00:43<00:20,  2.75it/s]

yantirmand
[23, 0, 21, 42, 45, 36, 39, 0, 29, 6]
itronogoge
[37, 29, 22, 0, 29, 0, 29, 0, 29, 52]


 96%|█████████▌| 200/208 [00:56<00:02,  3.78it/s]

loss: 23.320175170898438
unifiapery
[37, 29, 16, 51, 16, 0, 26, 45, 36, 60]
rutinatedv
[22, 49, 42, 16, 25, 49, 42, 50, 6, 25]
quriporusa
[40, 4, 36, 16, 26, 60, 22, 5, 9, 49]
hermeyrmen
[43, 45, 36, 39, 45, 45, 36, 39, 50, 21]


 97%|█████████▋| 201/208 [00:57<00:02,  2.95it/s]

alilygarba
[0, 34, 16, 34, 60, 29, 0, 36, 18, 0]


100%|██████████| 208/208 [00:58<00:00,  3.53it/s]


loss: 23.4674129486084
usesXiscog
[37, 21, 50, 21, 30, 16, 9, 8, 0, 29]
train loss: 23.65| valid loss: 23.67



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.907119750976562
Minsurglit
[46, 37, 21, 30, 4, 36, 31, 34, 49, 42]
lidalahett
[46, 16, 25, 3, 34, 58, 55, 52, 14, 55]
Lagingwkyw
[46, 0, 29, 16, 25, 31, 1, 58, 23, 1]
cercultyci
[8, 0, 36, 8, 4, 36, 42, 60, 42, 16]


  0%|          | 1/208 [00:00<01:48,  1.92it/s]

Lapellflyt
[8, 0, 26, 45, 34, 34, 51, 34, 60, 42]


 24%|██▍       | 50/208 [00:13<00:43,  3.65it/s]

loss: 23.695032119750977
Ratoretica
[46, 45, 42, 45, 36, 49, 42, 16, 25, 3]
goshicarba
[46, 52, 9, 55, 16, 25, 3, 36, 18, 0]
bootercnid
[46, 0, 21, 42, 45, 36, 29, 29, 16, 42]


 25%|██▍       | 51/208 [00:14<00:54,  2.88it/s]

ussloGhova
[37, 21, 30, 34, 52, 26, 43, 0, 53, 45]
wetitalace
[46, 49, 42, 16, 42, 3, 34, 49, 42, 45]


 48%|████▊     | 100/208 [00:27<00:29,  3.64it/s]

loss: 23.597536087036133
stosieuted
[30, 55, 0, 9, 16, 0, 5, 42, 45, 6]
chyalinkll
[24, 43, 60, 3, 34, 16, 25, 31, 34, 34]
crasmabate
[24, 22, 0, 21, 8, 0, 32, 49, 42, 45]


 49%|████▊     | 101/208 [00:28<00:38,  2.75it/s]

jeborenzom
[13, 52, 18, 45, 36, 50, 21, 8, 52, 39]
toresecict
[24, 0, 36, 0, 9, 52, 9, 16, 9, 55]


 72%|███████▏  | 150/208 [00:41<00:15,  3.65it/s]

loss: 23.76523208618164
usorniarfl
[37, 39, 45, 36, 29, 16, 0, 36, 51, 34]
thyneighto
[24, 43, 60, 29, 0, 5, 11, 44, 42, 45]
toonsishor
[46, 52, 0, 21, 30, 16, 9, 55, 45, 36]
usimphicme
[37, 21, 16, 39, 14, 55, 16, 25, 46, 52]


 73%|███████▎  | 151/208 [00:42<00:19,  2.85it/s]

herrhepipa
[43, 45, 36, 24, 7, 52, 26, 16, 9, 0]


 96%|█████████▌| 200/208 [00:55<00:02,  3.64it/s]

loss: 23.59160804748535
cnedaexsym
[24, 22, 50, 6, 3, 50, 21, 30, 60, 39]
sChynytera
[30, 24, 43, 60, 29, 60, 42, 45, 36, 49]
Oraralyoch
[24, 22, 0, 36, 3, 34, 63, 0, 9, 55]


 97%|█████████▋| 201/208 [00:56<00:02,  2.83it/s]

nolaletice
[46, 0, 21, 0, 34, 50, 55, 16, 25, 45]
vichnargis
[35, 16, 9, 55, 25, 3, 36, 31, 16, 9]


100%|██████████| 208/208 [00:58<00:00,  3.57it/s]


loss: 23.567001342773438
tanctinest
[24, 0, 21, 9, 55, 16, 25, 52, 9, 55]
train loss: 23.59| valid loss: 23.60



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.410438537597656
prenilisis
[24, 22, 52, 29, 16, 34, 16, 30, 52, 9]
cocllllite
[8, 52, 26, 34, 34, 34, 34, 49, 42, 50]
buedellils
[18, 4, 49, 42, 50, 34, 34, 16, 21, 30]
ipoleastic
[37, 19, 0, 34, 52, 0, 9, 55, 16, 9]


  0%|          | 1/208 [00:00<01:45,  1.97it/s]

myrimusurE
[46, 60, 22, 52, 39, 4, 9, 4, 36, 24]


 24%|██▍       | 50/208 [00:13<00:42,  3.70it/s]

loss: 23.268503189086914
kheonsagre
[24, 7, 52, 0, 21, 9, 0, 54, 22, 52]
gilatolela
[46, 16, 34, 49, 42, 0, 34, 50, 22, 0]
rontiagner
[22, 0, 21, 42, 52, 0, 11, 29, 45, 36]
utcankille
[37, 21, 8, 0, 21, 42, 16, 34, 34, 50]


 25%|██▍       | 51/208 [00:14<00:55,  2.84it/s]

ednalatyss
[50, 6, 25, 3, 34, 49, 42, 60, 21, 30]


 48%|████▊     | 100/208 [00:27<00:27,  3.87it/s]

loss: 23.533544540405273
ladisulari
[22, 49, 42, 16, 9, 4, 34, 45, 36, 16]
wutsopacet
[46, 5, 42, 22, 0, 29, 0, 9, 52, 42]
soateracis
[30, 52, 49, 42, 45, 36, 49, 42, 16, 9]
vidensamag
[35, 16, 25, 50, 21, 30, 0, 39, 0, 54]


 49%|████▊     | 101/208 [00:27<00:35,  2.99it/s]

soongtinsp
[30, 52, 0, 29, 29, 29, 16, 25, 43, 14]


 72%|███████▏  | 150/208 [00:40<00:14,  4.02it/s]

loss: 23.200355529785156
destreppph
[13, 52, 9, 55, 22, 52, 26, 26, 26, 43]
Fhoningiab
[24, 43, 0, 29, 16, 25, 31, 16, 0, 32]
selonriana
[30, 45, 34, 0, 36, 22, 16, 0, 29, 52]
proKphanam
[24, 22, 52, 39, 26, 43, 0, 21, 0, 39]


 73%|███████▎  | 151/208 [00:41<00:18,  3.04it/s]

biticipflo
[43, 16, 42, 16, 42, 16, 26, 51, 22, 52]


 96%|█████████▌| 200/208 [00:53<00:02,  3.83it/s]

loss: 23.843347549438477
tovelisere
[46, 0, 53, 45, 34, 16, 39, 50, 22, 52]
urecerpoin
[37, 36, 49, 42, 45, 36, 14, 55, 16, 25]
heexinalul
[7, 52, 50, 21, 16, 25, 3, 34, 4, 34]
insuiainte
[37, 21, 30, 4, 16, 0, 37, 21, 42, 45]


 97%|█████████▋| 201/208 [00:54<00:02,  3.02it/s]

Odestensem
[37, 6, 50, 21, 42, 50, 21, 30, 52, 39]


100%|██████████| 208/208 [00:56<00:00,  3.71it/s]


loss: 23.585073471069336
yoctiadala
[46, 52, 9, 55, 16, 0, 42, 3, 34, 49]
train loss: 23.52| valid loss: 23.55



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.49643898010254
ieryonagli
[52, 45, 36, 23, 0, 29, 0, 54, 22, 16]
olattylall
[0, 34, 0, 21, 42, 60, 34, 45, 34, 34]
enderfulle
[50, 21, 42, 45, 36, 62, 4, 34, 34, 52]
raticedutt
[22, 0, 42, 16, 25, 50, 6, 5, 14, 55]


  0%|          | 1/208 [00:00<01:51,  1.86it/s]

locentosic
[46, 0, 25, 50, 21, 42, 0, 9, 16, 25]


 24%|██▍       | 50/208 [00:13<00:40,  3.88it/s]

loss: 23.765445709228516
myyrchymip
[27, 12, 45, 36, 8, 43, 60, 39, 16, 26]
fiasermatr
[51, 16, 0, 39, 45, 36, 27, 49, 42, 22]
intiverslo
[37, 21, 42, 16, 53, 45, 36, 30, 34, 0]
detesiorch
[13, 52, 55, 0, 9, 16, 0, 36, 8, 7]


 25%|██▍       | 51/208 [00:13<00:50,  3.13it/s]

phormphahe
[24, 43, 45, 36, 39, 26, 43, 14, 7, 52]


 48%|████▊     | 100/208 [00:26<00:27,  4.00it/s]

loss: 23.434844970703125
jeticeaspr
[46, 52, 42, 16, 25, 50, 21, 30, 24, 22]
cabicalyaz
[8, 0, 32, 16, 25, 3, 34, 63, 0, 29]
ocyanadusE
[52, 8, 63, 0, 29, 45, 6, 5, 30, 55]
Qunhastilo
[46, 37, 21, 43, 0, 9, 55, 16, 34, 0]


 49%|████▊     | 101/208 [00:26<00:33,  3.19it/s]

diareavedu
[35, 16, 0, 36, 52, 0, 53, 50, 6, 5]


 72%|███████▏  | 150/208 [00:39<00:15,  3.83it/s]

loss: 23.631690979003906
underonone
[37, 21, 6, 45, 36, 0, 29, 0, 25, 52]
istraningl
[37, 21, 24, 22, 0, 29, 16, 25, 31, 34]
Bsugnealeo
[24, 22, 5, 11, 29, 52, 3, 34, 50, 38]


 73%|███████▎  | 151/208 [00:39<00:19,  3.00it/s]

dacradalet
[46, 0, 54, 22, 49, 42, 3, 34, 49, 42]
joroincasK
[46, 0, 36, 38, 37, 21, 8, 0, 21, 24]


 96%|█████████▌| 200/208 [00:52<00:02,  3.71it/s]

loss: 23.789161682128906
folepaelde
[51, 45, 34, 0, 26, 45, 50, 34, 42, 52]
coshadedoc
[8, 52, 9, 55, 45, 6, 45, 6, 0, 8]
palygwiten
[24, 3, 34, 60, 54, 22, 49, 42, 50, 21]
ratiticani
[22, 49, 42, 49, 42, 16, 8, 0, 36, 16]


 97%|█████████▋| 201/208 [00:53<00:02,  2.90it/s]

muntonalen
[27, 37, 21, 42, 0, 25, 3, 34, 50, 21]


100%|██████████| 208/208 [00:54<00:00,  3.78it/s]


loss: 23.48554039001465
ursyummper
[37, 36, 30, 60, 59, 39, 39, 26, 45, 36]
train loss: 23.48| valid loss: 23.51



You may wish to try different values of $N$ and see what the impact on sample quality is.

In [18]:
x = torch.tensor(encode("quack")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T))

x = torch.tensor(encode("quick")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T))

x = torch.tensor(encode("qurck")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T)) # should have lower probability---in English only vowels follow "qu"

x = torch.tensor(encode("qiick")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T)) # should have lower probability---in English only "u" follows "q"


([[40, 4, 0, 25, 31]], tensor([[-16.2346]], device='mps:0', grad_fn=<GatherBackward0>))
([[40, 4, 16, 25, 31]], tensor([[-13.1426]], device='mps:0', grad_fn=<GatherBackward0>))
([[40, 4, 36, 14, 55]], tensor([[-16.1358]], device='mps:0', grad_fn=<GatherBackward0>))
([[40, 4, 16, 25, 31]], tensor([[-20.7870]], device='mps:0', grad_fn=<GatherBackward0>))
