#### Based on https://medium.com/swlh/painless-fine-tuning-of-bert-in-pytorch-b91c14912caa
https://github.com/aniruddhachoudhury/BERT-Tutorials/blob/master/Blog%202/BERT_Fine_Tuning_Sentence_Classification.ipynb

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import torch.nn as nn
from os.path import join
import torch
from nlpClassifiers.data.dataset  import NLPDataset
#from torch.optim import AdamW, SGD
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
from torch.nn import LayerNorm as BertLayerNorm
import numpy as np
import time
import logging
import datetime
import random
import pandas as pd
import argparse
import pickle as pk
import itertools
import os
import shutil
from pathlib import Path
import copy
import wandb
import re
from nlpClassifiers import settings
from scipy.special import expit
from sklearn.metrics import classification_report

In [3]:
def predict(
    model_path: Path,
    dataset: str,
    batch_size: int,
    labels_dict,
    device: torch.device
):
       
    print(f"====Loading dataset for testing")
    test_corpus = NLPDataset(dataset, "test", sentence_max_len, bert_path, labels_dict)
    test_dataloader = DataLoader(
        test_corpus,
        batch_size=batch_size,
        #sampler = RandomSampler(test_corpus),
        pin_memory=True,
        num_workers=0,
        drop_last=False
    )

    print(f"====Loading model for testing")
    model = BertForSequenceClassification.from_pretrained(
        model_path,
        num_labels = test_corpus.num_labels,
        output_attentions = False,
        output_hidden_states = False,
    )
    model.to(device)
    model.eval()
  #  cm = ConfusionMatrix([0,1])
    pred_labels = []
    test_labels = []
    logits_list = []

    def _list_from_tensor(tensor):
        if tensor.numel() == 1:
            return [tensor.item()]
        return list(tensor.cpu().detach().numpy())

    print("====Testing model...")
    for batch in test_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_segment_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
            loss, logits, *_ =  model(b_input_ids, b_input_mask, token_type_ids=None, labels=b_labels)

            preds = np.argmax(logits.cpu(), axis=1) # Convert one-hot to index
            b_labels = b_labels.int()
            pred_labels.extend(_list_from_tensor(preds))
            test_labels.extend(_list_from_tensor(b_labels))
        #logits_list.extend(_list_from_tensor(logits))
        logits_list.append(logits.cpu().detach().numpy()[0])
    print(classification_report(test_labels, pred_labels, labels=list(labels_dict.values()), target_names=np.array(list(labels_dict.keys())), digits=3))
    del model
    torch.cuda.empty_cache()
    return np.array(logits_list)



In [4]:
def get_accuracy_from_logits(logits, labels):
    acc = (labels.cpu() == logits.cpu().argmax(-1)).float().detach().numpy()
    return float(100 * acc.sum() / len(acc))

In [5]:
DATA_PATH = '../../data/virtual-operator'
TRAIN_DATASET = os.path.join(DATA_PATH, 'train.csv')
VAL_DATASET = os.path.join(DATA_PATH, 'val.csv')
TEST_DATASET  = os.path.join(DATA_PATH, 'test.csv')
bert_path =  'neuralmind/bert-base-portuguese-cased'

In [6]:
gpu = 0
dataset = 'virtual-operator'
model_path='../../models/virtual-operator/bert-base-portuguese-tapt-classifier/base-dataset-virtual-operator-virtual-operator-100-epochs-early-stop-reset-3-bertimbau-base'
batch_size = 16
sentence_max_len = 82

In [7]:
device = torch.device(f"cuda:{gpu}")


In [8]:
train_corpus = NLPDataset(dataset, "train", sentence_max_len, bert_path)
labels_dict = train_corpus.labels_dict

In [9]:
prediction = predict(model_path, dataset, batch_size, labels_dict, device)

====Loading dataset for testing
====Loading model for testing
====Testing model...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                                                  precision    recall  f1-score   support

                              Sintomas.Genérico.Sky não funciona      0.966     0.954     0.960      8357
                                    Sintomas.Genérico.Instalação      0.994     0.992     0.993       647
                                Sintomas.Genérico.Canal não pega      0.947     0.957     0.952      5967
                    Sintomas.Genérico.Equipamento não funciona G      0.983     0.967     0.975      2318
                                     Sintomas.Genérico.Sem sinal      0.978     0.970     0.974     14552
                               Sintomas.Qualificado.Cancelamento      0.951     0.979     0.965      1847
                           Sintomas.Qualificado.Outros problemas      0.916     0.856     0.885       729
                              Sintomas.Qualificado.NãoTéc_fatura      0.918     0.924     0.921      1451
                          Sintomas.Qualificad

  _warn_prf(average, modifier, msg_start, len(result))
