## 7-4　BERT によるマルチラベル分類

In [1]:
# !mkdir chap7
%cd ./chap7

/content/chap7


In [2]:
# !pip install transformers==4.18.0 fugashi===1.1.0 ipadic==1.0.0 pytorch-lightning==1.6.1

In [3]:
import glob
import json
import torch
import random
import numpy as np
import pytorch_lightning as pl

from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertModel

In [4]:
MODEL_NAME = 'tohoku-nlp/bert-base-japanese-whole-word-masking'

In [5]:
class BertForSequenceClassificationMultiLabel(torch.nn.Module):

  def __init__(self, model_name, num_labels):
    super().__init__()
    self.bert = BertModel.from_pretrained(model_name)
    self.linear = torch.nn.Linear(self.bert.config.hidden_size, num_labels)

  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
    bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    last_hidden_state = bert_output.last_hidden_state

    averaged_hidden_state = (last_hidden_state*attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)
    scores = self.linear(averaged_hidden_state)
    output = {'logits': scores}

    if labels is not None:
      loss = torch.nn.BCEWithLogitsLoss()(scores, labels.float())
      output['loss'] = loss

    output = type('bert_output', (object,), output)

    return output

In [6]:
masked_hidden_state = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
                                    [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]])
print(masked_hidden_state.shape)

torch.Size([2, 3, 4])


In [7]:
sum_hidden_state = masked_hidden_state.sum(1)
print(sum_hidden_state)
print(sum_hidden_state.shape)

tensor([[15, 18, 21, 24],
        [51, 54, 57, 60]])
torch.Size([2, 4])


In [8]:
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
bert_scml = BertForSequenceClassificationMultiLabel(MODEL_NAME, num_labels=2)
bert_scml = bert_scml.cuda()

Some weights of the model checkpoint at tohoku-nlp/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
text_list = ['今日の仕事はうまくいったが、体調があまり良くない。', '昨日は楽しかった。']
labels_list = [[1, 1], [0, 1]]

encoding = tokenizer(text_list, padding='longest', return_tensors='pt')
encoding = {k: v.cuda() for k, v in encoding.items()}
labels = torch.tensor(labels_list).cuda()

with torch.no_grad():
  output = bert_scml(**encoding)
scores = output.logits

labels_predicted = (scores > 0).int()

num_correct = (labels_predicted == labels).all(-1).sum().item()
accuracy = num_correct / labels.size(0)

In [10]:
encoding = tokenizer(text_list, padding='longest', return_tensors='pt')
encoding['labels'] = torch.tensor(labels_list)
encoding = {k: v.cuda() for k, v in encoding.items()}

output = bert_scml(**encoding)
loss = output.loss