#### Based on https://medium.com/swlh/painless-fine-tuning-of-bert-in-pytorch-b91c14912caa
https://github.com/aniruddhachoudhury/BERT-Tutorials/blob/master/Blog%202/BERT_Fine_Tuning_Sentence_Classification.ipynb

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import torch.nn as nn
from os.path import join
import sys
import torch
from nlpClassifiers.data.dataset  import NLPDataset
#from torch.optim import AdamW, SGD
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
from torch.nn import LayerNorm as BertLayerNorm
import numpy as np
import time
import logging
import datetime
import random
import pandas as pd
import argparse
import pickle as pk
import itertools
import os
import shutil
from pathlib import Path
import copy
import wandb
import re
from nlpClassifiers import settings
from scipy.special import expit
from sklearn.metrics import classification_report



In [3]:
def predict(
    model_path: Path,
    dataset: str,
    batch_size: int,
    labels_dict,
    device: torch.device
):
       
    print(f"====Loading dataset for testing")
    test_corpus = NLPDataset(dataset, "test", sentence_max_len, bert_path, labels_dict)
    test_dataloader = DataLoader(
        test_corpus,
        batch_size=batch_size,
        sampler = RandomSampler(test_corpus),
        pin_memory=True,
        num_workers=0,
        drop_last=True
    )

    print(f"====Loading model for testing")
    model = BertForSequenceClassification.from_pretrained(
        model_path,
        num_labels = test_corpus.num_labels,
        output_attentions = False,
        output_hidden_states = True,
    )
    model.to(device)
    model.eval()
  #  cm = ConfusionMatrix([0,1])
    pred_labels = []
    test_labels = []
    logits_list = []

    def _list_from_tensor(tensor):
        if tensor.numel() == 1:
            return [tensor.item()]
        return list(tensor.cpu().detach().numpy())

    print("====Testing model...")
    for batch in test_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_segment_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
            loss, logits, *_ =  model(b_input_ids, b_input_mask, token_type_ids=None, labels=b_labels)

            preds = np.argmax(logits.cpu(), axis=1) # Convert one-hot to index
            b_labels = b_labels.int()
            pred_labels.extend(_list_from_tensor(preds))
            test_labels.extend(_list_from_tensor(b_labels))
        logits_list.extend(_list_from_tensor(logits))
    print(classification_report(test_labels, pred_labels, labels=list(labels_dict.values()), target_names=np.array(list(labels_dict.keys())), digits=3))
    logits_list = expit(logits_list)
    del model
    torch.cuda.empty_cache()



In [4]:
def get_accuracy_from_logits(logits, labels):
    acc = (labels.cpu() == logits.cpu().argmax(-1)).float().detach().numpy()
    return float(100 * acc.sum() / len(acc))

In [5]:
!tail -n 10 ../../data/mercado-livre-pt-only/train.csv

protetor porta flexivel gm chevette 1990 preto univ,AUTOMOTIVE_WEATHERSTRIPS
fogao vintage pecas original continental,RANGES
100 adesivos bolo no pote etiqueta rotulo personalizado 5 cm,SELF_ADHESIVE_LABELS
eliptico caloi act home fitness clt 20,ELLIPTICAL_MACHINES
gravar conversa do skype produtos espionagem grava audio bd1,DIGITAL_VOICE_RECORDERS
12 faca desossar inox 5 plenus pre tramontina 23425005,KITCHEN_KNIVES
controle n64 usb,GAMEPADS_AND_JOYSTICKS
pinceis daymakeup- 1 pincel delinear- temos todos os pinceis,MAKEUP_BRUSHES
_protetor solar de parabrisa kangoo express logan master,AUTOMOTIVE_WEATHERSTRIPS
chave combinada tramontina 19 ee51071,WRENCHES


In [6]:
DATA_PATH = '../../data/virtual-operator'
MODELS_PATH = '../../models/virtual-operator/bert-base-portuguese-tapt-classifier/'
PATH_TO_BERT = '../../models/virtual-operator/bertimbau-adaptive-base-finetuned/'
TRAIN_DATASET = os.path.join(DATA_PATH, 'train.csv')
VAL_DATASET = os.path.join(DATA_PATH, 'val.csv')
TEST_DATASET  = os.path.join(DATA_PATH, 'test.csv')
PATH_TO_VIRTUAL_OPERATOR_MODELS =  "../../models/virtual-operator"
PATH_TO_AGENT_BENCHMARK_MODELS = "../../models/agent-benchmark"
PATH_TO_ML_PT_MODELS = "../../models/mercado-livre-pt-only"

In [7]:
gpu = 7
dataset = 'mercado-livre-pt'
save_name = 'mercado-livre-pt-100-epochs-early-stop-reset-multilingual-base-dataset-fixed'
bert_path = 'bert-base-multilingual-cased'
batch_size = 16
sentence_max_len = 30

In [8]:
BASE_PATH_TO_MODELS = {"virtual-operator": PATH_TO_VIRTUAL_OPERATOR_MODELS, "agent-benchmark": PATH_TO_AGENT_BENCHMARK_MODELS, "mercado-livre-pt": PATH_TO_ML_PT_MODELS}
FULL_PATH_TO_MODELS = join(BASE_PATH_TO_MODELS[dataset], "bert-base-portuguese-tapt-classifier")

In [9]:
device = torch.device(f"cuda:{gpu}")


In [10]:
model_path = Path(
    FULL_PATH_TO_MODELS, 
    f"base-dataset-{dataset}-{save_name}"
)
last_saved_model = model_path

In [11]:
train_corpus = NLPDataset(dataset, "train", sentence_max_len, bert_path)
labels_dict = train_corpus.labels_dict

In [12]:
predict(last_saved_model, dataset, batch_size, labels_dict, device)

====Loading dataset for testing
====Loading model for testing
====Testing model...


  _warn_prf(average, modifier, msg_start, len(result))


                                             precision    recall  f1-score   support

                              FISHING_LINES      0.933     0.908     0.920       584
                     MOBILE_DEVICE_CHARGERS      0.984     0.965     0.974       804
                                 SUNGLASSES      0.985     0.974     0.979       875
                                   FREEZERS      0.922     0.953     0.937       444
                                 CAR_WHEELS      0.990     0.997     0.994       720
                           BATHROOM_FAUCETS      0.911     0.900     0.905       521
                             ACTION_FIGURES      0.800     0.740     0.769       800
                                      IRONS      0.988     0.982     0.985       433
                                 MATTRESSES      0.991     0.985     0.988       534
                      SWIMMING_POOL_HEATERS      0.870     0.952     0.909        21
                             KITCHEN_KNIVES      0.960     0.950