In [None]:
!nvidia-smi

Sat Sep 11 09:17:34 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    50W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install transformers SentencePiece torch tqdm



In [None]:
import math

from tqdm import tqdm
import numpy as np
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score

In [None]:
class SoftEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                n_tokens: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to 
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                                  n_tokens, 
                                                                                  random_range, 
                                                                                  initialize_from_vocab))
            
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
            
    def forward(self, tokens):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

In [None]:
!pip install zh-dataset-inews



In [None]:
from zh_dataset_inews import title_train, label_train, title_dev, label_dev, title_test, label_test

In [None]:
def generate_data(batch_size, n_tokens, title_data, label_data):

    labels = [
        torch.tensor([[3]]),  # \x00
        torch.tensor([[4]]),  # \x01
        torch.tensor([[5]]),  # \x02
    ]

    def yield_data(x_batch, y_batch, l_batch):
        x = torch.nn.utils.rnn.pad_sequence(x_batch, batch_first=True)
        y = torch.cat(y_batch, dim=0)
        m = (x > 0).to(torch.float32)
        decoder_input_ids = torch.full((x.size(0), n_tokens), 1)
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
            m = m.cuda()
            decoder_input_ids = decoder_input_ids.cuda()
        return x, y, m, decoder_input_ids, l_batch

    x_batch, y_batch, l_batch = [], [], []
    for x, y in zip(title_data, label_data):
        context = x
        inputs = tokenizer(context, return_tensors="pt")
        inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)
        l_batch.append(y)
        y = labels[y]
        y = torch.cat([torch.full((1, n_tokens - 1), -100), y], 1)
        x_batch.append(inputs['input_ids'][0])
        y_batch.append(y)
        if len(x_batch) >= batch_size:
            yield yield_data(x_batch, y_batch, l_batch)
            x_batch, y_batch, l_batch = [], [], []

    if len(x_batch) > 0:
        yield yield_data(x_batch, y_batch, l_batch)
        x_batch, y_batch, l_batch = [], [], []

In [None]:
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-large")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-large")
n_tokens = 100
s_wte = SoftEmbedding(model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=True)
model.set_input_embeddings(s_wte)
if torch.cuda.is_available():
    model = model.cuda()

In [None]:
parameters = list(model.parameters())
for x in parameters[1:]:
    x.requires_grad = False

In [None]:
parameters[0]

Parameter containing:
tensor([[ -1.0312,  -4.2500,   7.0000,  ...,   6.0938,  -8.0625,  -9.5000],
        [ -7.7500, -12.1250,  -2.3438,  ...,  -7.8438,   9.1875,   4.4375],
        [  0.9805,   1.0781,  -0.3867,  ...,  -1.0156,  -0.4785,   0.8008],
        ...,
        [ -1.4922,   0.1895,  -0.2041,  ...,   0.6250,   0.0131,  -1.8828],
        [  0.8789,   0.1108,   1.1953,  ...,   0.8281,   1.4844,   0.3418],
        [  0.1436,  -0.3867,  -0.7734,  ...,   0.5078,  -0.0157,   0.1060]],
       device='cuda:0', requires_grad=True)

In [None]:
parameters[2]

Parameter containing:
tensor([[ 0.0099,  0.0084,  0.0172,  ...,  0.0220,  0.0435, -0.0337],
        [ 0.0112, -0.0181, -0.0107,  ...,  0.0227,  0.0190,  0.0033],
        [ 0.0061,  0.0430,  0.0625,  ..., -0.0334, -0.0130,  0.0205],
        ...,
        [ 0.0034,  0.0228,  0.0003,  ...,  0.0113, -0.0045, -0.0222],
        [ 0.0297, -0.0042, -0.0393,  ...,  0.0037, -0.0145, -0.0023],
        [ 0.0053, -0.0029,  0.0157,  ..., -0.0125,  0.0068,  0.0106]],
       device='cuda:0')

In [None]:
for x, y, m, dii, true_labels in generate_data(2, n_tokens, title_train, label_train):
    assert dii.shape == y.shape
    outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)
    assert outputs['logits'].shape[:2] == y.shape
    pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()
    break

In [13]:
batch_size = 2
n_epoch = 50
total_batch = math.ceil(len(title_train) / batch_size)
dev_total_batch = math.ceil(len(title_dev) / batch_size)
use_ce_loss = False
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(s_wte.parameters(), lr=0.5)

for epoch in range(n_epoch):
    print('epoch', epoch)

    all_true_labels = []
    all_pred_labels = []
    losses = []
    pbar = tqdm(enumerate(generate_data(batch_size, n_tokens, title_train, label_train)), total=total_batch)
    for i, (x, y, m, dii, true_labels) in pbar:
        all_true_labels += true_labels
        
        optimizer.zero_grad()
        outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)
        pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()
        all_pred_labels += pred_labels

        if use_ce_loss:
            logits = outputs['logits'][:, -1, 3:6]
            true_labels_tensor = torch.tensor(true_labels, dtype=torch.long).cuda()
            loss = ce_loss(logits, true_labels_tensor)
        else:
            loss = outputs.loss
        loss.backward()
        optimizer.step()
        loss_value = float(loss.detach().cpu().numpy().tolist()) / batch_size
        losses.append(loss_value)

        acc = accuracy_score(all_true_labels, all_pred_labels)
        pbar.set_description(f'train: loss={np.mean(losses):.4f}, acc={acc:.4f}')

    all_true_labels = []
    all_pred_labels = []
    losses = []
    with torch.no_grad():
        pbar = tqdm(enumerate(generate_data(batch_size, n_tokens, title_dev, label_dev)), total=dev_total_batch)
        for i, (x, y, m, dii, true_labels) in pbar:
            all_true_labels += true_labels
            outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)
            loss = outputs.loss
            loss_value = float(loss.detach().cpu().numpy().tolist()) / batch_size
            losses.append(loss_value)
            pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()
            all_pred_labels += pred_labels
            acc = accuracy_score(all_true_labels, all_pred_labels)
            pbar.set_description(f'dev: loss={np.mean(losses):.4f}, acc={acc:.4f}')

epoch 0


train: loss=27.0702, acc=0.1079: 100%|██████████| 2678/2678 [06:10<00:00,  7.23it/s]
dev: loss=24.9345, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s]


epoch 1


train: loss=18.2214, acc=0.1617: 100%|██████████| 2678/2678 [06:08<00:00,  7.26it/s]
dev: loss=16.0809, acc=0.4344: 100%|██████████| 500/500 [00:32<00:00, 15.47it/s]


epoch 2


train: loss=24.5954, acc=0.2510: 100%|██████████| 2678/2678 [06:08<00:00,  7.26it/s]
dev: loss=31.2844, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.45it/s]


epoch 3


train: loss=27.7627, acc=0.1079: 100%|██████████| 2678/2678 [06:07<00:00,  7.29it/s]
dev: loss=17.8570, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.51it/s]


epoch 4


train: loss=5.3752, acc=0.4288: 100%|██████████| 2678/2678 [06:08<00:00,  7.27it/s]
dev: loss=2.5420, acc=0.4935: 100%|██████████| 500/500 [00:32<00:00, 15.43it/s]


epoch 5


train: loss=1.6852, acc=0.4426: 100%|██████████| 2678/2678 [06:07<00:00,  7.29it/s]
dev: loss=1.8112, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.38it/s]


epoch 6


train: loss=6.8852, acc=0.3473: 100%|██████████| 2678/2678 [06:07<00:00,  7.29it/s]
dev: loss=15.4861, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.24it/s]


epoch 7


train: loss=5.7757, acc=0.4325: 100%|██████████| 2678/2678 [06:07<00:00,  7.29it/s]
dev: loss=2.9378, acc=0.4895: 100%|██████████| 500/500 [00:32<00:00, 15.45it/s]


epoch 8


train: loss=3.1350, acc=0.4405: 100%|██████████| 2678/2678 [06:08<00:00,  7.27it/s]
dev: loss=1.9451, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.25it/s]


epoch 9


train: loss=1.4360, acc=0.4459: 100%|██████████| 2678/2678 [06:07<00:00,  7.28it/s]
dev: loss=1.5743, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.34it/s]


epoch 10


train: loss=0.7307, acc=0.4702: 100%|██████████| 2678/2678 [06:07<00:00,  7.28it/s]
dev: loss=2.3358, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.38it/s]


epoch 11


train: loss=0.7700, acc=0.4480: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=1.3577, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.50it/s]


epoch 12


train: loss=0.6409, acc=0.4652: 100%|██████████| 2678/2678 [06:06<00:00,  7.31it/s]
dev: loss=1.1933, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.43it/s]


epoch 13


train: loss=0.5962, acc=0.5100: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=1.0476, acc=0.4925: 100%|██████████| 500/500 [00:31<00:00, 15.64it/s]


epoch 14


train: loss=0.5385, acc=0.5968: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=1.1322, acc=0.4925: 100%|██████████| 500/500 [00:31<00:00, 15.69it/s]


epoch 15


train: loss=0.4824, acc=0.6390: 100%|██████████| 2678/2678 [06:02<00:00,  7.39it/s]
dev: loss=0.8160, acc=0.4925: 100%|██████████| 500/500 [00:31<00:00, 15.64it/s]


epoch 16


train: loss=0.4478, acc=0.6753: 100%|██████████| 2678/2678 [06:03<00:00,  7.38it/s]
dev: loss=0.6584, acc=0.5726: 100%|██████████| 500/500 [00:32<00:00, 15.60it/s]


epoch 17


train: loss=0.4170, acc=0.6930: 100%|██████████| 2678/2678 [06:02<00:00,  7.39it/s]
dev: loss=0.6231, acc=0.6336: 100%|██████████| 500/500 [00:32<00:00, 15.52it/s]


epoch 18


train: loss=0.3995, acc=0.7029: 100%|██████████| 2678/2678 [06:01<00:00,  7.41it/s]
dev: loss=0.5515, acc=0.6667: 100%|██████████| 500/500 [00:32<00:00, 15.56it/s]


epoch 19


train: loss=0.3821, acc=0.7107: 100%|██████████| 2678/2678 [06:03<00:00,  7.38it/s]
dev: loss=0.5967, acc=0.6466: 100%|██████████| 500/500 [00:31<00:00, 15.63it/s]


epoch 20


train: loss=0.3647, acc=0.7303: 100%|██████████| 2678/2678 [06:01<00:00,  7.40it/s]
dev: loss=0.4908, acc=0.6557: 100%|██████████| 500/500 [00:32<00:00, 15.57it/s]


epoch 21


train: loss=0.3547, acc=0.7315: 100%|██████████| 2678/2678 [06:01<00:00,  7.42it/s]
dev: loss=0.4713, acc=0.7067: 100%|██████████| 500/500 [00:31<00:00, 15.65it/s]


epoch 22


train: loss=0.3406, acc=0.7415: 100%|██████████| 2678/2678 [06:01<00:00,  7.40it/s]
dev: loss=0.4078, acc=0.7417: 100%|██████████| 500/500 [00:31<00:00, 15.64it/s]


epoch 23


train: loss=0.3255, acc=0.7565: 100%|██████████| 2678/2678 [06:02<00:00,  7.39it/s]
dev: loss=0.3752, acc=0.7337: 100%|██████████| 500/500 [00:32<00:00, 15.61it/s]


epoch 24


train: loss=0.3098, acc=0.7621: 100%|██████████| 2678/2678 [06:02<00:00,  7.38it/s]
dev: loss=0.3692, acc=0.7407: 100%|██████████| 500/500 [00:32<00:00, 15.57it/s]


epoch 25


train: loss=0.2986, acc=0.7811: 100%|██████████| 2678/2678 [06:02<00:00,  7.40it/s]
dev: loss=0.3479, acc=0.7628: 100%|██████████| 500/500 [00:32<00:00, 15.50it/s]


epoch 26


train: loss=0.2924, acc=0.7824: 100%|██████████| 2678/2678 [06:02<00:00,  7.39it/s]
dev: loss=0.3496, acc=0.7548: 100%|██████████| 500/500 [00:32<00:00, 15.62it/s]


epoch 27


train: loss=0.2865, acc=0.7826: 100%|██████████| 2678/2678 [06:00<00:00,  7.42it/s]
dev: loss=0.3458, acc=0.7608: 100%|██████████| 500/500 [00:31<00:00, 15.69it/s]


epoch 28


train: loss=0.2746, acc=0.7950: 100%|██████████| 2678/2678 [06:03<00:00,  7.37it/s]
dev: loss=0.3276, acc=0.7648: 100%|██████████| 500/500 [00:31<00:00, 15.63it/s]


epoch 29


train: loss=0.2642, acc=0.8011: 100%|██████████| 2678/2678 [06:06<00:00,  7.31it/s]
dev: loss=0.3346, acc=0.7578: 100%|██████████| 500/500 [00:32<00:00, 15.58it/s]


epoch 30


train: loss=0.2633, acc=0.8021: 100%|██████████| 2678/2678 [06:05<00:00,  7.33it/s]
dev: loss=0.3168, acc=0.7477: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s]


epoch 31


train: loss=0.2450, acc=0.8121: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=0.3279, acc=0.7628: 100%|██████████| 500/500 [00:32<00:00, 15.46it/s]


epoch 32


train: loss=0.2345, acc=0.8239: 100%|██████████| 2678/2678 [06:04<00:00,  7.34it/s]
dev: loss=0.3625, acc=0.7538: 100%|██████████| 500/500 [00:32<00:00, 15.55it/s]


epoch 33


train: loss=0.2345, acc=0.8243: 100%|██████████| 2678/2678 [06:03<00:00,  7.36it/s]
dev: loss=0.3248, acc=0.7618: 100%|██████████| 500/500 [00:32<00:00, 15.56it/s]


epoch 34


train: loss=0.2715, acc=0.7938: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=0.3177, acc=0.7487: 100%|██████████| 500/500 [00:32<00:00, 15.54it/s]


epoch 35


train: loss=0.2448, acc=0.8136: 100%|██████████| 2678/2678 [06:05<00:00,  7.34it/s]
dev: loss=0.3142, acc=0.7638: 100%|██████████| 500/500 [00:32<00:00, 15.44it/s]


epoch 36


train: loss=0.2281, acc=0.8288: 100%|██████████| 2678/2678 [06:06<00:00,  7.31it/s]
dev: loss=0.3680, acc=0.7518: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s]


epoch 37


train: loss=0.2265, acc=0.8317: 100%|██████████| 2678/2678 [06:06<00:00,  7.31it/s]
dev: loss=0.3473, acc=0.7588: 100%|██████████| 500/500 [00:32<00:00, 15.46it/s]


epoch 38


train: loss=0.2330, acc=0.8263: 100%|██████████| 2678/2678 [06:04<00:00,  7.34it/s]
dev: loss=0.3554, acc=0.7497: 100%|██████████| 500/500 [00:32<00:00, 15.37it/s]


epoch 39


train: loss=0.2277, acc=0.8261: 100%|██████████| 2678/2678 [06:05<00:00,  7.33it/s]
dev: loss=0.3790, acc=0.7407: 100%|██████████| 500/500 [00:32<00:00, 15.42it/s]


epoch 40


train: loss=0.2221, acc=0.8304: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=0.3345, acc=0.7407: 100%|██████████| 500/500 [00:32<00:00, 15.40it/s]


epoch 41


train: loss=0.2417, acc=0.8144: 100%|██████████| 2678/2678 [06:06<00:00,  7.31it/s]
dev: loss=0.3752, acc=0.7347: 100%|██████████| 500/500 [00:32<00:00, 15.50it/s]


epoch 42


train: loss=0.2384, acc=0.8174: 100%|██████████| 2678/2678 [06:06<00:00,  7.31it/s]
dev: loss=0.3535, acc=0.7588: 100%|██████████| 500/500 [00:32<00:00, 15.43it/s]


epoch 43


train: loss=0.2198, acc=0.8303: 100%|██████████| 2678/2678 [06:06<00:00,  7.32it/s]
dev: loss=0.3527, acc=0.7467: 100%|██████████| 500/500 [00:32<00:00, 15.48it/s]


epoch 44


train: loss=0.2125, acc=0.8400: 100%|██████████| 2678/2678 [06:05<00:00,  7.33it/s]
dev: loss=0.3540, acc=0.7447: 100%|██████████| 500/500 [00:32<00:00, 15.47it/s]


epoch 45


train: loss=0.2063, acc=0.8476: 100%|██████████| 2678/2678 [06:04<00:00,  7.34it/s]
dev: loss=0.3499, acc=0.7518: 100%|██████████| 500/500 [00:32<00:00, 15.53it/s]


epoch 46


train: loss=0.2001, acc=0.8514: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=0.3517, acc=0.7467: 100%|██████████| 500/500 [00:32<00:00, 15.48it/s]


epoch 47


train: loss=0.1885, acc=0.8646: 100%|██████████| 2678/2678 [06:05<00:00,  7.33it/s]
dev: loss=0.3732, acc=0.7477: 100%|██████████| 500/500 [00:32<00:00, 15.37it/s]


epoch 48


train: loss=0.1811, acc=0.8650: 100%|██████████| 2678/2678 [06:05<00:00,  7.32it/s]
dev: loss=0.4366, acc=0.7227: 100%|██████████| 500/500 [00:32<00:00, 15.33it/s]


epoch 49


train: loss=0.1723, acc=0.8726: 100%|██████████| 2678/2678 [06:04<00:00,  7.34it/s]
dev: loss=0.4849, acc=0.7387: 100%|██████████| 500/500 [00:32<00:00, 15.41it/s]


In [14]:
parameters2 = list(model.parameters())

In [15]:
parameters2[0]

Parameter containing:
tensor([[  39.9078, -138.6385,  217.1636,  ...,   29.5207,  145.8943,
          144.5315],
        [ 515.6390,  -31.6162,   51.5134,  ...,  -53.0618,  245.4292,
          -69.3007],
        [ 236.2896,   43.0374,  -19.2581,  ..., -127.3152,  130.7397,
           31.1689],
        ...,
        [  88.1486,   49.1501,  125.5696,  ...,  113.4881,   96.0846,
          368.0652],
        [ 100.6963, -102.7619,  -35.8637,  ..., -144.5385,  -25.3403,
          173.1718],
        [-164.1508,  -81.5056,  152.1980,  ..., -178.5098,    6.0514,
         -129.9609]], device='cuda:0', requires_grad=True)

In [16]:
parameters2[2]

Parameter containing:
tensor([[ 0.0099,  0.0084,  0.0172,  ...,  0.0220,  0.0435, -0.0337],
        [ 0.0112, -0.0181, -0.0107,  ...,  0.0227,  0.0190,  0.0033],
        [ 0.0061,  0.0430,  0.0625,  ..., -0.0334, -0.0130,  0.0205],
        ...,
        [ 0.0034,  0.0228,  0.0003,  ...,  0.0113, -0.0045, -0.0222],
        [ 0.0297, -0.0042, -0.0393,  ...,  0.0037, -0.0145, -0.0023],
        [ 0.0053, -0.0029,  0.0157,  ..., -0.0125,  0.0068,  0.0106]],
       device='cuda:0')

In [17]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)

    decoder_input_ids = torch.full((1, n_tokens), 1)
    with torch.no_grad():
        outputs = model(input_ids=inputs['input_ids'].cuda(), decoder_input_ids=decoder_input_ids.cuda())
    logits = outputs['logits'][:, -1, 3:6]
    pred = logits.argmax(-1).detach().cpu().numpy()[0]
    # print(logits)
    return pred

In [18]:
train_rets = []
for i in tqdm(range(len(title_train))):
    pred = predict(title_train[i])
    train_rets.append((label_train[i], pred, title_train[i]))

100%|██████████| 5355/5355 [04:39<00:00, 19.19it/s]


In [None]:
rets = []
for i in tqdm(range(len(title_test))):
    pred = predict(title_test[i])
    rets.append((label_test[i], pred, title_test[i]))

In [23]:
print(
    accuracy_score(
        [x[0] for x in train_rets],
        [x[1] for x in train_rets],
    )
)

0.861624649859944


In [24]:
print(
    accuracy_score(
        [x[0] for x in rets],
        [x[1] for x in rets],
    )
)

0.7447447447447447


In [25]:
print(
    accuracy_score(
        [x[0] for x in rets],
        [0] * len(rets),
    ),
    accuracy_score(
        [x[0] for x in rets],
        [1] * len(rets),
    ),
    accuracy_score(
        [x[0] for x in rets],
        [2] * len(rets),
    )
)

0.0990990990990991 0.4944944944944945 0.4064064064064064
