<a href="https://colab.research.google.com/github/straxFromIbr/NLP_with_BERT/blob/main/7MultiLabelCls.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
ON_COLAB = "COLAB_GPU" in os.environ

In [3]:
from datetime import date
modelpath = f"./model_{date.today().strftime('%y-%m-%d')}"
if ON_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    modelpath = '/content/drive/MyDrive/Colab/NLPwithBERT/Section7/'+modelpath
    

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## ライブラリインストール

In [4]:
if ON_COLAB:
    ! pip install -U pip 2>&1 >/dev/null
    ! pip install -U \
        transformers \
        fugashi \
        ipadic \
        pytorch-lightning  2>&1 >/dev/null



In [29]:
import random
import glob
import json
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertModel
import pytorch_lightning as pl

MODEL_NAME = "cl-tohoku/bert-base-japanese-whole-word-masking"

In [6]:
class BertForSequenceClassificationMultiLabel(torch.nn.Module):
    def __init__(self, model_name, num_labels) -> None:
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.linear = torch.nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
    ):
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        last_hidden_state = bert_output.last_hidden_state

        # [PAD]以外のトークンの平均
        average_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(
            1
        ) / attention_mask.sum(1, keepdim=True)

        # 線形変換
        scores = self.linear(average_hidden_state)

        # 出力フォーマット
        output = {"logits": scores}
        if labels is not None:
            loss = torch.nn.BCEWithLogitsLoss()(scores, labels.float())
            output["loss"] = loss

        output = type("bert_output", (object,), output)

        return output

In [7]:
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
bert_scml = BertForSequenceClassificationMultiLabel(MODEL_NAME, num_labels=2)
bert_scml = bert_scml.cuda()

Downloading:   0%|          | 0.00/252k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/110 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/479 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/424M [00:00<?, ?B/s]

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
text_list = [
    "今日の仕事はうまくいったが、体調があまり良くない。",
    "昨日は楽しかった。",
]
labels_list = [
    [1, 1],
    [0, 1],
]

encoding = tokenizer(text_list, padding="longest", return_tensors="pt")
labels = torch.tensor(labels_list)

encoding = {k: v.cuda() for k, v in encoding.items()}
labels = labels.cuda()

with torch.no_grad():
    output = bert_scml(**encoding)
scores = output.logits

labels_predicted = (scores > 0).int()

num_correct = (labels_predicted == labels).all(-1).sum().item()
accuracy = num_correct / labels.size(0)

print(labels_predicted)
print(accuracy)

tensor([[0, 1],
        [0, 1]], device='cuda:0', dtype=torch.int32)
0.5


In [11]:
encoding = tokenizer(text_list, padding="longest", return_tensors="pt")
encoding["labels"] = torch.tensor(labels_list)
encoding = {k: v.cuda() for k, v in encoding.items()}

output = bert_scml(**encoding)
loss = output.loss
print(loss)


tensor(0.6128, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


In [14]:
! wget 'https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip'
! unzip chABSA-dataset.zip >&/dev/null

--2021-12-22 09:44:42--  https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip
Resolving s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)... 52.219.12.66
Connecting to s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)|52.219.12.66|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 722777 (706K) [application/zip]
Saving to: ‘chABSA-dataset.zip’


2021-12-22 09:44:43 (767 KB/s) - ‘chABSA-dataset.zip’ saved [722777/722777]



In [16]:
data = json.load(open('chABSA-dataset/e00030_ann.json'))
data['sentences'][0]

{'opinions': [{'category': 'NULL#general',
   'from': 6,
   'polarity': 'neutral',
   'target': 'わが国経済',
   'to': 11},
  {'category': 'NULL#general',
   'from': 13,
   'polarity': 'positive',
   'target': '景気',
   'to': 15},
  {'category': 'NULL#general',
   'from': 28,
   'polarity': 'positive',
   'target': '設備投資',
   'to': 32},
  {'category': 'NULL#general',
   'from': 42,
   'polarity': 'positive',
   'target': '企業収益',
   'to': 46},
  {'category': 'NULL#general',
   'from': 62,
   'polarity': 'neutral',
   'target': '資源国等',
   'to': 66},
  {'category': 'NULL#general',
   'from': 80,
   'polarity': 'negative',
   'target': '為替',
   'to': 82}],
 'sentence': '当期におけるわが国経済は、景気は緩やかな回復基調が続き、設備投資の持ち直し等を背景に企業収益は改善しているものの、海外では、資源国等を中心に不透明な状況が続き、為替が急激に変動するなど、依然として先行きが見通せない状況で推移した',
 'sentence_id': 0}

In [23]:
category_id = {'negative':0,'neutral':1,'positive':2}

dataset = []
for file in glob.glob('chABSA-dataset/*.json'):
    data = json.load(open(file))
    for sentence in data['sentences']:
        text = sentence['sentence']
        labels = [0,0,0]
        for opinion in sentence['opinions']:
            labels[category_id[opinion['polarity']]] = 1
        sample = {'text': text, 'labels':labels}
        dataset.append(sample)
print(dataset[0])

{'text': '〈経済環境〉\n\u3000当連結会計年度における世界経済は、年前半の米国経済の足踏みや英国のEU離脱決定による金融市場の混乱等を背景に減速したものの、年後半はトランプ新政権が掲げた大型減税、インフラ投資等の政策への期待により米国経済は持ち直し、英国のEU離脱への主要国中央銀行による迅速な対応によって緩やかに回復が進みました', 'labels': [1, 0, 1]}


In [28]:
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)

max_length = 128
dataset_for_loader = []
for sample in dataset:
    text = sample['text']
    labels = sample['labels']

    encoding = tokenizer(
        text,
        max_length=max_length,
        padding='max_length',
        truncation=True,
    )
    encoding['labels'] = labels
    encoding = {k:torch.tensor(v) for k, v in encoding.items()}
    dataset_for_loader.append(encoding)

random.seed(13)
random.shuffle(dataset_for_loader)
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)

dataset_train = dataset_for_loader[:n_train]
dataset_val = dataset_for_loader[n_train:n_train+n_val]
dataset_test = dataset_for_loader[n_train+n_val:]

dataloader_train = DataLoader(dataset_train, batch_size=32, shuffle=True)
dataloader_val = DataLoader(dataset_val)
dataloader_test = DataLoader(dataset_test)


In [37]:
class BertForSequenceClassificationMultiLabel_pl(pl.LightningModule):
    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters()
        self.bert_scml = BertForSequenceClassificationMultiLabel(model_name, num_labels=num_labels)

        def training_step(self, batch, batch_idx):
            output = self.bert_scml(**batch)
            loss = output.loss
            self.log('train_loss', loss)
            return loss

        def validation_step(self, batch, batch_idx):
            output = self.bert_scml(**batch)
            val_loss = output.loss
            self.log('val_loss', val_loss)

        def test_step(self, batch, batch_idx):
            labels = batch.pop('labels')
            output = self.bert_scml(**batch)
            scores = output.logits
            labels_predicted = (scores > 0 ).int()
            num_correct = (labels_predicted == labels).all(-1).sum().item()
            accuracy = num_correct/scores.size(0)
            self.log('accuracy', accuracy)

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [40]:
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model.'
)

trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    # callbacks = [checkpoint]
)

model = BertForSequenceClassificationMultiLabel_pl(
    MODEL_NAME, num_labels=3, lr=1e-5,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [41]:
trainer.fit(model, dataloader_train, dataloader_val)


MisconfigurationException: ignored

In [33]:
test = trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.2}')

  "`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."


MisconfigurationException: ignored