In [1]:
!pip install transformers==4.11.3.

Collecting transformers==4.11.3.
  Downloading transformers-4.11.3-py3-none-any.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 5.0 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 31.4 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 28.4 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 33.4 MB/s 
[?25hCollecting huggingface-hub>=0.0.17
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 4.4 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, tra

In [2]:
import sys
from google.colab import drive
drive.mount('/content/gdrive/')
sys.path.append('/content/gdrive/MyDrive/data')

Mounted at /content/gdrive/


In [13]:
SYSPATH = '/content/gdrive/MyDrive/data/'
import torch
import random
import numpy as np
config = {
    'train_file_path': SYSPATH + 'train.csv',
    'test_file_path': SYSPATH + 'test.csv',
    'train_val_ratio': 0.1,
    'model_path': SYSPATH + 'BERT_model',
    
    'batch_size': 16,
    'num_epochs': 2,
    'learning_rate': 2e-5,
    'logging_step': 500,
    'seed': 2021}

config['device']='cuda' if torch.cuda.is_available() else 'cpu'

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

seed_everything(config['seed'])

2021

In [4]:
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
def read_data(config, tokenizer, mode='train'):
    df = pd.read_csv(config[mode + '_file_path'], sep=',')
    if mode == 'train':
        X_train, y_train = defaultdict(list), []
        X_val, y_val = defaultdict(list), []
        num_val = int(len(df) * config['train_val_ratio'])
    else:
        X_test, y_test = defaultdict(list), []

    for i, row in df.iterrows():
      # get label
        label=row[1] if mode == 'train' else 0
      # get sentence
        sentence = row[0]

        # add_special_tokens  CLS、 SEP
        inputs = tokenizer.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True, return_attention_mask=True)

        if mode == 'train':
            if i < num_val:
                X_val['inputs_ids'].append(inputs['input_ids'])
                y_val.append(label)
                X_val['token_type_ids'].append(inputs['token_type_ids'])
                X_val['attention_mask'].append(inputs['attention_mask'])
                                    
            else:
                X_train['inputs_ids'].append(inputs['input_ids'])
                y_train.append(label)
                X_train['token_type_ids'].append(inputs['token_type_ids'])
                X_train['attention_mask'].append(inputs['attention_mask'])

        else:
            X_test['inputs_ids'].append(inputs['input_ids'])
            y_test.append(label) 
            X_test['token_type_ids'].append(inputs['token_type_ids'])
            X_test['attention_mask'].append(inputs['attention_mask'])
            
    if mode == 'train':
        label2id = {label: i for i, label in enumerate(np.unique(y_train))} 
        id2label = {i: label for label, i in label2id.items()} 
        y_train = torch.tensor([label2id[i] for i in y_train], dtype=torch.long)  
        y_val = torch.tensor([label2id[i] for i in y_val], dtype=torch.long)  
        return X_train, y_train, X_val, y_val, label2id, id2label
        
    else:
        y_test = torch.tensor(y_test, dtype=torch.long)
        return X_test, y_test

In [5]:
from torch.utils.data import Dataset
class NEWSData(Dataset):
    def __init__(self, X, y):
        self.x = X
        self.y = y

    def __getitem__(self, idx):
        return {
            'inputs_ids' : self.x['inputs_ids'][idx],
            'label' : self.y[idx],
            'token_type_ids': self.x['token_type_ids'][idx],
            'attention_mask': self.x['attention_mask'][idx]
        }
    
    def __len__(self):
        return self.y.size(0)

In [6]:
def collate_fn(examples):
    input_ids_list = []
    labels = []
    token_type_ids_list = []
    attention_mask_list = []

    for example in examples:
        input_ids_list.append(example['inputs_ids'])
        labels.append(example['label'])
        token_type_ids_list.append(example['token_type_ids'])
        attention_mask_list.append(example['attention_mask'])
    
    # get max_length
    max_length = max(len(input_ids) for input_ids in input_ids_list)
    input_ids_tensor = torch.zeros((len(labels), max_length), dtype=torch.long)
    token_type_ids_tensor = torch.zeros_like(input_ids_tensor)
    attention_mask_tensor = torch.zeros_like(input_ids_tensor)

    # insert values into the tensor
    for i, input_ids in enumerate(input_ids_list):
        input_ids_tensor[i, :len(input_ids)] = torch.tensor(input_ids, dtype=torch.long)
        token_type_ids_tensor[i, :len(input_ids)] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
        attention_mask_tensor[i, :len(input_ids)] = torch.tensor(attention_mask_list[i], dtype=torch.long)
    
    return{
        'input_ids' : input_ids_tensor,
        'labels' : torch.tensor(labels, dtype=torch.long),
        'token_type_ids': token_type_ids_tensor,
        'attention_mask': attention_mask_tensor
        }

In [7]:
from transformers import BertTokenizer
from torch.utils.data import DataLoader
def build_dataloader(config):
    # get bert pretrain
    tokenizer = BertTokenizer.from_pretrained(config['model_path'])
    X_train, y_train, X_val, y_val, label2id, id2label = read_data(config, tokenizer, mode='train')
    X_test, y_test = read_data(config, tokenizer, mode='test')

    train_dataset = NEWSData(X_train, y_train)
    val_dataset = NEWSData(X_val, y_val)
    test_dataset = NEWSData(X_test, y_test)

    train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], num_workers=4, shuffle=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], num_workers=4, shuffle=False, collate_fn=collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], num_workers=4, shuffle=False, collate_fn=collate_fn)

    return train_dataloader, val_dataloader, test_dataloader, id2label

train_dataloader, val_dataloader, test_dataloader, id2label = build_dataloader(config)

  cpuset_checked))


In [8]:
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score

def evaluation(config, model, val_dataloader):
    model.eval()
    preds = []
    labels = []
    val_loss = 0.
    val_iterator = tqdm(val_dataloader, desc='Evaluation', total=len(val_dataloader))

    with torch.no_grad():
        for batch in val_iterator:
            labels.append(batch['labels'])
            batch = {item: value.to(config['device']) for item, value in batch.items()}
            loss, logits = model(**batch)[:2]

            val_loss += loss.item()
            preds.append(logits.argmax(dim=-1).detach().cpu())

    avg_val_loss = val_loss / len(val_dataloader)
    labels = torch.cat(labels, dim=0).numpy()
    preds = torch.cat(preds, dim=0).numpy()

    accuracy = accuracy_score(labels, preds)
    recall = recall_score(labels, preds, average='macro')
    precision = precision_score(labels, preds, average='macro')
    f1 = f1_score(labels, preds, average='macro')
    return [avg_val_loss, accuracy, recall, precision, f1]

In [9]:
from transformers import BertConfig, BertForSequenceClassification
from transformers import AdamW
from tqdm import trange

def train(config, id2label, train_dataloader, val_dataloader):
    bert_config = BertConfig.from_pretrained(config['model_path'])
    bert_config.num_labels = len(id2label)
    model = BertForSequenceClassification.from_pretrained(config['model_path'], config=bert_config)

    # define optimizer
    optimizer = AdamW(model.parameters(), lr=config['learning_rate'])
    model.to(config['device'])
    global_steps = 0
    train_loss = 0.
    logging_loss = 0.

    
    for epoch in trange(config['num_epochs']):
      train_iterator = tqdm(train_dataloader, desc='Training', total=len(train_dataloader))
      model.train()
      for batch in train_dataloader:
        batch = {item: value.to(config['device']) for item, value in batch.items()}
        loss = model(**batch)[0]
        model.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        global_steps += 1

        if global_steps % config['logging_step'] == 0:
            print_train_loss = (train_loss - logging_loss) / config['logging_step']
            logging_loss = train_loss
            result = evaluation(config, model, val_dataloader)
            avg_val_loss, accuracy = result[0], result[1]
            print_log = f'>>> training loss: {print_train_loss:.4f}, valid loss: {avg_val_loss:.4f}, ' \
                        f'valid accuracy score: {accuracy:.4f}'
            print(print_log)
            model.train()

    return model

In [14]:
model = train(config, id2label, train_dataloader, val_dataloader)

Some weights of the model checkpoint at /content/gdrive/MyDrive/data/BERT_model were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from t

>>> training loss: 0.4244, valid loss: 0.2435, valid accuracy score: 0.9265






  cpuset_checked))




Evaluation:   0%|          | 1/282 [00:00<02:03,  2.28it/s][A[A[A[A



Evaluation:   1%|          | 2/282 [00:00<01:20,  3.50it/s][A[A[A[A



Evaluation:   1%|          | 3/282 [00:00<01:06,  4.18it/s][A[A[A[A



Evaluation:   1%|▏         | 4/282 [00:01<01:01,  4.51it/s][A[A[A[A



Evaluation:   2%|▏         | 5/282 [00:01<00:57,  4.81it/s][A[A[A[A



Evaluation:   2%|▏         | 6/282 [00:01<00:55,  4.97it/s][A[A[A[A



Evaluation:   2%|▏         | 7/282 [00:01<00:54,  5.07it/s][A[A[A[A



Evaluation:   3%|▎         | 8/282 [00:01<00:53,  5.13it/s][A[A[A[A



Evaluation:   3%|▎         | 9/282 [00:01<00:53,  5.14it/s][A[A[A[A



Evaluation:   4%|▎         | 10/282 [00:02<00:52,  5.21it/s][A[A[A[A



Evaluation:   4%|▍         | 11/282 [00:02<00:51,  5.24it/s][A[A[A[A



Evaluation:   4%|▍         | 12/282 [00:02<00:51,  5.29it/s][A[A[A[A



Evaluation:   5%|▍         | 13/282 [00:02<00:50,  5.30it/s][A[A[A

>>> training loss: 0.2425, valid loss: 0.2075, valid accuracy score: 0.9358






  cpuset_checked))




Evaluation:   0%|          | 1/282 [00:00<01:55,  2.43it/s][A[A[A[A



Evaluation:   1%|          | 2/282 [00:00<01:17,  3.62it/s][A[A[A[A



Evaluation:   1%|          | 3/282 [00:00<01:05,  4.24it/s][A[A[A[A



Evaluation:   1%|▏         | 4/282 [00:00<01:01,  4.51it/s][A[A[A[A



Evaluation:   2%|▏         | 5/282 [00:01<00:57,  4.78it/s][A[A[A[A



Evaluation:   2%|▏         | 6/282 [00:01<00:54,  5.04it/s][A[A[A[A



Evaluation:   2%|▏         | 7/282 [00:01<00:54,  5.07it/s][A[A[A[A



Evaluation:   3%|▎         | 8/282 [00:01<00:53,  5.15it/s][A[A[A[A



Evaluation:   3%|▎         | 9/282 [00:01<00:52,  5.22it/s][A[A[A[A



Evaluation:   4%|▎         | 10/282 [00:02<00:51,  5.24it/s][A[A[A[A



Evaluation:   4%|▍         | 11/282 [00:02<00:51,  5.22it/s][A[A[A[A



Evaluation:   4%|▍         | 12/282 [00:02<00:51,  5.21it/s][A[A[A[A



Evaluation:   5%|▍         | 13/282 [00:02<00:51,  5.26it/s][A[A[A

>>> training loss: 0.2171, valid loss: 0.1850, valid accuracy score: 0.9402






  cpuset_checked))




Evaluation:   0%|          | 1/282 [00:00<01:54,  2.46it/s][A[A[A[A



Evaluation:   1%|          | 2/282 [00:00<01:17,  3.62it/s][A[A[A[A



Evaluation:   1%|          | 3/282 [00:00<01:06,  4.21it/s][A[A[A[A



Evaluation:   1%|▏         | 4/282 [00:00<01:01,  4.50it/s][A[A[A[A



Evaluation:   2%|▏         | 5/282 [00:01<00:57,  4.81it/s][A[A[A[A



Evaluation:   2%|▏         | 6/282 [00:01<00:55,  4.98it/s][A[A[A[A



Evaluation:   2%|▏         | 7/282 [00:01<00:54,  5.04it/s][A[A[A[A



Evaluation:   3%|▎         | 8/282 [00:01<00:53,  5.09it/s][A[A[A[A



Evaluation:   3%|▎         | 9/282 [00:01<00:52,  5.17it/s][A[A[A[A



Evaluation:   4%|▎         | 10/282 [00:02<00:52,  5.21it/s][A[A[A[A



Evaluation:   4%|▍         | 11/282 [00:02<00:52,  5.20it/s][A[A[A[A



Evaluation:   4%|▍         | 12/282 [00:02<00:52,  5.19it/s][A[A[A[A



Evaluation:   5%|▍         | 13/282 [00:02<00:51,  5.20it/s][A[A[A

>>> training loss: 0.2114, valid loss: 0.1749, valid accuracy score: 0.9427






  cpuset_checked))




Evaluation:   0%|          | 1/282 [00:00<01:54,  2.45it/s][A[A[A[A



Evaluation:   1%|          | 2/282 [00:00<01:22,  3.39it/s][A[A[A[A



Evaluation:   1%|          | 3/282 [00:00<01:07,  4.14it/s][A[A[A[A



Evaluation:   1%|▏         | 4/282 [00:00<01:02,  4.48it/s][A[A[A[A



Evaluation:   2%|▏         | 5/282 [00:01<00:58,  4.73it/s][A[A[A[A



Evaluation:   2%|▏         | 6/282 [00:01<00:56,  4.89it/s][A[A[A[A



Evaluation:   2%|▏         | 7/282 [00:01<00:55,  4.96it/s][A[A[A[A



Evaluation:   3%|▎         | 8/282 [00:01<00:55,  4.96it/s][A[A[A[A



Evaluation:   3%|▎         | 9/282 [00:01<00:53,  5.06it/s][A[A[A[A



Evaluation:   4%|▎         | 10/282 [00:02<00:53,  5.09it/s][A[A[A[A



Evaluation:   4%|▍         | 11/282 [00:02<00:52,  5.14it/s][A[A[A[A



Evaluation:   4%|▍         | 12/282 [00:02<00:52,  5.11it/s][A[A[A[A



Evaluation:   5%|▍         | 13/282 [00:02<00:52,  5.13it/s][A[A[A

>>> training loss: 0.1942, valid loss: 0.1817, valid accuracy score: 0.9394


 50%|█████     | 1/2 [26:42<26:42, 1602.33s/it]



Training:   0%|          | 0/2532 [26:42<?, ?it/s]
  cpuset_checked))



Evaluation:   0%|          | 0/282 [00:00<?, ?it/s][A[A[A


Evaluation:   0%|          | 1/282 [00:00<01:54,  2.46it/s][A[A[A


Evaluation:   1%|          | 2/282 [00:00<01:17,  3.63it/s][A[A[A


Evaluation:   1%|          | 3/282 [00:00<01:07,  4.14it/s][A[A[A


Evaluation:   1%|▏         | 4/282 [00:00<01:02,  4.48it/s][A[A[A


Evaluation:   2%|▏         | 5/282 [00:01<00:58,  4.72it/s][A[A[A


Evaluation:   2%|▏         | 6/282 [00:01<00:56,  4.88it/s][A[A[A


Evaluation:   2%|▏         | 7/282 [00:01<00:55,  4.92it/s][A[A[A


Evaluation:   3%|▎         | 8/282 [00:01<00:54,  4.98it/s][A[A[A


Evaluation:   3%|▎         | 9/282 [00:01<00:53,  5.10it/s][A[A[A


Evaluation:   4%|▎         | 10/282 [00:02<00:52,  5.21it/s][A[A[A


Evaluation:   4%|▍         | 11/282 [00:02<00:51,  5.25it/s][A[A[A


Evaluation:   4%|▍         | 

>>> training loss: 0.1388, valid loss: 0.1773, valid accuracy score: 0.9430





  cpuset_checked))



Evaluation:   0%|          | 1/282 [00:00<01:54,  2.45it/s][A[A[A


Evaluation:   1%|          | 2/282 [00:00<01:17,  3.63it/s][A[A[A


Evaluation:   1%|          | 3/282 [00:00<01:06,  4.18it/s][A[A[A


Evaluation:   1%|▏         | 4/282 [00:00<01:01,  4.51it/s][A[A[A


Evaluation:   2%|▏         | 5/282 [00:01<00:58,  4.77it/s][A[A[A


Evaluation:   2%|▏         | 6/282 [00:01<00:55,  4.95it/s][A[A[A


Evaluation:   2%|▏         | 7/282 [00:01<00:55,  4.99it/s][A[A[A


Evaluation:   3%|▎         | 8/282 [00:01<00:54,  4.99it/s][A[A[A


Evaluation:   3%|▎         | 9/282 [00:01<00:53,  5.09it/s][A[A[A


Evaluation:   4%|▎         | 10/282 [00:02<00:53,  5.12it/s][A[A[A


Evaluation:   4%|▍         | 11/282 [00:02<00:52,  5.13it/s][A[A[A


Evaluation:   4%|▍         | 12/282 [00:02<00:52,  5.18it/s][A[A[A


Evaluation:   5%|▍         | 13/282 [00:02<00:51,  5.18it/s][A[A[A


Evaluation:   5%|▍         | 14/282 [00:02<00:51

>>> training loss: 0.1421, valid loss: 0.1665, valid accuracy score: 0.9472





  cpuset_checked))



Evaluation:   0%|          | 1/282 [00:00<01:48,  2.60it/s][A[A[A


Evaluation:   1%|          | 2/282 [00:00<01:17,  3.63it/s][A[A[A


Evaluation:   1%|          | 3/282 [00:00<01:06,  4.21it/s][A[A[A


Evaluation:   1%|▏         | 4/282 [00:00<01:01,  4.50it/s][A[A[A


Evaluation:   2%|▏         | 5/282 [00:01<00:58,  4.76it/s][A[A[A


Evaluation:   2%|▏         | 6/282 [00:01<00:55,  4.95it/s][A[A[A


Evaluation:   2%|▏         | 7/282 [00:01<00:54,  5.07it/s][A[A[A


Evaluation:   3%|▎         | 8/282 [00:01<00:53,  5.11it/s][A[A[A


Evaluation:   3%|▎         | 9/282 [00:01<00:52,  5.15it/s][A[A[A


Evaluation:   4%|▎         | 10/282 [00:02<00:52,  5.20it/s][A[A[A


Evaluation:   4%|▍         | 11/282 [00:02<00:51,  5.24it/s][A[A[A


Evaluation:   4%|▍         | 12/282 [00:02<00:51,  5.22it/s][A[A[A


Evaluation:   5%|▍         | 13/282 [00:02<00:51,  5.27it/s][A[A[A


Evaluation:   5%|▍         | 14/282 [00:02<00:50

>>> training loss: 0.1387, valid loss: 0.1636, valid accuracy score: 0.9467





  cpuset_checked))



Evaluation:   0%|          | 1/282 [00:00<01:58,  2.37it/s][A[A[A


Evaluation:   1%|          | 2/282 [00:00<01:19,  3.51it/s][A[A[A


Evaluation:   1%|          | 3/282 [00:00<01:08,  4.08it/s][A[A[A


Evaluation:   1%|▏         | 4/282 [00:00<01:01,  4.53it/s][A[A[A


Evaluation:   2%|▏         | 5/282 [00:01<00:57,  4.81it/s][A[A[A


Evaluation:   2%|▏         | 6/282 [00:01<00:54,  5.02it/s][A[A[A


Evaluation:   2%|▏         | 7/282 [00:01<00:55,  4.97it/s][A[A[A


Evaluation:   3%|▎         | 8/282 [00:01<00:53,  5.14it/s][A[A[A


Evaluation:   3%|▎         | 9/282 [00:01<00:52,  5.18it/s][A[A[A


Evaluation:   4%|▎         | 10/282 [00:02<00:52,  5.15it/s][A[A[A


Evaluation:   4%|▍         | 11/282 [00:02<00:52,  5.19it/s][A[A[A


Evaluation:   4%|▍         | 12/282 [00:02<00:52,  5.17it/s][A[A[A


Evaluation:   5%|▍         | 13/282 [00:02<00:51,  5.21it/s][A[A[A


Evaluation:   5%|▍         | 14/282 [00:02<00:51

>>> training loss: 0.1462, valid loss: 0.1596, valid accuracy score: 0.9475





  cpuset_checked))



Evaluation:   0%|          | 1/282 [00:00<01:54,  2.46it/s][A[A[A


Evaluation:   1%|          | 2/282 [00:00<01:19,  3.52it/s][A[A[A


Evaluation:   1%|          | 3/282 [00:00<01:08,  4.10it/s][A[A[A


Evaluation:   1%|▏         | 4/282 [00:00<01:01,  4.51it/s][A[A[A


Evaluation:   2%|▏         | 5/282 [00:01<00:57,  4.78it/s][A[A[A


Evaluation:   2%|▏         | 6/282 [00:01<00:55,  4.95it/s][A[A[A


Evaluation:   2%|▏         | 7/282 [00:01<00:55,  4.94it/s][A[A[A


Evaluation:   3%|▎         | 8/282 [00:01<00:53,  5.10it/s][A[A[A


Evaluation:   3%|▎         | 9/282 [00:01<00:52,  5.17it/s][A[A[A


Evaluation:   4%|▎         | 10/282 [00:02<00:52,  5.22it/s][A[A[A


Evaluation:   4%|▍         | 11/282 [00:02<00:52,  5.21it/s][A[A[A


Evaluation:   4%|▍         | 12/282 [00:02<00:52,  5.17it/s][A[A[A


Evaluation:   5%|▍         | 13/282 [00:02<00:52,  5.12it/s][A[A[A


Evaluation:   5%|▍         | 14/282 [00:02<00:52

>>> training loss: 0.1393, valid loss: 0.1655, valid accuracy score: 0.9459


100%|██████████| 2/2 [53:28<00:00, 1604.04s/it]
Training:   0%|          | 0/2532 [26:45<?, ?it/s]


In [15]:
def predict(config, id2label, model, test_dataloader):
    test_iterator = tqdm(test_dataloader, desc='Predicting', total=len(test_dataloader))
    model.eval()
    test_preds = []
    with torch.no_grad():
        for batch in test_iterator:
            batch = {item: value.to(config['device']) for item, value in batch.items()}
            logits = model(**batch)[1]

            test_preds.append(logits.argmax(dim=-1).detach().cpu())
    test_preds = torch.cat(test_preds, dim=0).numpy()
    test_preds = [id2label[id_] for id_ in test_preds]
    test_df = pd.read_csv(config['test_file_path'], sep=',')
    test_df.insert(1, column='predicted_label', value=test_preds)
    test_df.drop(columns=['sentence'], inplace=True)
    
    return test_df

In [16]:
prediction = predict(config, id2label, model, test_dataloader)

  cpuset_checked))
Predicting: 100%|██████████| 157/157 [00:28<00:00,  5.47it/s]


In [17]:
accuracy_score(prediction['predicted_label'], prediction['label'])

0.9457

In [18]:
recall_score(prediction['predicted_label'], prediction['label'], average='weighted')

0.9457

In [19]:
precision_score(prediction['predicted_label'], prediction['label'], average='weighted')

0.9462547999999998

In [20]:
f1_score(prediction['predicted_label'], prediction['label'], average='weighted')

0.9458034959098653