In [1]:
!nvidia-smi

Sat Sep 11 03:45:27 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install transformers SentencePiece torch tqdm

Collecting transformers
  Downloading transformers-4.10.2-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 5.0 MB/s 
[?25hCollecting SentencePiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 34.6 MB/s 
Collecting huggingface-hub>=0.0.12
  Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)
[K     |████████████████████████████████| 50 kB 4.0 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 53.8 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 61.5 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |███████████████████████

In [3]:
import math

from tqdm import tqdm
import numpy as np
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score

In [4]:
class SoftEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                n_tokens: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to 
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                                  n_tokens, 
                                                                                  random_range, 
                                                                                  initialize_from_vocab))
            
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
            
    def forward(self, tokens):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

In [5]:
!pip install zh-dataset-inews

Collecting zh-dataset-inews
  Downloading zh_dataset_inews-0.0.2-py3-none-any.whl (11.3 MB)
[K     |████████████████████████████████| 11.3 MB 5.2 MB/s 
[?25hInstalling collected packages: zh-dataset-inews
Successfully installed zh-dataset-inews-0.0.2


In [6]:
from zh_dataset_inews import title_train, label_train, title_dev, label_dev, title_test, label_test

In [7]:
def generate_data(batch_size, n_tokens, title_data, label_data):

    labels = [
        torch.tensor([[3]]),  # \x00
        torch.tensor([[4]]),  # \x01
        torch.tensor([[5]]),  # \x02
    ]

    def yield_data(x_batch, y_batch, l_batch):
        x = torch.nn.utils.rnn.pad_sequence(x_batch, batch_first=True)
        y = torch.cat(y_batch, dim=0)
        m = (x > 0).to(torch.float32)
        decoder_input_ids = torch.full((x.size(0), n_tokens), 1)
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()
            m = m.cuda()
            decoder_input_ids = decoder_input_ids.cuda()
        return x, y, m, decoder_input_ids, l_batch

    x_batch, y_batch, l_batch = [], [], []
    for x, y in zip(title_data, label_data):
        context = x
        inputs = tokenizer(context, return_tensors="pt")
        inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)
        l_batch.append(y)
        y = labels[y]
        y = torch.cat([torch.full((1, n_tokens - 1), -100), y], 1)
        x_batch.append(inputs['input_ids'][0])
        y_batch.append(y)
        if len(x_batch) >= batch_size:
            yield yield_data(x_batch, y_batch, l_batch)
            x_batch, y_batch, l_batch = [], [], []

    if len(x_batch) > 0:
        yield yield_data(x_batch, y_batch, l_batch)
        x_batch, y_batch, l_batch = [], [], []

In [8]:
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-base")
n_tokens = 100
s_wte = SoftEmbedding(model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=True)
model.set_input_embeddings(s_wte)
if torch.cuda.is_available():
    model = model.cuda()

Downloading:   0%|          | 0.00/702 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/376 [00:00<?, ?B/s]

In [9]:
parameters = list(model.parameters())
for x in parameters[1:]:
    x.requires_grad = False

In [10]:
parameters[0]

Parameter containing:
tensor([[ 1.7500e+00, -1.6719e+00,  2.4062e+00,  ...,  6.9580e-03,
         -9.8828e-01, -4.6875e-01],
        [ 8.5625e+00,  5.5625e+00, -1.7109e+00,  ...,  7.7812e+00,
         -5.2812e+00, -3.2188e+00],
        [ 6.8750e-01, -4.5312e-01,  5.7812e-01,  ...,  7.3828e-01,
         -3.0078e-01,  2.0312e-01],
        ...,
        [-4.9219e-01,  1.9141e-01, -4.3555e-01,  ..., -8.0469e-01,
         -4.3359e-01,  5.8594e-01],
        [ 4.9609e-01,  1.1797e+00,  3.7109e-01,  ...,  1.7090e-01,
         -2.5195e-01, -3.3789e-01],
        [ 1.6328e+00,  3.4961e-01,  3.9062e-01,  ...,  1.9336e-01,
         -7.4219e-01,  3.1836e-01]], device='cuda:0', requires_grad=True)

In [11]:
parameters[2]

Parameter containing:
tensor([[-1.3977e-02,  3.8818e-02,  5.7129e-02,  ...,  4.9316e-02,
         -8.1177e-03, -3.8147e-03],
        [ 6.3965e-02, -1.0193e-02, -2.0020e-02,  ..., -8.3618e-03,
         -1.1902e-02, -2.6978e-02],
        [-1.6357e-02, -4.4922e-02,  4.8584e-02,  ..., -1.6479e-02,
         -4.0039e-02,  6.3782e-03],
        ...,
        [ 7.7820e-03, -6.5918e-03, -3.9062e-03,  ...,  1.9165e-02,
          7.4863e-05, -2.6001e-02],
        [-1.4587e-02,  1.8433e-02, -2.6489e-02,  ..., -3.9062e-02,
         -4.0527e-02,  4.1992e-02],
        [ 7.8125e-02,  1.6602e-02,  6.4941e-02,  ...,  4.2152e-04,
          4.5166e-02, -1.1780e-02]], device='cuda:0')

In [12]:
for x, y, m, dii, true_labels in generate_data(8, n_tokens, title_train, label_train):
    assert dii.shape == y.shape
    outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)
    assert outputs['logits'].shape[:2] == y.shape
    pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()
    break

In [13]:
batch_size = 16
n_epoch = 50
total_batch = math.ceil(len(title_train) / batch_size)
dev_total_batch = math.ceil(len(title_dev) / batch_size)
use_ce_loss = False
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(s_wte.parameters(), lr=0.5)

for epoch in range(n_epoch):
    print('epoch', epoch)

    all_true_labels = []
    all_pred_labels = []
    losses = []
    pbar = tqdm(enumerate(generate_data(batch_size, n_tokens, title_train, label_train)), total=total_batch)
    for i, (x, y, m, dii, true_labels) in pbar:
        all_true_labels += true_labels
        
        optimizer.zero_grad()
        outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)
        pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()
        all_pred_labels += pred_labels

        if use_ce_loss:
            logits = outputs['logits'][:, -1, 3:6]
            true_labels_tensor = torch.tensor(true_labels, dtype=torch.long).cuda()
            loss = ce_loss(logits, true_labels_tensor)
        else:
            loss = outputs.loss
        loss.backward()
        optimizer.step()
        loss_value = float(loss.detach().cpu().numpy().tolist()) / batch_size
        losses.append(loss_value)

        acc = accuracy_score(all_true_labels, all_pred_labels)
        pbar.set_description(f'train: loss={np.mean(losses):.4f}, acc={acc:.4f}')

    all_true_labels = []
    all_pred_labels = []
    losses = []
    with torch.no_grad():
        pbar = tqdm(enumerate(generate_data(batch_size, n_tokens, title_dev, label_dev)), total=dev_total_batch)
        for i, (x, y, m, dii, true_labels) in pbar:
            all_true_labels += true_labels
            outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)
            loss = outputs.loss
            loss_value = float(loss.detach().cpu().numpy().tolist()) / batch_size
            losses.append(loss_value)
            pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()
            all_pred_labels += pred_labels
            acc = accuracy_score(all_true_labels, all_pred_labels)
            pbar.set_description(f'dev: loss={np.mean(losses):.4f}, acc={acc:.4f}')

epoch 0


train: loss=1.2467, acc=0.2045: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.6062, acc=0.4104: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 1


train: loss=0.5303, acc=0.3890: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.4836, acc=0.4194: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 2


train: loss=0.4756, acc=0.4129: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.4530, acc=0.4104: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 3


train: loss=0.4401, acc=0.4411: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.4246, acc=0.4364: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 4


train: loss=0.4274, acc=0.4693: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.3864, acc=0.4935: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 5


train: loss=0.3643, acc=0.4233: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.3957, acc=0.4525: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 6


train: loss=0.3134, acc=0.3261: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.2806, acc=0.3103: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 7


train: loss=0.2656, acc=0.3294: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.2206, acc=0.4535: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 8


train: loss=0.1937, acc=0.3789: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.2822, acc=0.4935: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 9


train: loss=0.1487, acc=0.4204: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.1525, acc=0.5095: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 10


train: loss=0.1217, acc=0.4474: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.1070, acc=0.4905: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 11


train: loss=0.1196, acc=0.4394: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.1106, acc=0.5345: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 12


train: loss=0.0889, acc=0.4784: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0781, acc=0.5005: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 13


train: loss=0.0784, acc=0.5102: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0844, acc=0.5015: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 14


train: loss=0.0693, acc=0.5221: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0873, acc=0.4915: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 15


train: loss=0.0655, acc=0.5451: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0872, acc=0.5526: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 16


train: loss=0.2844, acc=0.4585: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=1.0427, acc=0.0861: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 17


train: loss=0.4167, acc=0.3049: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.3308, acc=0.2783: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 18


train: loss=0.3083, acc=0.2338: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.4314, acc=0.4505: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 19


train: loss=0.3152, acc=0.2866: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.3012, acc=0.4765: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 20


train: loss=0.2597, acc=0.3414: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.2011, acc=0.3163: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 21


train: loss=0.1745, acc=0.3834: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.1191, acc=0.4925: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 22


train: loss=0.0879, acc=0.4422: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0927, acc=0.4925: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 23


train: loss=0.0772, acc=0.4594: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0680, acc=0.4925: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 24


train: loss=0.0719, acc=0.4777: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0741, acc=0.4945: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 25


train: loss=0.0718, acc=0.4885: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0969, acc=0.4925: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 26


train: loss=0.0690, acc=0.5160: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0815, acc=0.5045: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 27


train: loss=0.0646, acc=0.5561: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0670, acc=0.5716: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 28


train: loss=0.0660, acc=0.5647: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0780, acc=0.5616: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 29


train: loss=0.0653, acc=0.6069: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0552, acc=0.7097: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 30


train: loss=0.0652, acc=0.6349: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0544, acc=0.7217: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 31


train: loss=0.0595, acc=0.6689: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0772, acc=0.7227: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 32


train: loss=0.0541, acc=0.7068: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0627, acc=0.7648: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 33


train: loss=0.0478, acc=0.7266: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0648, acc=0.7778: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 34


train: loss=0.0477, acc=0.7374: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0510, acc=0.7708: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 35


train: loss=0.0454, acc=0.7468: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0659, acc=0.7828: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 36


train: loss=0.0464, acc=0.7511: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0688, acc=0.7728: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 37


train: loss=0.0444, acc=0.7630: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0588, acc=0.7618: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 38


train: loss=0.0427, acc=0.7632: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0623, acc=0.7838: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 39


train: loss=0.0424, acc=0.7619: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0622, acc=0.7758: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 40


train: loss=0.0396, acc=0.7787: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0514, acc=0.7768: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 41


train: loss=0.0378, acc=0.7882: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0593, acc=0.7978: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 42


train: loss=0.0396, acc=0.7862: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0497, acc=0.7988: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 43


train: loss=0.0379, acc=0.7907: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0458, acc=0.8058: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 44


train: loss=0.0394, acc=0.7847: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0571, acc=0.8068: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 45


train: loss=0.3734, acc=0.3819: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.1661, acc=0.4484: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 46


train: loss=0.1279, acc=0.4310: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0930, acc=0.5526: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 47


train: loss=0.0717, acc=0.5380: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0986, acc=0.4985: 100%|██████████| 63/63 [00:15<00:00,  4.10it/s]


epoch 48


train: loss=0.0625, acc=0.6308: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0876, acc=0.6937: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


epoch 49


train: loss=0.0550, acc=0.7066: 100%|██████████| 335/335 [02:50<00:00,  1.96it/s]
dev: loss=0.0598, acc=0.7518: 100%|██████████| 63/63 [00:15<00:00,  4.11it/s]


In [14]:
parameters2 = list(model.parameters())

In [15]:
parameters2[0]

Parameter containing:
tensor([[ 16.9844, -27.7936, -24.9008,  ...,   6.1309,  -6.2886,  -1.4139],
        [  6.2328,  29.8126,  -4.9943,  ...,  19.5903, -31.4742,  14.9552],
        [ 14.4980,   2.2994, -30.5098,  ...,   3.5612,   5.8513,  -4.0489],
        ...,
        [ 13.1602,  38.7834,  26.7747,  ..., -11.5163,  -0.2207,   4.6533],
        [ 22.6867,  28.9404,  -9.7768,  ...,   4.4906,  11.8440,   1.6058],
        [ 40.6202,  -8.4282,  31.4642,  ...,  -5.5024, -15.7918, -23.5409]],
       device='cuda:0', requires_grad=True)

In [16]:
parameters2[2]

Parameter containing:
tensor([[-1.3977e-02,  3.8818e-02,  5.7129e-02,  ...,  4.9316e-02,
         -8.1177e-03, -3.8147e-03],
        [ 6.3965e-02, -1.0193e-02, -2.0020e-02,  ..., -8.3618e-03,
         -1.1902e-02, -2.6978e-02],
        [-1.6357e-02, -4.4922e-02,  4.8584e-02,  ..., -1.6479e-02,
         -4.0039e-02,  6.3782e-03],
        ...,
        [ 7.7820e-03, -6.5918e-03, -3.9062e-03,  ...,  1.9165e-02,
          7.4863e-05, -2.6001e-02],
        [-1.4587e-02,  1.8433e-02, -2.6489e-02,  ..., -3.9062e-02,
         -4.0527e-02,  4.1992e-02],
        [ 7.8125e-02,  1.6602e-02,  6.4941e-02,  ...,  4.2152e-04,
          4.5166e-02, -1.1780e-02]], device='cuda:0')

In [17]:
def predict(text):
    inputs = tokenizer(text, return_tensors='pt')
    inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)

    decoder_input_ids = torch.full((1, n_tokens), 1)
    with torch.no_grad():
        outputs = model(input_ids=inputs['input_ids'].cuda(), decoder_input_ids=decoder_input_ids.cuda())
    logits = outputs['logits'][:, -1, 3:6]
    pred = logits.argmax(-1).detach().cpu().numpy()[0]
    # print(logits)
    return pred

In [18]:
train_rets = []
for i in tqdm(range(len(title_train))):
    pred = predict(title_train[i])
    train_rets.append((label_train[i], pred, title_train[i]))

100%|██████████| 5355/5355 [03:45<00:00, 23.80it/s]


In [19]:
rets = []
for i in tqdm(range(len(title_test))):
    pred = predict(title_test[i])
    rets.append((label_test[i], pred, title_test[i]))

100%|██████████| 999/999 [00:42<00:00, 23.27it/s]


In [20]:
print(
    accuracy_score(
        [x[0] for x in train_rets],
        [x[1] for x in train_rets],
    )
)

0.7454715219421102


In [21]:
print(
    accuracy_score(
        [x[0] for x in rets],
        [x[1] for x in rets],
    )
)

0.7387387387387387


In [23]:
print(
    accuracy_score(
        [x[0] for x in rets],
        [0] * len(rets),
    ),
    accuracy_score(
        [x[0] for x in rets],
        [1] * len(rets),
    ),
    accuracy_score(
        [x[0] for x in rets],
        [2] * len(rets),
    )
)

0.0990990990990991 0.4944944944944945 0.4064064064064064
