In [1]:
%load_ext autoreload
%autoreload 2

from os.path import join
import os
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.init import kaiming_uniform_
from torch.optim import SGD
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.optim import Adam
from torch.autograd import Variable
from nlpClassifiers.data.dataset  import NLPDataset, Vocabulary
from sklearn.metrics import roc_curve, auc, classification_report
from scipy.special import expit
import gensim
import shutil
from pathlib import Path
import random

In [2]:
ROOT = '../../'
PATH_TO_VIRTUAL_OPERATOR_DATA = join(ROOT, "data/virtual-operator")
PATH_TO_AGENT_BENCHMARK_DATA = join(ROOT, "data/agent-benchmark")
PATH_TO_ML_PT_DATA = join(ROOT, "data/mercado-livre-pt-only")

PATH_TO_VIRTUAL_OPERATOR_MODELS = join(ROOT, "models/virtual-operator")
PATH_TO_AGENT_BENCHMARK_MODELS = join(ROOT, "models/agent-benchmark")
PATH_TO_ML_PT_MODELS = join(ROOT, "models/mercado-livre-pt-only")
#EMBEDDINGS_FILE = '../../data/agent-benchmark/vocab/word2vec-model-agent-benchmark.bin'
EMBEDDING_DIM = 300
MAX_WORDS = 30000
dataset = 'virtual-operator'

sentence_max_len = 82
batch_size = 512
biLSTM = True
epochs=30
stopwords_lang = None
gpu=1
lr=1e-3
clipping_value=0.25
save_name='bilstm-pytorch-jointly-trained-embeddings'
patience=5

In [3]:
BASE_PATH_TO_MODELS = {"virtual-operator": PATH_TO_VIRTUAL_OPERATOR_MODELS, "agent-benchmark": PATH_TO_AGENT_BENCHMARK_MODELS, "mercado-livre-pt": PATH_TO_ML_PT_MODELS}
FULL_PATH_TO_MODELS = join(BASE_PATH_TO_MODELS[dataset], "bilstm-classifier")

In [4]:
def _init_fn(worker_id):
    np.random.seed(seed)

In [5]:
def read_data(dataset, subset):
    BASE_PATH_TO_DATASET = {"virtual-operator": PATH_TO_VIRTUAL_OPERATOR_DATA, "agent-benchmark": PATH_TO_AGENT_BENCHMARK_DATA, "mercado-livre-pt": PATH_TO_ML_PT_DATA}
    BASE_PATH_TO_DATASET = {"train": join(BASE_PATH_TO_DATASET[dataset], "train.csv"), "val": join(BASE_PATH_TO_DATASET[dataset], "val.csv"), "test": join(BASE_PATH_TO_DATASET[dataset], "test.csv")}
    FULL_PATH_TO_DATASET = BASE_PATH_TO_DATASET[subset]
    
    if dataset == "mercado-livre-pt":
        sep=","
    else:
        sep=";"
    data = pd.read_csv(FULL_PATH_TO_DATASET, sep=sep, names =['utterance','label'], header=None, dtype={'utterance':str, 'label': str} )
    return data

In [6]:
train_df = read_data(dataset, "train")
val_df = read_data(dataset, "val")

In [7]:
voc = Vocabulary('CNN', stopwords_lang)
voc.build_vocab(train_df['utterance'].tolist() +  val_df['utterance'].tolist(), MAX_WORDS)
#embedding_weights = voc.load_embeddings(EMBEDDINGS_FILE, EMBEDDING_DIM)
if voc.num_words < MAX_WORDS or MAX_WORDS == 0:
    MAX_WORDS = voc.num_words

In [8]:
train_corpus = NLPDataset(dataset, "train", sentence_max_len, vocab= voc)
labels_dict = train_corpus.labels_dict
val_corpus = NLPDataset(dataset, "val", sentence_max_len, labels_dict = labels_dict, vocab= voc)

In [9]:
train_dataloader = DataLoader(
            train_corpus,
            sampler = RandomSampler(train_corpus),
            batch_size = batch_size,
            pin_memory=True,
            #shuffle=True,
            worker_init_fn=_init_fn,
            num_workers=0
)

validation_dataloader = DataLoader(
            val_corpus,
            sampler = RandomSampler(val_corpus),
            batch_size = batch_size,
            pin_memory=True,
            #shuffle=True,
            worker_init_fn=_init_fn,
            num_workers=0
)

In [10]:
#del model
torch.cuda.empty_cache()

In [11]:
def seed_everything(seed=10):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [12]:
seed_everything()

In [13]:
device = torch.device(f"cuda:{gpu}")

In [14]:
class LSTMNet(nn.Module):
    def __init__(self, device, num_classes, num_features, criterion, embedding_dim, bidirectional, embedding_weights=None):
        super(LSTMNet, self).__init__()
        self.criterion = criterion
        self.embedding_dim = embedding_dim
        self.lstm_units = embedding_dim
        self.lstm_act = nn.Tanh()
        self.num_directions = 2 if bidirectional else 1
        self.device = device
        self.n_layers = 1
        self.embedding = nn.Embedding(num_features, embedding_dim, padding_idx = 0)
        
        if(embedding_weights is not None):
            print("Embedding layer weights won't be updated.")
            self.embedding.load_state_dict({'weight': embedding_weights})
            #self.embedding.from_pretrained(embedding_weights, freeze=True)
            self.embedding.weight.requires_grad = False
        else:
            self.embedding.weight.data = torch.zeros(self.embedding.weight.data.size())
            self.embedding.weight.requires_grad = True
        
        self.lstm = nn.LSTM(embedding_dim, self.lstm_units, num_layers=self.n_layers, bidirectional=bidirectional, batch_first=True)

        self.dropout = nn.Dropout(0.2)       
        self.linear = nn.Linear(self.num_directions * self.lstm_units, num_classes)   
        self.init_weights()
        
    def forward(self, x, y):

        # lstm step => then ONLY take the sequence's final timetep to pass into the linear/dense layer
        # Note: lstm_out contains outputs for every step of the sequence we are looping over (for BPTT)
        # but we just need the output of the last step of the sequence, aka lstm_out[-1]
        #print("x input:", x.size())
        #set_trace()
        x = self.embedding(x)
        #print("embeddings:", x.size())
        x_, (h_n, c_n) = self.lstm(x)
        h = x_[:,-0,:]
        #h = self.lstm_act(h)
        '''        
        if self.num_directions == 2:
            h = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)
        else:
            h = h_n[-1, :, :]
        '''
        #print("h before activation:", h.size())
        #print("h after activation:", h.size())
        output = self.dropout(h)
        output = self.linear(output)
        #print("output after linear:", output.size())
        loss = self.criterion(output.squeeze(), y)
        #print("loss:", loss.size())
        return loss,output
    
    def init_weights(self):  
        nn.init.xavier_normal_(self.lstm.weight_hh_l0)
        nn.init.xavier_normal_(self.lstm.weight_ih_l0)
        nn.init.xavier_normal_(self.linear.weight.data)
       
        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0)
            elif 'weight_ih' in name:
                nn.init.kaiming_normal_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)

In [15]:
criterion = torch.nn.CrossEntropyLoss()
model = LSTMNet(device, train_corpus.num_labels, MAX_WORDS, criterion, EMBEDDING_DIM, biLSTM)

In [16]:
model = model.to(device)
optimizer = Adam(model.parameters(),lr, betas=(0.7, 0.999))

In [17]:
def get_accuracy_from_logits(logits, labels):
    acc = (labels.cpu() == logits.cpu().argmax(-1)).float().detach().numpy()
    return float(100 * acc.sum() / len(acc))

In [18]:
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3,verbose=True, factor=0.1)

best_epoch = -1
last_saved_model = ""
training_stats = []
global_step = 0
best_val_acc = float("-inf")
best_model_wts = None
best_curr_val = 0

for epoch_i in range(0, epochs):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')
    total_train_loss = 0
    model.train()
    for step, batch in enumerate(train_dataloader):
         # Progress update every 40 batches.
        if step % 40 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.    .'.format(step, len(train_dataloader)))
        b_input_ids = batch[0]
        b_labels = batch[1].clone().detach().to(device)
        b_input_ids = b_input_ids.clone().detach().to(device).long()
        loss, logits= model(b_input_ids, b_labels)
        step_loss = loss.item()
        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
        optimizer.step()
        

        if step % 40 == 0 and not step == 0:
            print("  step loss: {0:.2f}".format(step_loss))
        total_train_loss += step_loss
        global_step += 1
    avg_train_loss = total_train_loss / len(train_dataloader)
    # Measure how long this epoch took.
    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))

    print("")
    print("Running Validation...")
    model.eval()

    # Tracking variables
    total_eval_accuracy = 0.0
    total_eval_loss = 0
    nb_eval_steps = 0
    for step, batch in enumerate(validation_dataloader):
        b_input_ids = batch[0]
        b_input_ids = b_input_ids.clone().detach().to(device).long()
        b_labels = batch[1].to(device)
        with torch.no_grad():
            # Forward pass, calculate logit predictions.
            loss, logits = model(b_input_ids, b_labels)
        total_eval_loss += loss
        batch_acc = get_accuracy_from_logits(logits, b_labels)
        print("Batch accuracy: {0:.2f}".format(batch_acc))
        total_eval_accuracy += batch_acc
    
    # Calculate the average loss over all of the batches.
    avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))
    
    if avg_val_accuracy > best_val_acc:
        print(f"New best model, saving it!")
        if avg_val_accuracy > best_curr_val:
            best_curr_val = avg_val_accuracy
        if last_saved_model:
            shutil.rmtree(last_saved_model)
        model_path = Path(
            FULL_PATH_TO_MODELS,
            f"base-dataset-{dataset}-{save_name}"
        )
        last_saved_model = model_path
        model_path.mkdir(parents=True, exist_ok=True)
        best_val_acc = avg_val_accuracy
        torch.save(model, join(model_path, "best-model.pth"))
        best_epoch = epoch_i
        n_epochs_no_improvement = 0
    elif avg_val_accuracy > best_curr_val:
        best_curr_val = avg_val_accuracy
        n_epochs_no_improvement = 0
    else:
        n_epochs_no_improvement += 1
        print(f"The model does not improve for {n_epochs_no_improvement} epochs!")

    if n_epochs_no_improvement > patience:
        print(f"====>Stopping training, the model did not improve for {n_epochs_no_improvement}\n====>Best epoch: {best_epoch + 1}.")
        break
    scheduler.step(total_eval_loss)
print("")
print("Training complete!")
    
    


Training...
  Batch    40  of    943.    .
  step loss: 2.91
  Batch    80  of    943.    .
  step loss: 2.29
  Batch   120  of    943.    .
  step loss: 1.77
  Batch   160  of    943.    .
  step loss: 1.38
  Batch   200  of    943.    .
  step loss: 1.12
  Batch   240  of    943.    .
  step loss: 1.01
  Batch   280  of    943.    .
  step loss: 0.98
  Batch   320  of    943.    .
  step loss: 0.74
  Batch   360  of    943.    .
  step loss: 0.72
  Batch   400  of    943.    .
  step loss: 0.68
  Batch   440  of    943.    .
  step loss: 0.61
  Batch   480  of    943.    .
  step loss: 0.54
  Batch   520  of    943.    .
  step loss: 0.58
  Batch   560  of    943.    .
  step loss: 0.52
  Batch   600  of    943.    .
  step loss: 0.53
  Batch   640  of    943.    .
  step loss: 0.56
  Batch   680  of    943.    .
  step loss: 0.56
  Batch   720  of    943.    .
  step loss: 0.44
  Batch   760  of    943.    .
  step loss: 0.44
  Batch   800  of    943.    .
  step loss: 0.42
  Batch

  step loss: 0.25
  Batch   880  of    943.    .
  step loss: 0.24
  Batch   920  of    943.    .
  step loss: 0.17

  Average training loss: 0.23

Running Validation...
Batch accuracy: 93.55
Batch accuracy: 92.38
Batch accuracy: 93.36
Batch accuracy: 91.60
Batch accuracy: 91.80
Batch accuracy: 93.16
Batch accuracy: 92.38
Batch accuracy: 91.80
Batch accuracy: 91.60
Batch accuracy: 91.99
Batch accuracy: 92.58
Batch accuracy: 92.97
Batch accuracy: 91.60
Batch accuracy: 93.16
Batch accuracy: 93.75
Batch accuracy: 92.77
Batch accuracy: 92.38
Batch accuracy: 91.41
Batch accuracy: 93.36
Batch accuracy: 93.16
Batch accuracy: 93.36
Batch accuracy: 93.36
Batch accuracy: 92.77
Batch accuracy: 94.34
Batch accuracy: 93.95
Batch accuracy: 91.80
Batch accuracy: 90.23
Batch accuracy: 92.58
Batch accuracy: 93.55
Batch accuracy: 92.97
Batch accuracy: 93.95
Batch accuracy: 93.75
Batch accuracy: 93.55
Batch accuracy: 91.21
Batch accuracy: 91.99
Batch accuracy: 92.58
Batch accuracy: 94.14
Batch accuracy: 

Batch accuracy: 94.73
Batch accuracy: 93.75
Batch accuracy: 92.77
Batch accuracy: 93.36
Batch accuracy: 95.90
Batch accuracy: 92.97
Batch accuracy: 91.41
Batch accuracy: 93.75
Batch accuracy: 94.14
Batch accuracy: 93.95
Batch accuracy: 93.95
Batch accuracy: 92.38
Batch accuracy: 93.36
Batch accuracy: 93.36
Batch accuracy: 92.97
Batch accuracy: 92.58
Batch accuracy: 93.75
Batch accuracy: 94.14
Batch accuracy: 92.19
Batch accuracy: 92.58
Batch accuracy: 91.80
Batch accuracy: 92.97
Batch accuracy: 93.95
Batch accuracy: 94.34
Batch accuracy: 93.75
Batch accuracy: 91.41
Batch accuracy: 94.53
Batch accuracy: 92.58
Batch accuracy: 93.75
Batch accuracy: 92.58
Batch accuracy: 92.38
Batch accuracy: 92.97
Batch accuracy: 91.02
Batch accuracy: 93.55
Batch accuracy: 92.19
Batch accuracy: 93.95
Batch accuracy: 93.36
Batch accuracy: 95.31
Batch accuracy: 94.14
Batch accuracy: 91.80
Batch accuracy: 93.75
Batch accuracy: 94.53
Batch accuracy: 91.21
Batch accuracy: 94.92
Batch accuracy: 93.55
Batch accu

Batch accuracy: 94.34
Batch accuracy: 94.53
Batch accuracy: 94.53
Batch accuracy: 92.38
Batch accuracy: 91.60
Batch accuracy: 92.77
Batch accuracy: 92.58
Batch accuracy: 92.58
Batch accuracy: 93.75
Batch accuracy: 94.34
Batch accuracy: 95.31
Batch accuracy: 94.53
Batch accuracy: 93.36
Batch accuracy: 94.34
Batch accuracy: 92.58
Batch accuracy: 91.80
Batch accuracy: 95.39
  Accuracy: 93.71
New best model, saving it!

Training...
  Batch    40  of    943.    .
  step loss: 0.05
  Batch    80  of    943.    .
  step loss: 0.04
  Batch   120  of    943.    .
  step loss: 0.07
  Batch   160  of    943.    .
  step loss: 0.05
  Batch   200  of    943.    .
  step loss: 0.06
  Batch   240  of    943.    .
  step loss: 0.04
  Batch   280  of    943.    .
  step loss: 0.05
  Batch   320  of    943.    .
  step loss: 0.09
  Batch   360  of    943.    .
  step loss: 0.06
  Batch   400  of    943.    .
  step loss: 0.06
  Batch   440  of    943.    .
  step loss: 0.07
  Batch   480  of    943.    

  step loss: 0.05
  Batch   520  of    943.    .
  step loss: 0.05
  Batch   560  of    943.    .
  step loss: 0.03
  Batch   600  of    943.    .
  step loss: 0.03
  Batch   640  of    943.    .
  step loss: 0.06
  Batch   680  of    943.    .
  step loss: 0.04
  Batch   720  of    943.    .
  step loss: 0.06
  Batch   760  of    943.    .
  step loss: 0.07
  Batch   800  of    943.    .
  step loss: 0.02
  Batch   840  of    943.    .
  step loss: 0.05
  Batch   880  of    943.    .
  step loss: 0.08
  Batch   920  of    943.    .
  step loss: 0.07

  Average training loss: 0.05

Running Validation...
Batch accuracy: 93.75
Batch accuracy: 91.21
Batch accuracy: 93.36
Batch accuracy: 93.55
Batch accuracy: 93.55
Batch accuracy: 94.53
Batch accuracy: 94.73
Batch accuracy: 94.53
Batch accuracy: 93.55
Batch accuracy: 94.73
Batch accuracy: 94.34
Batch accuracy: 94.34
Batch accuracy: 94.73
Batch accuracy: 94.14
Batch accuracy: 94.34
Batch accuracy: 92.97
Batch accuracy: 94.34
Batch accuracy:

Batch accuracy: 95.12
Batch accuracy: 93.16
Batch accuracy: 94.34
Batch accuracy: 94.53
Batch accuracy: 93.36
Batch accuracy: 92.77
Batch accuracy: 93.16
Batch accuracy: 94.73
Batch accuracy: 95.70
Batch accuracy: 96.29
Batch accuracy: 94.14
Batch accuracy: 95.31
Batch accuracy: 95.31
Batch accuracy: 92.38
Batch accuracy: 95.70
Batch accuracy: 94.53
Batch accuracy: 94.34
Batch accuracy: 93.16
Batch accuracy: 94.53
Batch accuracy: 92.97
Batch accuracy: 94.73
Batch accuracy: 95.31
Batch accuracy: 94.34
Batch accuracy: 96.09
Batch accuracy: 94.34
Batch accuracy: 93.95
Batch accuracy: 94.14
Batch accuracy: 94.34
Batch accuracy: 94.92
Batch accuracy: 93.55
Batch accuracy: 93.55
Batch accuracy: 94.53
Batch accuracy: 90.82
Batch accuracy: 95.12
Batch accuracy: 93.95
Batch accuracy: 92.19
Batch accuracy: 93.55
Batch accuracy: 95.51
Batch accuracy: 92.38
Batch accuracy: 92.97
Batch accuracy: 94.92
Batch accuracy: 95.51
Batch accuracy: 91.21
Batch accuracy: 94.14
Batch accuracy: 93.16
Batch accu

Batch accuracy: 94.34
Batch accuracy: 93.75
Batch accuracy: 95.51
Batch accuracy: 93.75
Batch accuracy: 92.97
Batch accuracy: 94.14
Batch accuracy: 94.92
Batch accuracy: 93.95
Batch accuracy: 92.97
Batch accuracy: 94.34
Batch accuracy: 95.12
Batch accuracy: 94.34
Batch accuracy: 95.12
Batch accuracy: 93.75
Batch accuracy: 94.14
Batch accuracy: 94.73
Batch accuracy: 94.14
Batch accuracy: 93.55
Batch accuracy: 92.19
Batch accuracy: 94.34
Batch accuracy: 95.70
Batch accuracy: 94.14
Batch accuracy: 94.14
Batch accuracy: 94.92
Batch accuracy: 95.70
Batch accuracy: 92.97
Batch accuracy: 94.34
Batch accuracy: 95.51
Batch accuracy: 95.90
Batch accuracy: 93.95
Batch accuracy: 93.75
Batch accuracy: 94.53
Batch accuracy: 94.14
Batch accuracy: 94.73
Batch accuracy: 94.14
Batch accuracy: 94.53
Batch accuracy: 94.14
Batch accuracy: 94.92
Batch accuracy: 94.53
Batch accuracy: 91.64
  Accuracy: 94.22
New best model, saving it!
Epoch    14: reducing learning rate of group 0 to 1.0000e-05.

Training...


  Batch    40  of    943.    .
  step loss: 0.01
  Batch    80  of    943.    .
  step loss: 0.01
  Batch   120  of    943.    .
  step loss: 0.01
  Batch   160  of    943.    .
  step loss: 0.01
  Batch   200  of    943.    .
  step loss: 0.01
  Batch   240  of    943.    .
  step loss: 0.02
  Batch   280  of    943.    .
  step loss: 0.01
  Batch   320  of    943.    .
  step loss: 0.01
  Batch   360  of    943.    .
  step loss: 0.01
  Batch   400  of    943.    .
  step loss: 0.01
  Batch   440  of    943.    .
  step loss: 0.01
  Batch   480  of    943.    .
  step loss: 0.01
  Batch   520  of    943.    .
  step loss: 0.01
  Batch   560  of    943.    .
  step loss: 0.02
  Batch   600  of    943.    .
  step loss: 0.02
  Batch   640  of    943.    .
  step loss: 0.02
  Batch   680  of    943.    .
  step loss: 0.01
  Batch   720  of    943.    .
  step loss: 0.01
  Batch   760  of    943.    .
  step loss: 0.00
  Batch   800  of    943.    .
  step loss: 0.02
  Batch   840  of   

Batch accuracy: 94.14
Batch accuracy: 94.73
Batch accuracy: 93.75
Batch accuracy: 94.73
Batch accuracy: 95.51
Batch accuracy: 94.53
Batch accuracy: 95.12
Batch accuracy: 94.53
Batch accuracy: 94.53
Batch accuracy: 94.73
Batch accuracy: 92.97
Batch accuracy: 96.88
Batch accuracy: 94.34
Batch accuracy: 93.55
Batch accuracy: 94.53
Batch accuracy: 93.55
Batch accuracy: 94.14
Batch accuracy: 93.16
Batch accuracy: 93.55
Batch accuracy: 92.97
Batch accuracy: 94.73
Batch accuracy: 92.77
Batch accuracy: 93.16
Batch accuracy: 95.12
Batch accuracy: 92.97
Batch accuracy: 94.73
Batch accuracy: 94.92
Batch accuracy: 93.55
Batch accuracy: 94.14
Batch accuracy: 95.12
Batch accuracy: 93.95
Batch accuracy: 94.53
Batch accuracy: 94.34
Batch accuracy: 93.75
Batch accuracy: 94.14
Batch accuracy: 93.36
Batch accuracy: 95.12
Batch accuracy: 94.73
Batch accuracy: 94.14
Batch accuracy: 92.77
Batch accuracy: 94.14
Batch accuracy: 95.12
Batch accuracy: 94.34
Batch accuracy: 95.90
Batch accuracy: 95.31
Batch accu

Batch accuracy: 95.51
Batch accuracy: 93.55
Batch accuracy: 95.70
Batch accuracy: 92.38
Batch accuracy: 95.31
Batch accuracy: 94.81
  Accuracy: 94.24
The model does not improve for 2 epochs!

Training...
  Batch    40  of    943.    .
  step loss: 0.01
  Batch    80  of    943.    .
  step loss: 0.00
  Batch   120  of    943.    .
  step loss: 0.01
  Batch   160  of    943.    .
  step loss: 0.02
  Batch   200  of    943.    .
  step loss: 0.01
  Batch   240  of    943.    .
  step loss: 0.01
  Batch   280  of    943.    .
  step loss: 0.02
  Batch   320  of    943.    .
  step loss: 0.01
  Batch   360  of    943.    .
  step loss: 0.01
  Batch   400  of    943.    .
  step loss: 0.01
  Batch   440  of    943.    .
  step loss: 0.01
  Batch   480  of    943.    .
  step loss: 0.02
  Batch   520  of    943.    .
  step loss: 0.01
  Batch   560  of    943.    .
  step loss: 0.01
  Batch   600  of    943.    .
  step loss: 0.02
  Batch   640  of    943.    .
  step loss: 0.02
  Batch   68

  Batch   640  of    943.    .
  step loss: 0.01
  Batch   680  of    943.    .
  step loss: 0.01
  Batch   720  of    943.    .
  step loss: 0.01
  Batch   760  of    943.    .
  step loss: 0.01
  Batch   800  of    943.    .
  step loss: 0.01
  Batch   840  of    943.    .
  step loss: 0.01
  Batch   880  of    943.    .
  step loss: 0.01
  Batch   920  of    943.    .
  step loss: 0.01

  Average training loss: 0.01

Running Validation...
Batch accuracy: 93.16
Batch accuracy: 92.58
Batch accuracy: 94.73
Batch accuracy: 95.31
Batch accuracy: 93.95
Batch accuracy: 93.95
Batch accuracy: 94.53
Batch accuracy: 94.92
Batch accuracy: 95.12
Batch accuracy: 94.34
Batch accuracy: 94.53
Batch accuracy: 95.70
Batch accuracy: 94.34
Batch accuracy: 94.14
Batch accuracy: 93.75
Batch accuracy: 93.75
Batch accuracy: 95.90
Batch accuracy: 94.73
Batch accuracy: 94.14
Batch accuracy: 93.16
Batch accuracy: 94.34
Batch accuracy: 94.73
Batch accuracy: 92.97
Batch accuracy: 94.73
Batch accuracy: 93.16
Batc

In [19]:
def predict(
    model_path: Path,
    dataset: str,
    batch_size: int,
    labels_dict,
    device: torch.device
):

    print(f"====Loading dataset for testing")
    test_corpus = NLPDataset(dataset, "test", sentence_max_len, labels_dict = labels_dict, vocab= voc)

    test_dataloader = DataLoader(
        test_corpus,
        batch_size=batch_size,
        #sampler = RandomSampler(test_corpus),
        pin_memory=True,
        num_workers=0,
        drop_last=False
    )

    print(f"====Loading model for testing")
    model = torch.load(join(model_path, "best-model.pth"))
    model.to(device)
    model.eval()
    pred_labels = []
    test_labels = []
    logits_list = []

    def _list_from_tensor(tensor):
        if tensor.numel() == 1:
            return [tensor.item()]
        return list(tensor.cpu().detach().numpy())

    print("====Testing model...")
    for batch in test_dataloader:
        b_input_ids = batch[0]
        b_input_ids = torch.tensor(b_input_ids).to(device).long()
        b_labels = batch[1].to(device)
        with torch.no_grad():
            loss, logits = model(b_input_ids, b_labels)
            preds = np.argmax(logits.cpu(), axis=1) # Convert one-hot to index
            b_labels = b_labels.int()
            pred_labels.extend(_list_from_tensor(preds))
            test_labels.extend(_list_from_tensor(b_labels))
        logits_list.extend(_list_from_tensor(logits))

    print(classification_report(test_labels, pred_labels, labels=list(labels_dict.values()), target_names=np.array(list(labels_dict.keys())), digits=3, output_dict=False))
    logits_list = expit(logits_list)
    del model
    torch.cuda.empty_cache()

In [20]:
del train_dataloader
del validation_dataloader
del train_corpus
del val_corpus
del model
torch.cuda.empty_cache()

In [21]:
predict(last_saved_model, dataset, batch_size, labels_dict, device)

====Loading dataset for testing
====Loading model for testing
====Testing model...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                                                  precision    recall  f1-score   support

                              Sintomas.Genérico.Sky não funciona      0.930     0.935     0.932      8357
                                    Sintomas.Genérico.Instalação      0.988     0.978     0.983       647
                                Sintomas.Genérico.Canal não pega      0.922     0.937     0.930      5967
                    Sintomas.Genérico.Equipamento não funciona G      0.946     0.962     0.954      2318
                                     Sintomas.Genérico.Sem sinal      0.958     0.960     0.959     14552
                               Sintomas.Qualificado.Cancelamento      0.934     0.959     0.946      1847
                           Sintomas.Qualificado.Outros problemas      0.824     0.757     0.789       729
                              Sintomas.Qualificado.NãoTéc_fatura      0.880     0.881     0.881      1451
                          Sintomas.Qualificad