# LSTM model for predicting intent based on phonemes

In [None]:
!nvidia-smi
!git clone https://github.com/wwzeng1/Intent-Recognition.git

Sat May  8 17:00:47 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [None]:
import os
import numpy as np
import time

import pandas as pd
import torch 
import torch.nn as nn
import torch.optim as optim
from google.colab import files

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import random_split
from more_itertools import sort_together

## Hyperparams

In [None]:
all_phones = ['I', 'a', 'aː', 'ã', 'ă', 'b', 'bʲ', 'bʲj', 'bʷ', 'bʼ', 'bː', 'b̞',
              'b̤', 'b̥', 'c', 'd', 'dʒ', 'dʲ', 'dː', 'd̚', 'd̥', 'd̪', 'd̯', 'd͡z', 
              'd͡ʑ', 'd͡ʒ', 'd͡ʒː', 'd͡ʒ̤', 'e', 'eː', 'e̞', 'f', 'fʲ', 'fʷ', 'fː', 
              'g', 'gʲ', 'gʲj', 'gʷ', 'gː', 'h', 'hʷ', 'i', 'ij', 'iː', 'i̞', 
              'i̥', 'i̯', 'j', 'k', 'kx', 'kʰ', 'kʲ', 'kʲj', 'kʷ', 'kʷʼ', 'kʼ', 
              'kː', 'k̟ʲ', 'k̟̚', 'k͡p̚', 'l', 'lʲ', 'lː', 'l̪', 'm', 'mʲ', 'mʲj', 
              'mʷ', 'mː', 'n', 'nj', 'nʲ', 'nː', 'n̪', 'n̺', 'o', 'oː', 'o̞', 'o̥', 
              'p', 'pf', 'pʰ', 'pʲ', 'pʲj', 'pʷ', 'pʷʼ', 'pʼ', 'pː', 'p̚', 'q', 
              'r', 'rː', 's', 'sʲ', 'sʼ', 'sː', 's̪', 't', 'ts', 'tsʰ', 'tɕ', 
              'tɕʰ', 'tʂ', 'tʂʰ', 'tʃ', 'tʰ', 'tʲ', 'tʷʼ', 'tʼ', 'tː', 't̚', 
              't̪', 't̪ʰ', 't̪̚', 't͡s', 't͡sʼ', 't͡ɕ', 't͡ɬ', 't͡ʃ', 't͡ʃʲ', 't͡ʃʼ', 't͡ʃː',
              'u', 'uə', 'uː', 'u͡w', 'v', 'vʲ', 'vʷ', 'vː', 'v̞', 'v̞ʲ', 'w', 'x',
              'x̟ʲ', 'y', 'z', 'zj', 'zʲ', 'z̪', 'ä', 'æ', 'ç', 'çj', 'ð', 'ø', 'ŋ',
              'ŋ̟', 'ŋ͡m', 'œ', 'œ̃', 'ɐ', 'ɐ̞', 'ɑ', 'ɑ̱', 'ɒ', 'ɓ', 'ɔ', 'ɔ̃', 'ɕ', 
              'ɕː', 'ɖ̤', 'ɗ', 'ə', 'ɛ', 'ɛ̃', 'ɟ', 'ɡ', 'ɡʲ', 'ɡ̤', 'ɡ̥', 'ɣ', 'ɣj',
              'ɤ', 'ɤɐ̞', 'ɤ̆', 'ɥ', 'ɦ', 'ɨ', 'ɪ', 'ɫ', 'ɯ', 'ɯ̟', 'ɯ̥', 'ɰ', 'ɱ', 
              'ɲ', 'ɳ', 'ɴ', 'ɵ', 'ɸ', 'ɹ', 'ɹ̩', 'ɻ', 'ɻ̩', 'ɽ', 'ɾ', 'ɾj', 'ɾʲ',
              'ɾ̠', 'ʀ', 'ʁ', 'ʁ̝', 'ʂ', 'ʃ', 'ʃʲː', 'ʃ͡ɣ', 'ʈ', 'ʉ̞', 'ʊ', 'ʋ', 'ʋʲ',
              'ʌ', 'ʎ', 'ʏ', 'ʐ', 'ʑ', 'ʒ', 'ʒ͡ɣ', 'ʔ', 'ʝ', 'ː', 'β', 'β̞', 'θ', 
              'χ', 'ә', 'ḁ']
phone2index = {'':0}
for idx in range(0, len(all_phones)):
  phone = all_phones[idx]
  phone2index[phone] = idx + 1

In [None]:
!cp /content/Intent-Recognition/Code/raw_corpus.npy .
!cp /content/Intent-Recognition/Dutch/grabo_train_data.npy .
!cp /content/Intent-Recognition/Dutch/grabo_train_labels.npy .
!cp /content/Intent-Recognition/Dutch/grabo_dev_data.npy .
!cp /content/Intent-Recognition/Dutch/grabo_dev_labels.npy .
!cp /content/Intent-Recognition/Dutch/grabo_test_data.npy .
!cp /content/Intent-Recognition/Dutch/grabo_test_labels.npy .

## Dataset and Collate

In [None]:
X_pretrain = np.load('raw_corpus.npy', allow_pickle=True)
X_pretrain = np.vectorize(phone2index.get)(X_pretrain)
train_data = np.load('grabo_train_data.npy', allow_pickle=True)
train_labels = np.load('grabo_train_labels.npy', allow_pickle=True)
dev_data = np.load('grabo_dev_data.npy', allow_pickle=True)
dev_labels = np.load('grabo_dev_labels.npy', allow_pickle=True)
test_data = np.load('grabo_test_data.npy', allow_pickle=True)
test_labels = np.load('grabo_test_labels.npy', allow_pickle=True)

In [None]:
class PretrainDataset(Dataset):
    def __init__(self, X, seq_len, num_phones):
        self.X = X
        self.length = len(X)
        self.seq_len = seq_len
        self.total_len = seq_len + 2
        self.num_phones = num_phones
        self.index_mapping = [split for split in range(0, self.length, self.total_len)]
        self.index_mapping = self.index_mapping[:-2] # account for offset and incomplete sequences

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, item):
        # sample offset so we don't get the same items each time
        offset = np.random.randint(low=0, high=self.seq_len)
        start = self.index_mapping[item] + offset
        end = start + self.seq_len
        x = self.X[start+1:end+1]
        y1 = self.X[start]
        y2 = self.X[end + 1]
        x, y1, y2 = torch.Tensor(x), y1, y2
        x = x.long()
        return x, y1, y2
def pretrain_collate(batch):
    data = [x[0] for x in batch]
    prev_labels = [x[1] for x in batch]
    next_labels = [x[2] for x in batch]
    padded_data = pad_sequence(data, batch_first=True, padding_value=0.0) 
    data_lens = torch.from_numpy(np.array([len(x) for x in data]))

    return padded_data, data_lens, prev_labels, next_labels

## Model

In [None]:
class CNN_LSTM(nn.Module):
    def __init__(self, in_features, num_classes, pretrain_classes, pretrain = False):
        super().__init__()
        
        self.dropout = nn.Dropout(p=0.6)
        embedding_dim = 128
        self.embedding = nn.Embedding(in_features, embedding_dim)
        self.conv1 = nn.Conv1d(embedding_dim, 128, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv1d(embedding_dim, 128, kernel_size=5, stride=1, padding=2, bias=True)
        
        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(128)

        self.lstm1 = nn.LSTM(input_size = 256, hidden_size = 128, num_layers = 6, batch_first = True,
                       bidirectional = True)
        
        # returns the final outputs
        self.lin1 = nn.Linear(1536, 256)
        self.relu = nn.ReLU()
        self.lin_out = nn.Linear(256, num_classes)
        
        self.pretrain_fwd = nn.Linear(128, 256)
        self.pretrain_bwd = nn.Linear(128, 256)
        self.pretrain_out = nn.Linear(256, pretrain_classes)

    def forward(self, x, x_lens, mode):
        # input is B, T, in_features
        x = self.embedding(x)
        x = x.permute(0, 2, 1)
        # x is B, embedding_dim, T
        out1 = self.conv1(x) # out1 is B, H, T
        out1 = self.bn1(out1)
        out2 = self.conv2(x) # out2 is B, H, T
        out2 = self.bn2(out2)
        # out is B, 2H, T
        out = torch.cat((out1, out2), dim = 1)
        if mode == 'train':
          out = self.dropout(out)
        out = out.permute(0, 2, 1) # out is B, T, H
        out = pack_padded_sequence(out, x_lens.cpu(), batch_first = True, \
          enforce_sorted=False)
        out, (hidden, cell) = self.lstm1(out)
        # will be batch_size, num_classes
        B = hidden.shape[2]
        if mode == 'pretrain':
          out, out_lens = pad_packed_sequence(out)
          hiddens = out.view(-1, B, 2, 128)
          prev_word = hiddens[0, :, 1, :]
          out_prev = self.pretrain_bwd(prev_word)
          out_prev = self.dropout(out_prev)
          out_prev = self.relu(out_prev)
          out_prev = self.pretrain_out(out_prev)
          next_word = hiddens[-1, :, 0, :]
          out_next = self.pretrain_fwd(next_word)
          out_next = self.dropout(out_next)
          out_next = self.relu(out_next)
          out_next = self.pretrain_out(out_next)
          return out_prev, out_next
        hidden = hidden.permute(1, 0, 2) # (batch, L*D, hidden_size)
        hidden = torch.flatten(hidden, start_dim = 1)
        out = self.lin1(hidden)
        if mode == 'train':
          out = self.dropout(out)
        out = self.relu(out)
        out = self.lin_out(out)
        return out

In [None]:
batch_size = 128
num_workers = 8
seq_len = 64
phoneme_num = len(all_phones)
pretrain_dataset = PretrainDataset(X_pretrain, seq_len, phoneme_num)
pretrain_loader = DataLoader(pretrain_dataset, batch_size = batch_size, shuffle = True,
                            collate_fn = pretrain_collate, num_workers = num_workers)

  cpuset_checked))


In [None]:
# Pretraining hyperparameters
actions_num = 36
num_epochs = 250

learning_rate = 1e-4
weight_decay = 5e-5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

network = CNN_LSTM(phoneme_num, actions_num, phoneme_num, pretrain = True)
network = network.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr = learning_rate, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, \
  patience=10, threshold=0.0001, threshold_mode='rel', cooldown=10, min_lr=0, eps=1e-08, verbose=True)

In [None]:
for epoch in range(num_epochs):
    
    # Pretrain
    network.train()
    correct = 0.0
    total = 0
    avg_loss = 0.0
    total_bwd_perplexity = 0.0
    total_fwd_perplexity = 0.0
    for batch_num, (x, x_lens, prev_labels, next_labels) in enumerate(pretrain_loader):
        optimizer.zero_grad()
        prev_labels = torch.Tensor(prev_labels)
        next_labels = torch.Tensor(next_labels)
        x, x_lens, prev_labels, next_labels = x.to(device), x_lens.to(device), prev_labels.to(device), next_labels.to(device)
        prev_out, next_out = network(x, x_lens, 'pretrain')

        prev_loss = criterion(prev_out, prev_labels.long())
        next_loss = criterion(next_out, next_labels.long())
        loss = prev_loss + next_loss
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
        backward_perplexity = torch.exp(prev_loss)
        forward_perplexity = torch.exp(next_loss)
        total_bwd_perplexity += backward_perplexity.item()
        total_fwd_perplexity += forward_perplexity.item()

        prev_pred = torch.argmax(prev_out, dim = 1).cpu().numpy()
        prev_true = prev_labels.cpu().numpy()
        correct += np.sum(prev_pred == prev_true)
        next_pred = torch.argmax(next_out, dim = 1).cpu().numpy()
        next_true = next_labels.cpu().numpy()
        correct += np.sum(next_pred == next_true)
        total += batch_size * 2 # predicting both previous and next
        if batch_num % 50 == 0:
          print("BATCH NUM:", batch_num)
          print("BATCH LOSS:", loss.item())
          print("FORWARD PERPLEXITY:", forward_perplexity.item())
          print("BACKWARD PERPLEXITY:", backward_perplexity.item())

    print("EPOCH: ", epoch)
    print("LOSS: ", avg_loss / batch_num)
    print("TRAIN ACC: ", correct / total)
    print("AVG BWD PERPLEXITY:", total_bwd_perplexity / batch_num)
    print("AVG FWD PERPLEXITY:", total_fwd_perplexity / batch_num)


    scheduler.step(avg_loss)

In [None]:
from google.colab import files
torch.save(network.state_dict(), "pretrained_dutch.pt")
files.download('pretrained_dutch.pt')

##Training the Model

In [None]:
from google.colab import files
files.upload()

In [None]:
dev_data = [item for item in dev_data if len(item) != 0]
min([len(item) for item in dev_data])

3

In [None]:
class PhonesDataset(Dataset):
  def __init__(self, data, labels, num_phones, num_classes, mode="Train"):
    self.X = data
    self.Y = labels
    self.num_phones = num_phones
    self.num_classes = num_classes
  def __len__(self):
      return len(self.X)
  def __getitem__(self, idx):
      x = self.X[idx]
      y = self.Y[idx]
      x = np.vectorize(phone2index.get)(x)
      x = torch.Tensor(x).long()
      return x, y
def my_collate_fn(batch):
    x = [x_i for (x_i, _) in batch]
    y = [y_i for (_, y_i) in batch]
    x_len = [len(x_i) for x_i in x]
    
    x_len, x, y = sort_together([x_len, x, y], reverse=True)
    x_len, x, y = list(x_len), list(x), list(y)

    x_padded = torch.nn.utils.rnn.pad_sequence(x, batch_first = True)
    x_len = torch.Tensor(x_len)

    return x_padded, y, x_len

In [None]:
batch_size = 8 # decreasing this from 128 -> 16 helped
train_dataset = PhonesDataset(train_data, train_labels, phoneme_num, actions_num, mode="Train")
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, collate_fn = my_collate_fn)
dev_dataset = PhonesDataset(dev_data, dev_labels, phoneme_num, actions_num, mode="Valid")
dev_loader = DataLoader(dev_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, collate_fn = my_collate_fn)
test_dataset = PhonesDataset(test_data, test_labels, phoneme_num, actions_num, mode="Test")
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, collate_fn = my_collate_fn)

  cpuset_checked))


In [None]:
# Pretrained Model hyperparameters
actions_num = 36
num_epochs = 40

learning_rate = 1e-3
weight_decay = 5e-5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

network = CNN_LSTM(phoneme_num, actions_num, phoneme_num, pretrain = True)
network = network.to(device)
network.load_state_dict(torch.load('/content/pretrained_dutch.pt'))
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr = learning_rate, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, \
  patience=3, threshold=0.0001, threshold_mode='rel', cooldown=3, min_lr=0, eps=1e-08, verbose=True)

In [None]:
for epoch in range(num_epochs):
    # Train
    network.train()
    correct = 0.0
    total = 0
    avg_loss = 0.0
    for batch_num, (x, y, x_len) in enumerate(train_loader):
        optimizer.zero_grad()
        y = torch.Tensor(y)
        x, y = x.to(device), y.to(device)
        out = network(x, x_len, 'train')
        loss = criterion(out, y.long())
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
        pred = torch.argmax(out, dim = 1).cpu().numpy()
        true = y.cpu().numpy()
        correct += np.sum(pred == true)
        total += batch_size
        if batch_num % 10 == 0:
          print("BATCH NUM:", batch_num)
          print("BATCH LOSS:", loss.item())

    print("EPOCH: ", epoch)
    print("LOSS: ", avg_loss / batch_num)
    print("TRAIN ACC: ", correct / total)
    scheduler.step(avg_loss)
    correct = 0.0
    total = 0
    avg_loss = 0.0
    network.eval()
    with torch.no_grad():
      for batch_num, (x, y, x_len) in enumerate(dev_loader):
          y = torch.Tensor(y)
          x, y = x.to(device), y.to(device)
          out = network(x, x_len, 'dev')
          loss = criterion(out, y.long())
          avg_loss += loss.item()
          pred = torch.argmax(out, dim = 1).cpu().numpy()
          true = y.cpu().numpy()
          correct += np.sum(pred == true)
          total += batch_size
    print("EPOCH: ", epoch)
    print("LOSS: ", avg_loss / batch_num)
    print("VALID ACC: ", correct / total)

  cpuset_checked))


BATCH NUM: 0
BATCH LOSS: 3.590276002883911
BATCH NUM: 10
BATCH LOSS: 3.8581855297088623
BATCH NUM: 20
BATCH LOSS: 3.7033214569091797
BATCH NUM: 30
BATCH LOSS: 3.659391403198242
BATCH NUM: 40
BATCH LOSS: 3.4339118003845215
BATCH NUM: 50
BATCH LOSS: 3.46943998336792
BATCH NUM: 60
BATCH LOSS: 3.7407004833221436
BATCH NUM: 70
BATCH LOSS: 3.522886276245117
BATCH NUM: 80
BATCH LOSS: 3.5160555839538574
BATCH NUM: 90
BATCH LOSS: 3.518364906311035
BATCH NUM: 100
BATCH LOSS: 3.4357845783233643
BATCH NUM: 110
BATCH LOSS: 3.6865665912628174
BATCH NUM: 120
BATCH LOSS: 3.4749069213867188
BATCH NUM: 130
BATCH LOSS: 2.811558723449707
BATCH NUM: 140
BATCH LOSS: 3.58172345161438
BATCH NUM: 150
BATCH LOSS: 3.3634557723999023
BATCH NUM: 160
BATCH LOSS: 3.490199565887451
BATCH NUM: 170
BATCH LOSS: 3.372067928314209
BATCH NUM: 180
BATCH LOSS: 3.1473910808563232
BATCH NUM: 190
BATCH LOSS: 3.2416093349456787
BATCH NUM: 200
BATCH LOSS: 2.7390081882476807
BATCH NUM: 210
BATCH LOSS: 3.428964138031006
BATCH NUM: 

In [None]:
network.eval()
correct = 0.0
total = 0.0
with torch.no_grad():
  for batch_num, (x, y, x_len) in enumerate(test_loader):
      y = torch.Tensor(y)
      x, y = x.to(device), y.to(device)
      out = network(x, x_len, 'test')
      loss = criterion(out, y.long())
      avg_loss += loss.item()
      pred = torch.argmax(out, dim = 1).cpu().numpy()
      true = y.cpu().numpy()
      correct += np.sum(pred == true)
      total += batch_size
print(correct/total)

  cpuset_checked))


0.9006622516556292


In [None]:
# Baseline Model hyperparameters
actions_num = 36
num_epochs = 40

learning_rate = 1e-3
weight_decay = 5e-5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

network = CNN_LSTM(phoneme_num, actions_num, phoneme_num, pretrain = True)
network = network.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr = learning_rate, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, \
  patience=3, threshold=0.0001, threshold_mode='rel', cooldown=3, min_lr=0, eps=1e-08, verbose=True)

In [None]:
for epoch in range(num_epochs):
    # Train
    network.train()
    correct = 0.0
    total = 0
    avg_loss = 0.0
    for batch_num, (x, y, x_len) in enumerate(train_loader):
        optimizer.zero_grad()
        y = torch.Tensor(y)
        x, y = x.to(device), y.to(device)
        out = network(x, x_len, 'train')
        loss = criterion(out, y.long())
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
        pred = torch.argmax(out, dim = 1).cpu().numpy()
        true = y.cpu().numpy()
        correct += np.sum(pred == true)
        total += batch_size
        if batch_num % 10 == 0:
          print("BATCH NUM:", batch_num)
          print("BATCH LOSS:", loss.item())

    print("EPOCH: ", epoch)
    print("LOSS: ", avg_loss / batch_num)
    print("TRAIN ACC: ", correct / total)
    scheduler.step(avg_loss)
    correct = 0.0
    total = 0
    avg_loss = 0.0
    network.eval()
    with torch.no_grad():
      for batch_num, (x, y, x_len) in enumerate(dev_loader):
          y = torch.Tensor(y)
          x, y = x.to(device), y.to(device)
          out = network(x, x_len, 'dev')
          loss = criterion(out, y.long())
          avg_loss += loss.item()
          pred = torch.argmax(out, dim = 1).cpu().numpy()
          true = y.cpu().numpy()
          correct += np.sum(pred == true)
          total += batch_size
    print("EPOCH: ", epoch)
    print("LOSS: ", avg_loss / batch_num)
    print("VALID ACC: ", correct / total)

  cpuset_checked))


BATCH NUM: 0
BATCH LOSS: 3.639611005783081
BATCH NUM: 10
BATCH LOSS: 3.6981654167175293
BATCH NUM: 20
BATCH LOSS: 3.5624847412109375
BATCH NUM: 30
BATCH LOSS: 3.5395774841308594
BATCH NUM: 40
BATCH LOSS: 3.56912899017334
BATCH NUM: 50
BATCH LOSS: 3.58838152885437
BATCH NUM: 60
BATCH LOSS: 3.6796164512634277
BATCH NUM: 70
BATCH LOSS: 3.5833353996276855
BATCH NUM: 80
BATCH LOSS: 3.5716745853424072
BATCH NUM: 90
BATCH LOSS: 3.584365129470825
BATCH NUM: 100
BATCH LOSS: 3.755573034286499
BATCH NUM: 110
BATCH LOSS: 3.5094223022460938
BATCH NUM: 120
BATCH LOSS: 3.6206047534942627
BATCH NUM: 130
BATCH LOSS: 3.5220744609832764
BATCH NUM: 140
BATCH LOSS: 3.0708060264587402
BATCH NUM: 150
BATCH LOSS: 3.349790096282959
BATCH NUM: 160
BATCH LOSS: 3.362287759780884
BATCH NUM: 170
BATCH LOSS: 3.2244904041290283
BATCH NUM: 180
BATCH LOSS: 3.5970730781555176
BATCH NUM: 190
BATCH LOSS: 2.688559055328369
BATCH NUM: 200
BATCH LOSS: 2.851349115371704
BATCH NUM: 210
BATCH LOSS: 3.357717990875244
BATCH NUM: 

In [None]:
network.eval()
correct = 0.0
total = 0.0
with torch.no_grad():
  for batch_num, (x, y, x_len) in enumerate(test_loader):
      y = torch.Tensor(y)
      x, y = x.to(device), y.to(device)
      out = network(x, x_len, 'test')
      loss = criterion(out, y.long())
      avg_loss += loss.item()
      pred = torch.argmax(out, dim = 1).cpu().numpy()
      true = y.cpu().numpy()
      correct += np.sum(pred == true)
      total += batch_size
print(correct/total)

  cpuset_checked))


0.9056291390728477
