<a href="https://colab.research.google.com/github/tsato-code/bert/blob/main/20220326_multilabel_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# このノートブックの概要

- ストックマーク社の BERT 本7章を動作確認。
- 内容はマルチラベル文章分類。

TODO
- `with torch.no_grad()` をどう使い分けるのか。本では学習時には `with torch.no_grad()` せず、推論時に `with torch.no_grad()` している。

In [1]:
### fugashi は 形態素解析ツール Mecab を Python から使えるようにしたもの
### ipadic は Mecab で形態素解析を利用するときに使う辞書
!pip install -q transformers==4.15.0 fugashi==1.1.0 ipadic==1.0.0 pytorch-lightning==1.5.8

[K     |████████████████████████████████| 3.4 MB 15.2 MB/s 
[K     |████████████████████████████████| 486 kB 65.4 MB/s 
[K     |████████████████████████████████| 13.4 MB 54.5 MB/s 
[K     |████████████████████████████████| 526 kB 57.4 MB/s 
[K     |████████████████████████████████| 67 kB 6.3 MB/s 
[K     |████████████████████████████████| 3.3 MB 65.4 MB/s 
[K     |████████████████████████████████| 895 kB 57.5 MB/s 
[K     |████████████████████████████████| 596 kB 57.6 MB/s 
[K     |████████████████████████████████| 829 kB 60.5 MB/s 
[K     |████████████████████████████████| 134 kB 67.3 MB/s 
[K     |████████████████████████████████| 398 kB 71.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 41.5 MB/s 
[K     |████████████████████████████████| 271 kB 60.3 MB/s 
[K     |████████████████████████████████| 144 kB 55.5 MB/s 
[K     |████████████████████████████████| 94 kB 2.8 MB/s 
[?25h  Building wheel for ipadic (setup.py) ... [?25l[?25hdone
  Building wheel for f

In [2]:
### グローバル変数
JAPANESE_WIKI_MODEL = 'cl-tohoku/bert-base-japanese-whole-word-masking'

In [3]:
### ライブラリのインポート
import random
import glob
import json
import pytorch_lightning as pl
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertModel

In [4]:
class BertForSequenceClassificationMultiLabel(torch.nn.Module):
    
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.linear = torch.nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(
        self, 
        input_ids=None, 
        attention_mask=None, 
        token_type_ids=None, 
        labels=None
    ):
        ### BERT の最終層の出力
        bert_output = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids
        )
        last_hidden_state = bert_output.last_hidden_state

        ### [PAD] 以外のトークンで隠れ状態の平均をとる
        averaged_hidden_state = \
            (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) \
            / attention_mask.sum(1, keepdim=True)
        
        ### 線形変換
        scores = self.linear(averaged_hidden_state)

        ### 出力形式を整える
        output = {'logits': scores}

        ### labels が入力に含まれていたら、損失を計算し出力
        if labels is not None:
            loss = torch.nn.BCEWithLogitsLoss()(scores, labels.float())
            output['loss'] = loss
        
        ### 属性でアクセスできるようにする
        output = type('bert_output', (object,), output)

        return output

In [5]:
tokenizer = BertJapaneseTokenizer.from_pretrained(JAPANESE_WIKI_MODEL)
bert_scml = BertForSequenceClassificationMultiLabel(JAPANESE_WIKI_MODEL, num_labels=2)
bert_scml = bert_scml.cuda()

Downloading:   0%|          | 0.00/252k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/110 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/479 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/424M [00:00<?, ?B/s]

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
text_list = [
    '恵比寿から北千住に向かう。',
    '海老名SAに寄ったが財布を忘れた。'
]

labels_list = [
    [0, 1],
    [1, 1]
]

### 符号化（推論のとき）
encoding = tokenizer(
    text_list, padding='longest', return_tensors='pt'
)
encoding = { k: v.cuda() for k, v in encoding.items() }
labels = torch.tensor(labels_list).cuda()

### BERT に入力して分類スコアを得る
with torch.no_grad():
    output = bert_scml(**encoding)
scores = output.logits

### スコアが正ならば、そのカテゴリーを選択
labels_predicted = ( scores > 0 ).int()

### 精度計算
num_correct = ( labels_predicted == labels ).all(-1).sum().item()
accuracy = num_correct / labels.size(0)
print(f'# accuracy: {accuracy}')

# accuracy: 1.0


In [7]:
### 符号化（学習のとき）
encoding = tokenizer(
    text_list, padding='longest', return_tensors='pt'
)
encoding['labels'] = torch.tensor(labels_list)
encoding = { k: v.cuda() for k, v in encoding.items() }

### BERT に入力して分類スコアを得る
with torch.no_grad():
    output = bert_scml(**encoding)
loss = output.loss
print(loss)

tensor(0.5909, device='cuda:0')


# chABSA-dataset によるマルチラベル文章分類

In [8]:
!wget https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip
!unzip chABSA-dataset.zip

--2022-03-26 14:06:47--  https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip
Resolving s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)... 52.219.4.56
Connecting to s3-ap-northeast-1.amazonaws.com (s3-ap-northeast-1.amazonaws.com)|52.219.4.56|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 722777 (706K) [application/zip]
Saving to: ‘chABSA-dataset.zip’


2022-03-26 14:06:49 (632 KB/s) - ‘chABSA-dataset.zip’ saved [722777/722777]

Archive:  chABSA-dataset.zip
   creating: chABSA-dataset/
  inflating: chABSA-dataset/.DS_Store  
   creating: __MACOSX/
   creating: __MACOSX/chABSA-dataset/
  inflating: __MACOSX/chABSA-dataset/._.DS_Store  
 extracting: chABSA-dataset/.gitkeep  
  inflating: chABSA-dataset/e00008_ann.json  
  inflating: chABSA-dataset/e00017_ann.json  
  inflating: chABSA-dataset/e00024_ann.json  
  inflating: chABSA-dataset/e00026_ann.json  
  inflating: chABSA-dataset/e00030_ann.json  
  in

In [9]:
data = json.load(open('./chABSA-dataset/e00030_ann.json'))
print(data['sentences'][0])

{'sentence_id': 0, 'sentence': '当期におけるわが国経済は、景気は緩やかな回復基調が続き、設備投資の持ち直し等を背景に企業収益は改善しているものの、海外では、資源国等を中心に不透明な状況が続き、為替が急激に変動するなど、依然として先行きが見通せない状況で推移した', 'opinions': [{'target': 'わが国経済', 'category': 'NULL#general', 'polarity': 'neutral', 'from': 6, 'to': 11}, {'target': '景気', 'category': 'NULL#general', 'polarity': 'positive', 'from': 13, 'to': 15}, {'target': '設備投資', 'category': 'NULL#general', 'polarity': 'positive', 'from': 28, 'to': 32}, {'target': '企業収益', 'category': 'NULL#general', 'polarity': 'positive', 'from': 42, 'to': 46}, {'target': '資源国等', 'category': 'NULL#general', 'polarity': 'neutral', 'from': 62, 'to': 66}, {'target': '為替', 'category': 'NULL#general', 'polarity': 'negative', 'from': 80, 'to': 82}]}


In [38]:
category_id = {'negative': 0, 'neutral':1, 'positive': 2}
dataset = []
for file in glob.glob('chABSA-dataset/*.json'):
    data = json.load(open(file))
    for sentence in data['sentences']:
        text = sentence['sentence']
        labels = [0, 0, 0]
        for opinion in sentence['opinions']:
            labels[category_id[opinion['polarity']]] = 1
        sample = {'text': text, 'labels': labels}
        dataset.append(sample)

print(dataset[0])
print(dataset[1])
print(dataset[2])

{'text': '当連結会計年度におけるわが国経済は、政府の経済政策を背景に企業収益や雇用情勢の改善が見られるなど、緩やかな回復基調が続きました', 'labels': [0, 0, 1]}
{'text': '一方、海外においては、中国ならびにアジア新興国経済の減速のほか、米国新政権の政策や英国のＥＵ離脱が経済に与える影響も懸念されるなど、先行き不透明な状況のまま推移いたしました', 'labels': [1, 1, 0]}
{'text': '化学工業におきましては、原油価格の下落に伴い原燃料費用は低下したものの、引き続き厳しい事業環境にありました', 'labels': [1, 0, 1]}


In [44]:
### トークナイザ
tokenizer = BertJapaneseTokenizer.from_pretrained(JAPANESE_WIKI_MODEL)

### データの整形
max_length = 128
dataset_for_loader = []
for sampe in dataset:
    text = sample['text']
    labels = sample['labels']
    encoding = tokenizer(
        text, 
        max_length=max_length, 
        padding='max_length', 
        truncation=True
    )
    encoding['labels'] = labels
    encoding = { k: torch.tensor(v) for k, v in encoding.items() }
    dataset_for_loader.append(encoding)

### train/val/test
random.shuffle(dataset_for_loader)
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)
dataset_train = dataset_for_loader[:n_train]
dataset_val = dataset_for_loader[n_train:n_train+n_val]
dataset_test = dataset_for_loader[n_train+n_val:]

### データローダ作成
dataloader_train = DataLoader(dataset_train, batch_size=32, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=256)
dataloader_test = DataLoader(dataset_test, batch_size=256)

In [46]:
class BertForSequenceClassificationMultiLabel_pl(pl.LightningModule):

    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters()
        self.bert_scml = BertForSequenceClassificationMultiLabel(
            model_name, num_labels=num_labels
        )
    
    def training_step(self, batch, batch_idx):
        output = self.bert_scml(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        output = self.bert_scml(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)
    
    def test_step(self, batch, batch_idx):
        labels = batch.pop('labels')
        output = self.bert_scml(**batch)
        scores = output.logits
        labels_predicted = ( scores > 0 ).int()
        num_correct = ( labels_predicted == labels ).all(-1).sum().item()
        accuracy = num_correct / scores.size(0)
        self.log('accuracy', accuracy)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/',
)

trainer = pl.Trainer(gpus=1, max_epochs=5, callbacks=[checkpoint])

model = BertForSequenceClassificationMultiLabel_pl(JAPANESE_WIKI_MODEL, num_labels=3, lr=1e-5)
trainer.fit(model, dataloader_train, dataloader_val)
test = trainer.test(test_dataloaders=dataloader_test)
print(f'# accuracy: {test[0]["accuracy"]:.2f}')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
LOCAL_RANK: 0 - 

Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."
  f"`.{fn}(ckpt_path=None)` was called without a model."
Restoring states from the checkpoint path at /content/model/epoch=4-step=574.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /content/model/epoch=4-step=574.ckpt


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'accuracy': 1.0}
--------------------------------------------------------------------------------
# accuracy: 1.00


In [60]:
### 推論
text_list = [
    '恵比寿から北千住に向かう。',
    '海老名SAに寄ったが財布を忘れた。',
    '2020年に改正、公布された改正個人情報保護法が2022年4月から全面施行されます。「個人の権利利益の保護」、「情報活用の強化」、「AI・ビッグデータへの対応」などを目的に改正され、大きく6つの変更ポイントがあります。改正の目的、背景に加え、変更ポイントごとの改正内容とこれまでのルールとの違い、企業が準備しておくべき改正後の影響への対応について具体的に解説します。'
]

best_model_path = checkpoint.best_model_path
model = BertForSequenceClassificationMultiLabel_pl.load_from_checkpoint(best_model_path)
bert_scml = model.bert_scml.cuda()

encoding = tokenizer(text_list, padding='longest', return_tensors='pt')
encoding = { k: v.cuda() for k, v in encoding.items() }

with torch.no_grad():
    output = bert_scml(**encoding)
scores = output.logits
labels_predicted = ( scores > 0 ).int().cpu().numpy().tolist()

for text, label in zip(text_list, labels_predicted):
    print('--')
    print(f'入力 : {text}')
    print(f'出力 : {label}')

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


--
入力 : 恵比寿から北千住に向かう。
出力 : [0, 0, 0]
--
入力 : 海老名SAに寄ったが財布を忘れた。
出力 : [0, 0, 0]
--
入力 : 2020年に改正、公布された改正個人情報保護法が2022年4月から全面施行されます。「個人の権利利益の保護」、「情報活用の強化」、「AI・ビッグデータへの対応」などを目的に改正され、大きく6つの変更ポイントがあります。改正の目的、背景に加え、変更ポイントごとの改正内容とこれまでのルールとの違い、企業が準備しておくべき改正後の影響への対応について具体的に解説します。
出力 : [0, 0, 0]
