In [1]:
from os.path import join
import os
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.init import kaiming_uniform_
from torch.optim import SGD
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.optim import Adam
from nlpClassifiers.data.dataset  import NLPDataset, Vocabulary
from sklearn.metrics import roc_curve, auc, classification_report
from scipy.special import expit
import gensim
import shutil
from pathlib import Path
import random

In [2]:
ROOT = '../../'
PATH_TO_VIRTUAL_OPERATOR_DATA = join(ROOT, "data/virtual-operator")
PATH_TO_AGENT_BENCHMARK_DATA = join(ROOT, "data/agent-benchmark")
PATH_TO_ML_PT_DATA = join(ROOT, "data/mercado-livre-pt-only")

PATH_TO_VIRTUAL_OPERATOR_MODELS = join(ROOT, "models/virtual-operator")
PATH_TO_AGENT_BENCHMARK_MODELS = join(ROOT, "models/agent-benchmark")
PATH_TO_ML_PT_MODELS = join(ROOT, "models/mercado-livre-pt-only")

EMBEDDING_DIM = 300
EMBEDDINGS_FILE = '../../data/fasttext-embeddings/wiki.pt.vec'
MAX_WORDS = 30000
dataset = 'virtual-operator'
sentence_max_len = 82
batch_size = 512

epochs=30
stopwords_lang = None
gpu=1
lr=0.001
clipping_value=0.25
save_name='cnn-pytorch-fasttext-embeddings'
patience=5

In [3]:
BASE_PATH_TO_MODELS = {"virtual-operator": PATH_TO_VIRTUAL_OPERATOR_MODELS, "agent-benchmark": PATH_TO_AGENT_BENCHMARK_MODELS, "mercado-livre-pt": PATH_TO_ML_PT_MODELS}
FULL_PATH_TO_MODELS = join(BASE_PATH_TO_MODELS[dataset], "bow-classifier")

In [4]:
def _init_fn(worker_id):
    np.random.seed(seed)

In [5]:
def read_data(dataset, subset):
    BASE_PATH_TO_DATASET = {"virtual-operator": PATH_TO_VIRTUAL_OPERATOR_DATA, "agent-benchmark": PATH_TO_AGENT_BENCHMARK_DATA, "mercado-livre-pt": PATH_TO_ML_PT_DATA}
    BASE_PATH_TO_DATASET = {"train": join(BASE_PATH_TO_DATASET[dataset], "train.csv"), "val": join(BASE_PATH_TO_DATASET[dataset], "val.csv"), "test": join(BASE_PATH_TO_DATASET[dataset], "test.csv")}
    FULL_PATH_TO_DATASET = BASE_PATH_TO_DATASET[subset]
    
    if dataset == "mercado-livre-pt":
        sep=","
    else:
        sep=";"
    data = pd.read_csv(FULL_PATH_TO_DATASET, sep=sep, names =['utterance','label'], header=None, dtype={'utterance':str, 'label': str} )
    return data

In [6]:
train_df = read_data(dataset, "train")
val_df = read_data(dataset, "val")

In [7]:
voc = Vocabulary('CNN', stopwords_lang)
voc.build_vocab(train_df['utterance'].tolist() +  val_df['utterance'].tolist(), MAX_WORDS)
embedding_weights = voc.load_embeddings(EMBEDDINGS_FILE, EMBEDDING_DIM)
if voc.num_words < MAX_WORDS or MAX_WORDS == 0:
    MAX_WORDS = voc.num_words

4456 tokens not found in vocabulary.


In [8]:
train_corpus = NLPDataset(dataset, "train", sentence_max_len, vocab= voc)
labels_dict = train_corpus.labels_dict
val_corpus = NLPDataset(dataset, "val", sentence_max_len, labels_dict = labels_dict, vocab= voc)

In [9]:
train_dataloader = DataLoader(
            train_corpus,
            sampler = RandomSampler(train_corpus),
            batch_size = batch_size,
            pin_memory=True,
            worker_init_fn=_init_fn,
            num_workers=0
)

validation_dataloader = DataLoader(
            val_corpus,
            sampler = RandomSampler(val_corpus),
            batch_size = batch_size,
            pin_memory=True,
            worker_init_fn=_init_fn,
            num_workers=0
)

In [10]:
def initialize_weights(model):
    if type(model) in [nn.Linear]:
        nn.init.xavier_normal_(model.weight.data)
    elif type(model) in [nn.LSTM, nn.RNN, nn.GRU]:
        nn.init.xavier_normal_(model.weight_hh_l0)
        nn.init.xavier_normal_(model.weight_ih_l0)

In [11]:
#del model
torch.cuda.empty_cache()

In [12]:
def seed_everything(seed=10):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [13]:
seed_everything()

In [14]:
class CNNNet(nn.Module):
    def __init__(self, num_classes, num_features, criterion, embedding_dim, embedding_weights=None):
        super(CNNNet, self).__init__()
        self.criterion = criterion
        self.embedding_dim = embedding_dim
        self.n_layers = 1
        self.embedding = nn.Embedding(num_features, embedding_dim, padding_idx = 0)
        if(embedding_weights is not None):
            print("Embedding layer Weights won't be updated.")
            self.embedding.load_state_dict({'weight': embedding_weights})
            self.embedding.weight.requires_grad = False
            #self.embedding.from_pretrained(embedding_weights, freeze=True)
        else:
            self.embedding.weight.data = torch.zeros(self.embedding.weight.data.size())
            self.embedding.weight.requires_grad = True
        
        self.cnn = nn.Conv1d(embedding_dim, 256, 4)
        
        for name, param in self.cnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0)
            elif 'weight_ih' in name:
                nn.init.kaiming_normal_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
             
        self.dropout = nn.Dropout(0.2)       
        self.out = nn.Linear(256, num_classes)
        self.init_weights()
    def forward(self, x, y):
        # Conv1d takes in (batch, channels, seq_len), but raw embedded is (batch, seq_len, channels)
        h_embedding = self.embedding(x)
        h_embedding = h_embedding.permute(0, 2, 1)
        h_cnn = self.cnn(h_embedding)
        h_cnn = h_cnn.permute(0, 2, 1)
        h_cnn = self.dropout(h_cnn)
        h_max_pool = self.global_max_pool(h_cnn)
        x = self.out(h_max_pool)
        loss = self.criterion(x, y)
        # return the final output
        return loss,x 
    
    @staticmethod
    def global_max_pool(x):
        """Convolution and global max pooling layer"""
        return x.max(1)[0]
    
    def init_weights(self):
        """
        Here we reproduce Keras default initialization weights to initialize Embeddings/LSTM weights
        """
        ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
        hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
        b = (param.data for name, param in self.named_parameters() if 'bias' in name)
        for t in ih:
            nn.init.xavier_uniform(t)
        for t in hh:
            nn.init.orthogonal(t)
        for t in b:
            nn.init.constant(t, 0)
    
    def init_hidden(self, batch_size, device):
        weight = next(self.parameters()).data
        hidden = (weight.new(self.n_layers, batch_size, self.embedding_dim).zero_().to(device),
                      weight.new(self.n_layers, batch_size, self.embedding_dim).zero_().to(device))
        return hidden

In [15]:

criterion = torch.nn.CrossEntropyLoss()
model = CNNNet(train_corpus.num_labels, MAX_WORDS, criterion, EMBEDDING_DIM, embedding_weights)

Embedding layer Weights won't be updated.




In [16]:
model.apply(initialize_weights)

CNNNet(
  (criterion): CrossEntropyLoss()
  (embedding): Embedding(22418, 300, padding_idx=0)
  (cnn): Conv1d(300, 256, kernel_size=(4,), stride=(1,))
  (dropout): Dropout(p=0.2, inplace=False)
  (out): Linear(in_features=256, out_features=121, bias=True)
)

In [17]:
device = torch.device(f"cuda:{gpu}")

In [18]:
model = model.to(device)
optimizer = Adam(model.parameters(),lr, betas=(0.7, 0.999))

In [19]:
def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    acc = acc * 100.0

    return acc

In [20]:
def get_accuracy_from_logits(logits, labels):
    acc = (labels.cpu() == logits.cpu().argmax(-1)).float().detach().numpy()
    return float(100 * acc.sum() / len(acc))

In [21]:
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3,verbose=True, factor=0.1)

best_epoch = -1
last_saved_model = ""
training_stats = []
global_step = 0
best_val_acc = float("-inf")
best_model_wts = None
best_curr_val = 0

for epoch_i in range(0, epochs):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')
    total_train_loss = 0
    model.train()
    for step, batch in enumerate(train_dataloader):
        #val_h = model.init_hidden(len(batch[0]), device)
        #val_h = tuple([each.data for each in val_h])
         # Progress update every 40 batches.
        if step % 40 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.    .'.format(step, len(train_dataloader)))
        b_input_ids = batch[0]
        b_labels = batch[1].to(device)
        b_input_ids = torch.tensor(b_input_ids).to(device).long()
        model.zero_grad()
        loss, logits = model(b_input_ids, b_labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
        optimizer.step()
        step_loss = loss.item()
        if step % 40 == 0 and not step == 0:
            print("  step loss: {0:.2f}".format(step_loss))
        total_train_loss += step_loss
        global_step += 1
    avg_train_loss = total_train_loss / len(train_dataloader)
    # Measure how long this epoch took.
    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))

    print("")
    print("Running Validation...")
    model.eval()

    # Tracking variables
    total_eval_accuracy = 0.0
    total_eval_loss = 0
    nb_eval_steps = 0
    for step, batch in enumerate(validation_dataloader):
        #val_h = model.init_hidden(len(batch[0]), device)
        #val_h = tuple([each.data for each in val_h])
        b_input_ids = batch[0]
        b_input_ids = torch.tensor(b_input_ids).to(device).long()
        b_labels = batch[1].to(device)
        with torch.no_grad():
            # Forward pass, calculate logit predictions.
            loss, logits = model(b_input_ids, b_labels)
        total_eval_loss += loss
        batch_acc = get_accuracy_from_logits(logits, b_labels)
        batch_acc2 = multi_acc(logits, b_labels)
        print("Batch accuracy: {0:.2f}".format(batch_acc))
        total_eval_accuracy += batch_acc
    
    # Calculate the average loss over all of the batches.
    avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))
    
    if avg_val_accuracy > best_val_acc:
        print(f"New best model, saving it!")
        if avg_val_accuracy > best_curr_val:
            best_curr_val = avg_val_accuracy
        if last_saved_model:
            shutil.rmtree(last_saved_model)
        model_path = Path(
            FULL_PATH_TO_MODELS,
            f"base-dataset-{dataset}-{save_name}"
        )
        last_saved_model = model_path
        model_path.mkdir(parents=True, exist_ok=True)
        best_val_acc = avg_val_accuracy
        torch.save(model, join(model_path, "best-model.pth"))
        best_epoch = epoch_i
        n_epochs_no_improvement = 0
    elif avg_val_accuracy > best_curr_val:
        best_curr_val = avg_val_accuracy
        n_epochs_no_improvement = 0
    else:
        n_epochs_no_improvement += 1
        print(f"The model does not improve for {n_epochs_no_improvement} epochs!")

    if n_epochs_no_improvement > patience:
        print(f"====>Stopping training, the model did not improve for {n_epochs_no_improvement}\n====>Best epoch: {best_epoch + 1}.")
        break
    scheduler.step(total_eval_loss)
print("")
print("Training complete!")
    
    


Training...




  Batch    40  of    943.    .
  step loss: 1.62
  Batch    80  of    943.    .
  step loss: 1.15
  Batch   120  of    943.    .
  step loss: 1.03
  Batch   160  of    943.    .
  step loss: 0.97
  Batch   200  of    943.    .
  step loss: 0.89
  Batch   240  of    943.    .
  step loss: 0.85
  Batch   280  of    943.    .
  step loss: 0.75
  Batch   320  of    943.    .
  step loss: 0.77
  Batch   360  of    943.    .
  step loss: 0.59
  Batch   400  of    943.    .
  step loss: 0.67
  Batch   440  of    943.    .
  step loss: 0.62
  Batch   480  of    943.    .
  step loss: 0.55
  Batch   520  of    943.    .
  step loss: 0.52
  Batch   560  of    943.    .
  step loss: 0.61
  Batch   600  of    943.    .
  step loss: 0.59
  Batch   640  of    943.    .
  step loss: 0.60
  Batch   680  of    943.    .
  step loss: 0.47
  Batch   720  of    943.    .
  step loss: 0.49
  Batch   760  of    943.    .
  step loss: 0.46
  Batch   800  of    943.    .
  step loss: 0.50
  Batch   840  of   



Batch accuracy: 86.52
Batch accuracy: 88.87
Batch accuracy: 86.13
Batch accuracy: 87.30
Batch accuracy: 86.13
Batch accuracy: 85.16
Batch accuracy: 87.30
Batch accuracy: 87.70
Batch accuracy: 90.82
Batch accuracy: 90.82
Batch accuracy: 88.67
Batch accuracy: 86.52
Batch accuracy: 88.28
Batch accuracy: 87.30
Batch accuracy: 89.06
Batch accuracy: 86.13
Batch accuracy: 88.09
Batch accuracy: 89.06
Batch accuracy: 86.72
Batch accuracy: 89.45
Batch accuracy: 88.09
Batch accuracy: 86.72
Batch accuracy: 87.50
Batch accuracy: 90.04
Batch accuracy: 87.89
Batch accuracy: 87.70
Batch accuracy: 86.72
Batch accuracy: 86.52
Batch accuracy: 87.30
Batch accuracy: 86.72
Batch accuracy: 86.52
Batch accuracy: 87.50
Batch accuracy: 89.06
Batch accuracy: 86.72
Batch accuracy: 84.38
Batch accuracy: 87.50
Batch accuracy: 87.11
Batch accuracy: 87.30
Batch accuracy: 90.04
Batch accuracy: 85.74
Batch accuracy: 86.33
Batch accuracy: 87.70
Batch accuracy: 88.28
Batch accuracy: 88.67
Batch accuracy: 90.04
Batch accu

Batch accuracy: 90.62
Batch accuracy: 90.62
Batch accuracy: 93.16
Batch accuracy: 88.87
Batch accuracy: 91.80
Batch accuracy: 90.23
Batch accuracy: 91.41
Batch accuracy: 88.67
Batch accuracy: 90.04
Batch accuracy: 90.43
Batch accuracy: 89.26
Batch accuracy: 89.06
Batch accuracy: 88.67
Batch accuracy: 90.04
Batch accuracy: 89.65
Batch accuracy: 89.45
Batch accuracy: 91.41
Batch accuracy: 90.04
Batch accuracy: 89.65
Batch accuracy: 89.45
Batch accuracy: 88.87
Batch accuracy: 92.19
Batch accuracy: 90.23
Batch accuracy: 88.67
Batch accuracy: 90.82
Batch accuracy: 90.23
Batch accuracy: 88.87
Batch accuracy: 88.87
Batch accuracy: 90.82
Batch accuracy: 87.70
Batch accuracy: 89.65
Batch accuracy: 92.19
Batch accuracy: 91.21
Batch accuracy: 88.28
Batch accuracy: 90.62
Batch accuracy: 90.82
Batch accuracy: 88.87
Batch accuracy: 90.04
Batch accuracy: 92.38
Batch accuracy: 89.06
Batch accuracy: 90.04
Batch accuracy: 90.23
Batch accuracy: 88.09
Batch accuracy: 90.23
Batch accuracy: 88.87
Batch accu

Batch accuracy: 91.60
Batch accuracy: 91.80
Batch accuracy: 90.62
Batch accuracy: 90.20
  Accuracy: 91.10
New best model, saving it!

Training...
  Batch    40  of    943.    .
  step loss: 0.25
  Batch    80  of    943.    .
  step loss: 0.24
  Batch   120  of    943.    .
  step loss: 0.33
  Batch   160  of    943.    .
  step loss: 0.28
  Batch   200  of    943.    .
  step loss: 0.30
  Batch   240  of    943.    .
  step loss: 0.28
  Batch   280  of    943.    .
  step loss: 0.23
  Batch   320  of    943.    .
  step loss: 0.26
  Batch   360  of    943.    .
  step loss: 0.29
  Batch   400  of    943.    .
  step loss: 0.28
  Batch   440  of    943.    .
  step loss: 0.23
  Batch   480  of    943.    .
  step loss: 0.24
  Batch   520  of    943.    .
  step loss: 0.32
  Batch   560  of    943.    .
  step loss: 0.31
  Batch   600  of    943.    .
  step loss: 0.30
  Batch   640  of    943.    .
  step loss: 0.30
  Batch   680  of    943.    .
  step loss: 0.29
  Batch   720  of    

  Batch   760  of    943.    .
  step loss: 0.32
  Batch   800  of    943.    .
  step loss: 0.31
  Batch   840  of    943.    .
  step loss: 0.19
  Batch   880  of    943.    .
  step loss: 0.23
  Batch   920  of    943.    .
  step loss: 0.32

  Average training loss: 0.24

Running Validation...
Batch accuracy: 92.97
Batch accuracy: 91.21
Batch accuracy: 90.62
Batch accuracy: 91.21
Batch accuracy: 90.04
Batch accuracy: 91.99
Batch accuracy: 93.75
Batch accuracy: 91.60
Batch accuracy: 93.36
Batch accuracy: 93.75
Batch accuracy: 90.62
Batch accuracy: 90.82
Batch accuracy: 92.97
Batch accuracy: 94.14
Batch accuracy: 92.97
Batch accuracy: 91.80
Batch accuracy: 92.38
Batch accuracy: 92.19
Batch accuracy: 92.38
Batch accuracy: 90.62
Batch accuracy: 92.19
Batch accuracy: 91.02
Batch accuracy: 92.97
Batch accuracy: 89.45
Batch accuracy: 92.58
Batch accuracy: 91.60
Batch accuracy: 91.80
Batch accuracy: 91.60
Batch accuracy: 90.43
Batch accuracy: 91.80
Batch accuracy: 91.41
Batch accuracy: 91.

Batch accuracy: 93.16
Batch accuracy: 90.23
Batch accuracy: 90.23
Batch accuracy: 91.02
Batch accuracy: 91.80
Batch accuracy: 92.38
Batch accuracy: 92.58
Batch accuracy: 91.41
Batch accuracy: 91.02
Batch accuracy: 92.77
Batch accuracy: 92.58
Batch accuracy: 92.38
Batch accuracy: 90.43
Batch accuracy: 90.23
Batch accuracy: 92.77
Batch accuracy: 93.55
Batch accuracy: 93.55
Batch accuracy: 91.99
Batch accuracy: 93.16
Batch accuracy: 90.62
Batch accuracy: 92.97
Batch accuracy: 91.80
Batch accuracy: 90.23
Batch accuracy: 91.99
Batch accuracy: 91.41
Batch accuracy: 91.60
Batch accuracy: 90.04
Batch accuracy: 92.38
Batch accuracy: 89.26
Batch accuracy: 92.19
Batch accuracy: 93.36
Batch accuracy: 93.36
Batch accuracy: 89.26
Batch accuracy: 91.60
Batch accuracy: 91.99
Batch accuracy: 91.80
Batch accuracy: 94.14
Batch accuracy: 91.21
Batch accuracy: 90.62
Batch accuracy: 92.77
Batch accuracy: 91.41
Batch accuracy: 91.21
Batch accuracy: 93.16
Batch accuracy: 93.36
Batch accuracy: 90.04
Batch accu

Batch accuracy: 91.60
Batch accuracy: 92.77
Batch accuracy: 93.16
Batch accuracy: 92.58
Batch accuracy: 93.36
Batch accuracy: 89.26
Batch accuracy: 91.21
Batch accuracy: 91.41
Batch accuracy: 91.02
Batch accuracy: 90.82
Batch accuracy: 90.82
Batch accuracy: 92.38
Batch accuracy: 90.43
Batch accuracy: 91.99
Batch accuracy: 94.34
Batch accuracy: 92.97
Batch accuracy: 90.04
Batch accuracy: 92.58
Batch accuracy: 91.99
Batch accuracy: 91.60
Batch accuracy: 91.60
Batch accuracy: 93.08
  Accuracy: 91.93
New best model, saving it!

Training...
  Batch    40  of    943.    .
  step loss: 0.17
  Batch    80  of    943.    .
  step loss: 0.22
  Batch   120  of    943.    .
  step loss: 0.19
  Batch   160  of    943.    .
  step loss: 0.22
  Batch   200  of    943.    .
  step loss: 0.26
  Batch   240  of    943.    .
  step loss: 0.23
  Batch   280  of    943.    .
  step loss: 0.17
  Batch   320  of    943.    .
  step loss: 0.20
  Batch   360  of    943.    .
  step loss: 0.17
  Batch   400  of

  Batch   440  of    943.    .
  step loss: 0.24
  Batch   480  of    943.    .
  step loss: 0.21
  Batch   520  of    943.    .
  step loss: 0.20
  Batch   560  of    943.    .
  step loss: 0.17
  Batch   600  of    943.    .
  step loss: 0.13
  Batch   640  of    943.    .
  step loss: 0.16
  Batch   680  of    943.    .
  step loss: 0.18
  Batch   720  of    943.    .
  step loss: 0.16
  Batch   760  of    943.    .
  step loss: 0.21
  Batch   800  of    943.    .
  step loss: 0.26
  Batch   840  of    943.    .
  step loss: 0.19
  Batch   880  of    943.    .
  step loss: 0.20
  Batch   920  of    943.    .
  step loss: 0.16

  Average training loss: 0.18

Running Validation...
Batch accuracy: 91.41
Batch accuracy: 93.75
Batch accuracy: 89.65
Batch accuracy: 92.58
Batch accuracy: 91.02
Batch accuracy: 92.97
Batch accuracy: 91.60
Batch accuracy: 89.84
Batch accuracy: 89.06
Batch accuracy: 90.43
Batch accuracy: 92.19
Batch accuracy: 90.43
Batch accuracy: 92.58
Batch accuracy: 93.75
B

Batch accuracy: 93.16
Batch accuracy: 92.38
Batch accuracy: 91.02
Batch accuracy: 91.99
Batch accuracy: 93.75
Batch accuracy: 90.82
Batch accuracy: 89.45
Batch accuracy: 91.02
Batch accuracy: 91.60
Batch accuracy: 93.95
Batch accuracy: 91.21
Batch accuracy: 93.75
Batch accuracy: 93.16
Batch accuracy: 92.77
Batch accuracy: 92.19
Batch accuracy: 93.36
Batch accuracy: 91.21
Batch accuracy: 95.12
Batch accuracy: 91.99
Batch accuracy: 92.58
Batch accuracy: 92.19
Batch accuracy: 93.36
Batch accuracy: 91.80
Batch accuracy: 90.23
Batch accuracy: 90.43
Batch accuracy: 92.77
Batch accuracy: 92.19
Batch accuracy: 90.62
Batch accuracy: 90.04
Batch accuracy: 89.65
Batch accuracy: 91.60
Batch accuracy: 92.58
Batch accuracy: 94.53
Batch accuracy: 92.97
Batch accuracy: 92.19
Batch accuracy: 89.84
Batch accuracy: 91.99
Batch accuracy: 91.02
Batch accuracy: 92.97
Batch accuracy: 91.80
Batch accuracy: 91.21
Batch accuracy: 90.43
Batch accuracy: 92.77
Batch accuracy: 92.19
Batch accuracy: 91.99
Batch accu

Batch accuracy: 92.58
Batch accuracy: 91.21
Batch accuracy: 93.75
Batch accuracy: 90.04
Batch accuracy: 92.77
Batch accuracy: 91.80
Batch accuracy: 92.97
Batch accuracy: 92.19
Batch accuracy: 92.97
Batch accuracy: 92.58
Batch accuracy: 90.82
Batch accuracy: 89.84
Batch accuracy: 91.60
Batch accuracy: 93.36
Batch accuracy: 91.21
Batch accuracy: 93.75
Batch accuracy: 89.26
Batch accuracy: 92.19
Batch accuracy: 93.75
Batch accuracy: 90.62
Batch accuracy: 90.43
Batch accuracy: 92.58
Batch accuracy: 91.02
Batch accuracy: 92.97
Batch accuracy: 90.82
Batch accuracy: 93.36
Batch accuracy: 93.95
Batch accuracy: 92.58
Batch accuracy: 92.77
Batch accuracy: 92.97
Batch accuracy: 91.99
Batch accuracy: 91.80
Batch accuracy: 91.02
Batch accuracy: 93.16
Batch accuracy: 91.99
Batch accuracy: 93.16
Batch accuracy: 90.82
Batch accuracy: 91.99
Batch accuracy: 92.58
Batch accuracy: 90.62
Batch accuracy: 92.97
Batch accuracy: 93.08
  Accuracy: 92.38
New best model, saving it!

Training...
  Batch    40  of 

  Batch    80  of    943.    .
  step loss: 0.12
  Batch   120  of    943.    .
  step loss: 0.11
  Batch   160  of    943.    .
  step loss: 0.11
  Batch   200  of    943.    .
  step loss: 0.13
  Batch   240  of    943.    .
  step loss: 0.11
  Batch   280  of    943.    .
  step loss: 0.11
  Batch   320  of    943.    .
  step loss: 0.11
  Batch   360  of    943.    .
  step loss: 0.12
  Batch   400  of    943.    .
  step loss: 0.11
  Batch   440  of    943.    .
  step loss: 0.10
  Batch   480  of    943.    .
  step loss: 0.12
  Batch   520  of    943.    .
  step loss: 0.12
  Batch   560  of    943.    .
  step loss: 0.12
  Batch   600  of    943.    .
  step loss: 0.12
  Batch   640  of    943.    .
  step loss: 0.09
  Batch   680  of    943.    .
  step loss: 0.14
  Batch   720  of    943.    .
  step loss: 0.13
  Batch   760  of    943.    .
  step loss: 0.11
  Batch   800  of    943.    .
  step loss: 0.13
  Batch   840  of    943.    .
  step loss: 0.11
  Batch   880  of   

  Batch   920  of    943.    .
  step loss: 0.10

  Average training loss: 0.12

Running Validation...
Batch accuracy: 93.16
Batch accuracy: 92.97
Batch accuracy: 91.80
Batch accuracy: 91.41
Batch accuracy: 93.36
Batch accuracy: 93.16
Batch accuracy: 92.38
Batch accuracy: 91.99
Batch accuracy: 91.21
Batch accuracy: 92.19
Batch accuracy: 94.73
Batch accuracy: 91.60
Batch accuracy: 91.41
Batch accuracy: 92.97
Batch accuracy: 92.58
Batch accuracy: 92.58
Batch accuracy: 90.82
Batch accuracy: 91.99
Batch accuracy: 91.02
Batch accuracy: 91.99
Batch accuracy: 91.80
Batch accuracy: 91.99
Batch accuracy: 92.77
Batch accuracy: 94.92
Batch accuracy: 93.75
Batch accuracy: 93.16
Batch accuracy: 92.19
Batch accuracy: 93.55
Batch accuracy: 92.38
Batch accuracy: 93.16
Batch accuracy: 92.77
Batch accuracy: 94.53
Batch accuracy: 91.02
Batch accuracy: 93.55
Batch accuracy: 93.95
Batch accuracy: 89.45
Batch accuracy: 92.19
Batch accuracy: 93.55
Batch accuracy: 92.97
Batch accuracy: 92.38
Batch accuracy: 9

Batch accuracy: 92.38
Batch accuracy: 90.62
Batch accuracy: 93.16
Batch accuracy: 94.53
Batch accuracy: 90.62
Batch accuracy: 92.97
Batch accuracy: 91.80
Batch accuracy: 90.82
Batch accuracy: 92.77
Batch accuracy: 92.77
Batch accuracy: 91.21
Batch accuracy: 93.55
Batch accuracy: 94.92
Batch accuracy: 93.75
Batch accuracy: 92.58
Batch accuracy: 92.58
Batch accuracy: 92.19
Batch accuracy: 92.77
Batch accuracy: 92.97
Batch accuracy: 92.19
Batch accuracy: 92.97
Batch accuracy: 93.55
Batch accuracy: 93.55
Batch accuracy: 92.97
Batch accuracy: 91.80
Batch accuracy: 92.97
Batch accuracy: 93.16
Batch accuracy: 92.97
Batch accuracy: 91.60
Batch accuracy: 89.84
Batch accuracy: 90.62
Batch accuracy: 91.02
Batch accuracy: 93.55
Batch accuracy: 92.97
Batch accuracy: 90.23
Batch accuracy: 91.41
Batch accuracy: 92.77
Batch accuracy: 93.95
Batch accuracy: 91.41
Batch accuracy: 91.80
Batch accuracy: 90.82
Batch accuracy: 90.62
Batch accuracy: 94.34
Batch accuracy: 91.41
Batch accuracy: 91.80
Batch accu

Batch accuracy: 92.58
Batch accuracy: 93.75
Batch accuracy: 93.75
Batch accuracy: 93.55
Batch accuracy: 93.36
Batch accuracy: 92.19
Batch accuracy: 91.60
Batch accuracy: 94.52
  Accuracy: 92.44
The model does not improve for 3 epochs!

Training...
  Batch    40  of    943.    .
  step loss: 0.13
  Batch    80  of    943.    .
  step loss: 0.10
  Batch   120  of    943.    .
  step loss: 0.09
  Batch   160  of    943.    .
  step loss: 0.14
  Batch   200  of    943.    .
  step loss: 0.12
  Batch   240  of    943.    .
  step loss: 0.11
  Batch   280  of    943.    .
  step loss: 0.10
  Batch   320  of    943.    .
  step loss: 0.11
  Batch   360  of    943.    .
  step loss: 0.14
  Batch   400  of    943.    .
  step loss: 0.15
  Batch   440  of    943.    .
  step loss: 0.11
  Batch   480  of    943.    .
  step loss: 0.09
  Batch   520  of    943.    .
  step loss: 0.08
  Batch   560  of    943.    .
  step loss: 0.13
  Batch   600  of    943.    .
  step loss: 0.09
  Batch   640  of

In [22]:
def predict(
    model_path: Path,
    dataset: str,
    batch_size: int,
    labels_dict,
    device: torch.device
):

    print(f"====Loading dataset for testing")
    test_corpus = NLPDataset(dataset, "test", sentence_max_len, labels_dict = labels_dict, vocab= voc)

    test_dataloader = DataLoader(
        test_corpus,
        batch_size=batch_size,
        #sampler = RandomSampler(test_corpus),
        pin_memory=True,
        num_workers=0,
        drop_last=False
    )

    print(f"====Loading model for testing")
    model = torch.load(join(model_path, "best-model.pth"))
    model.to(device)
    model.eval()
    pred_labels = []
    test_labels = []
    logits_list = []

    def _list_from_tensor(tensor):
        if tensor.numel() == 1:
            return [tensor.item()]
        return list(tensor.cpu().detach().numpy())

    print("====Testing model...")
    for batch in test_dataloader:
        #h = model.init_hidden(len(batch[0]), device)
        #h = tuple([each.data for each in val_h])
        b_input_ids = batch[0]
        b_input_ids = torch.tensor(b_input_ids).to(device).long()
        b_labels = batch[1].to(device)
        with torch.no_grad():
            #h = tuple([each.data for each in h])
            loss, logits = model(b_input_ids, b_labels)
            preds = np.argmax(logits.cpu(), axis=1) # Convert one-hot to index
            b_labels = b_labels.int()
            pred_labels.extend(_list_from_tensor(preds))
            test_labels.extend(_list_from_tensor(b_labels))
        logits_list.extend(_list_from_tensor(logits))

    print(classification_report(test_labels, pred_labels, labels=list(labels_dict.values()), target_names=np.array(list(labels_dict.keys())), digits=3, output_dict=False))
    logits_list = expit(logits_list)
    del model
    torch.cuda.empty_cache()

In [23]:
del train_dataloader
del validation_dataloader
del train_corpus
del val_corpus
del model
torch.cuda.empty_cache()

In [24]:
predict(last_saved_model, dataset, batch_size, labels_dict, device)

====Loading dataset for testing
====Loading model for testing
====Testing model...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                                                  precision    recall  f1-score   support

                              Sintomas.Genérico.Sky não funciona      0.931     0.911     0.921      8357
                                    Sintomas.Genérico.Instalação      0.978     0.972     0.975       647
                                Sintomas.Genérico.Canal não pega      0.891     0.913     0.902      5967
                    Sintomas.Genérico.Equipamento não funciona G      0.922     0.944     0.933      2318
                                     Sintomas.Genérico.Sem sinal      0.945     0.947     0.946     14552
                               Sintomas.Qualificado.Cancelamento      0.903     0.962     0.932      1847
                           Sintomas.Qualificado.Outros problemas      0.800     0.704     0.749       729
                              Sintomas.Qualificado.NãoTéc_fatura      0.839     0.830     0.834      1451
                          Sintomas.Qualificad