# このノートブックについて
* OpenPromptを使って、Livedoorニュースコーパスの分類問題を解きます
* [公式のチュートリアルスクリプト](https://github.com/thunlp/OpenPrompt/blob/fe7f4cbb719e796311973c882883773c8306c4b2/tutorial/1.2_soft_verbalizers.py)をベースに、Colab上で動かせるように改変しています
* 学習に時間がかかるため、Colab利用時はランタイムタイプを変更してGPUを有効化することを推奨します

# Google Drive設定
Colab環境でファイルを永続化するため、Google Driveをマウントして保存用ディレクトリを作成

* マウント時に認証を要求されるので許可が必要

In [None]:
from google.colab import drive
drive.mount('./gdrive')

In [None]:
!mkdir -p gdrive/MyDrive/openprompt/result gdrive/MyDrive/openprompt/models

# 必要なライブラリをインストール
* 注意
    * ColabのPythonバージョンが3.7系のため、OpenPromptの0.1.1までしかインストールできません
    * 自前のJupyter環境で動かす場合は、最新バージョンを利用してもよいかもです（そのままコードが動くかは保証しませんが）

In [None]:
!pip install openprompt==0.1.1 \
ja_sentence_segmenter \
'torch>=1.9.0' \
'transformers>=4.10.0' \
sentencepiece==0.1.96 \
'scikit-learn>=0.24.2' \
'tqdm>=4.62.2' \
tensorboardX \
nltk \
yacs \
dill \
datasets \
rouge==1.0.0 \
scipy==1.4.1 \
fugashi \
ipadic \
unidic-lite

In [None]:
!mkdir -p data
!test -e ldcc-20140209.tar.gz || wget -O ldcc-20140209.tar.gz https://www.rondhuit.com/download/ldcc-20140209.tar.gz

In [None]:
# 実行ごとにファイル名がかぶらないようにランダム文字列を取得
import random
this_run_unicode = str(random.randint(0, 1e10))

# 変数宣言
[公式チュートリアルのスクリプト](https://github.com/thunlp/OpenPrompt/blob/main/tutorial/1.4_soft_template.py)で使われているコマンドライン引数等をここで設定できるようにしてあります。

各変数の説明は、上記スクリプトをご参照ください。

In [None]:
seed = 42

# FewShotラーニング時のラベルごとのデータ数
num_examples_per_label = 16 # Noneで全量使用

# プロンプトラーニング用パラメータ
# model = "t5"
# model_name_or_path = "sonoisa/t5-base-japanese"
model = "bert-ja"
model_name_or_path = "cl-tohoku/bert-large-japanese"
shot = 1
tune_plm = True
plm_eval_mode = False
max_steps = 1000
eval_every_steps = 100
prompt_lr = 0.3
warmup_step_prompt = 100
optimizer = "Adafactor"
multi_token_handler = "max" # "first" or "mean" or "max"
truncate_method="tail" # "head" or "tail" or "balanced"


# 環境スペックに応じて適宜変更してください
# batchsize_t = 4
# batchsize_e = 4
batchsize_t = 2
batchsize_e = 2
max_seq_l = 512
gradient_accumulation_steps = 8
model_parallelize = False
use_cuda = True

# 結果の保存先・ファイル名
project_root = "/content/gdrive/MyDrive/openprompt/"
import datetime
dt_now = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
result_file = f"results/results-{model}-{shot}-{dt_now}.txt"

In [None]:
from openprompt.utils.reproduciblity import set_seed

set_seed(seed)

# Livedoorニュースコーパスの準備
* [sonoisa/t5-japanese](https://github.com/sonoisa/t5-japanese)のノートブックを参考に、OpenPrompt用の形式に変換しています

In [None]:
target_genres = ["dokujo-tsushin",
                 "it-life-hack",
                 "kaden-channel",
                 "livedoor-homme",
                 "movie-enter",
                 "peachy",
                 "smax",
                 "sports-watch",
                 "topic-news"]

In [None]:
import tarfile
import re

from ja_sentence_segmenter.normalize.neologd_normalizer import normalize

def remove_brackets(text):
    text = re.sub(r"(^【[^】]*】)|(【[^】]*】$)", "", text)
    return text

def normalize_text(text):
    assert "\n" not in text and "\r" not in text
    text = text.replace("\t", " ")
    text = text.strip()
    text = list(normalize(text))[0]
    text = text.lower()
    return text

def read_title_body(file):
    next(file)
    next(file)
    title = next(file).decode("utf-8").strip()
    title = normalize_text(remove_brackets(title))
    body = normalize_text(" ".join([line.decode("utf-8").strip() for line in file.readlines()]))
    return title, body

genre_files_list = [[] for genre in target_genres]

all_data = []

with tarfile.open("ldcc-20140209.tar.gz") as archive_file:
    for archive_item in archive_file:
        for i, genre in enumerate(target_genres):
            if genre in archive_item.name and archive_item.name.endswith(".txt"):
                genre_files_list[i].append(archive_item.name)

    for i, genre_files in enumerate(genre_files_list):
        for name in genre_files:
            file = archive_file.extractfile(name)
            title, body = read_title_body(file)
            title = normalize_text(title)
            body = normalize_text(body)

            if len(title) > 0 and len(body) > 0:
                all_data.append({
                    "title": title,
                    "body": body,
                    "genre_id": i
                    })

In [None]:
import random
from tqdm.notebook import tqdm

random.shuffle(all_data)

def to_line(data):
    title = data["title"]
    body = data["body"]
    genre_id = data["genre_id"]

    assert len(title) > 0 and len(body) > 0
    return f"{title}\t{body}\t{genre_id}\n"

data_size = len(all_data)
train_ratio, dev_ratio, test_ratio = 0.7, 0.15, 0.15

with open(f"data/train.tsv", "w", encoding="utf-8") as f_train, \
    open(f"data/dev.tsv", "w", encoding="utf-8") as f_dev, \
    open(f"data/test.tsv", "w", encoding="utf-8") as f_test:
    
    for i, data in tqdm(enumerate(all_data)):
        line = to_line(data)
        if i < train_ratio * data_size:
            f_train.write(line)
        elif i < (train_ratio + dev_ratio) * data_size:
            f_dev.write(line)
        else:
            f_test.write(line)

In [None]:
import openprompt.plms as plms
from openprompt.plms.mlm import MLMTokenizerWrapper
from transformers import BertConfig, BertForMaskedLM, BertJapaneseTokenizer

In [None]:
plms._MODEL_CLASSES['bert-ja'] = plms.ModelClass(**{
    'config': BertConfig,
    'tokenizer': BertJapaneseTokenizer,
    'model':BertForMaskedLM,
    'wrapper': MLMTokenizerWrapper,
})

In [None]:
plms._MODEL_CLASSES

In [None]:
import os
from openprompt.data_utils.data_processor import DataProcessor
from openprompt.data_utils.utils import InputExample

class LivedoorNewsProcessor(DataProcessor):

    def __init__(self):
        super().__init__()
        self.labels = target_genres

    def get_examples(self, data_dir, split):
        path = os.path.join(data_dir, "{}.tsv".format(split))
        examples = []
        with open(path, "r", encoding='utf8') as f:
            for idx, line in enumerate(f):
                line = line.strip().split("\t")
                
                text_a = line[0]
                text_b = line[1]
                label = line[2]
                
                example = InputExample(guid=str(idx), text_a=text_a, text_b=text_b, label=int(label))
                examples.append(example)
        return examples

In [None]:
dataset = {}
dataset['train'] = LivedoorNewsProcessor().get_train_examples("./data")
dataset['dev'] = LivedoorNewsProcessor().get_dev_examples("./data")
dataset['test'] = LivedoorNewsProcessor().get_test_examples("./data")
class_labels = LivedoorNewsProcessor().get_labels()

# モデル学習


In [None]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm(model, model_name_or_path)
# plm, tokenizer, model_config, WrapperClass = load_plm("t5", "sonoisa/t5-base-japanese")

In [None]:
from openprompt.prompts import ManualTemplate
mytemplate = ManualTemplate(tokenizer=tokenizer, text='{"placeholder":"text_a"} {"placeholder":"text_b"} この記事のジャンルは{"mask"}。')

In [None]:
from openprompt.data_utils.data_sampler import FewShotSampler

if num_examples_per_label is not None:
    sampler  = FewShotSampler(num_examples_per_label=num_examples_per_label)
    dataset['train'] = sampler(dataset['train'], seed=seed)

In [None]:
from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 
    batch_size=batchsize_t,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method=truncate_method)

In [None]:
validation_dataloader = PromptDataLoader(dataset=dataset["dev"], template=mytemplate, tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 
    batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method=truncate_method)

In [None]:
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 
    batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method=truncate_method)

In [None]:
from openprompt.prompts import SoftVerbalizer
import torch

# TODO：label_wordsを設定するとうまく動かない
# https://github.com/thunlp/OpenPrompt/issues/78
# label_words = [
#                 "独女",
#                 "ライフハック",
#                 "家電",
#                 "メンズ",
#                 "映画",
#                 "恋愛",
#                 "ガジェット",
#                 "スポーツ",
#                 "ニュース"
# ]
# myverbalizer = SoftVerbalizer(tokenizer=tokenizer, plm=plm, label_words=label_words, classes=target_genres, num_classes=9,multi_token_handler="max")

myverbalizer = SoftVerbalizer(tokenizer=tokenizer, plm=plm, classes=target_genres, num_classes=9, multi_token_handler=multi_token_handler)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

def evaluate(prompt_model, dataloader, desc):
    prompt_model.eval()
    allpreds = []
    alllabels = []
   
    for step, inputs in enumerate(dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        alllabels.extend(labels.cpu().tolist())
        allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
    acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
    
    print(confusion_matrix(alllabels, allpreds))
    print(classification_report(alllabels, allpreds))
    return acc

In [None]:
from openprompt import PromptForClassification
myPromptModel = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=(not tune_plm), plm_eval_mode=plm_eval_mode)
if use_cuda:
    myPromptModel=  myPromptModel.cuda()
if model_parallelize:
    myPromptModel.parallelize()

In [None]:
from transformers import  AdamW, get_linear_schedule_with_warmup,get_constant_schedule_with_warmup  # use AdamW is a standard practice for transformer 
from transformers.optimization import Adafactor, AdafactorSchedule  # use Adafactor is the default setting for T5

loss_func = torch.nn.CrossEntropyLoss()
tot_step = max_steps

if tune_plm: # normally we freeze the model when using soft_template. However, we keep the option to tune plm
    no_decay = ['bias', 'LayerNorm.weight'] # it's always good practice to set no decay to biase and LayerNorm parameters
    optimizer_grouped_parameters1 = [
        {'params': [p for n, p in myPromptModel.plm.named_parameters() if (not any(nd in n for nd in no_decay))], 'weight_decay': 0.01},
        {'params': [p for n, p in myPromptModel.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)
    scheduler1 = get_linear_schedule_with_warmup(
        optimizer1, 
        num_warmup_steps=warmup_step_prompt, num_training_steps=tot_step)
else:
    optimizer1 = None
    scheduler1 = None

In [None]:
optimizer_grouped_parameters2 = [{'params': [p for name, p in myPromptModel.template.named_parameters() if 'raw_embedding' not in name]}] # note that you have to remove the raw_embedding manually from the optimization
if optimizer.lower() == "adafactor":
    optimizer2 = Adafactor(optimizer_grouped_parameters2,  
                            lr=prompt_lr,
                            relative_step=False,
                            scale_parameter=False,
                            warmup_init=False)  # when lr is 0.3, it is the same as the configuration of https://arxiv.org/abs/2104.08691
    scheduler2 = get_constant_schedule_with_warmup(optimizer2, num_warmup_steps=warmup_step_prompt) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691
elif optimizer.lower() == "adamw":
    optimizer2 = AdamW(optimizer_grouped_parameters2, lr=prompt_lr) # usually lr = 0.5
    scheduler2 = get_linear_schedule_with_warmup(
                    optimizer2, 
                    num_warmup_steps=warmup_step_prompt, num_training_steps=tot_step) # usually num_warmup_steps is 500

In [None]:
content_write = "="*20+"\n"
content_write += f"model:{model}\t"
content_write += f"model_name_or_path:{model_name_or_path}\t"
content_write += f"seed:{seed}\t"
content_write += f"shot:{shot}\t"
content_write += f"num_examples_per_label:{num_examples_per_label}\t"
content_write += f"plm_eval_mode:{plm_eval_mode}\t"
content_write += f"eval_every_steps:{eval_every_steps}\t"
content_write += f"warmup_step_prompt:{warmup_step_prompt}\t"
content_write += f"prompt_lr:{prompt_lr}\t"
content_write += f"optimizer:{optimizer}\t"
content_write += f"multi_token_handler:{multi_token_handler}\t"
content_write += "\n"
content_write += f"batchsize_t:{batchsize_t}\t"
content_write += f"batchsize_e:{batchsize_e}\t"
content_write += f"max_seq_l:{max_seq_l}\t"
content_write += f"gradient_accumulation_steps:{gradient_accumulation_steps}\t"
content_write += f"model_parallelize:{model_parallelize}\t"
content_write += f"use_cuda:{use_cuda}\t"

print(content_write)

In [None]:
tot_loss = 0 
log_loss = 0
best_val_acc = 0
glb_step = 0
actual_step = 0
leave_training = False

acc_traces = []
tot_train_time = 0
pbar_update_freq = 10
myPromptModel.train()

In [None]:
import time

pbar = tqdm(total=tot_step, desc="Train")
for epoch in range(1000000):
    # print(f"Begin epoch {epoch}")
    pbar.set_description(f"Train[Epoch {epoch}]")
    for step, inputs in enumerate(train_dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        tot_train_time -= time.time()
        logits = myPromptModel(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        actual_step += 1

        if actual_step % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(myPromptModel.parameters(), 1.0)
            glb_step += 1
            if glb_step % pbar_update_freq == 0:
                aveloss = (tot_loss - log_loss)/pbar_update_freq
                pbar.update(10)
                pbar.set_postfix({'loss': aveloss})
                log_loss = tot_loss

        
        if optimizer1 is not None:
            optimizer1.step()
            optimizer1.zero_grad()
        if scheduler1 is not None:
            scheduler1.step()
        if optimizer2 is not None:
            optimizer2.step()
            optimizer2.zero_grad()
        if scheduler2 is not None:
            scheduler2.step()

        tot_train_time += time.time()

        if actual_step % gradient_accumulation_steps == 0 and glb_step >0 and glb_step % eval_every_steps == 0:
            val_acc = evaluate(myPromptModel, validation_dataloader, desc="Valid")
            if val_acc >= best_val_acc:
                torch.save(myPromptModel.state_dict(),f"{project_root}models/{this_run_unicode}.ckpt")
                best_val_acc = val_acc
            
            acc_traces.append(val_acc)
            print("Glb_step {}, val_acc {}, average time {}".format(glb_step, val_acc, tot_train_time/actual_step ), flush=True)
            myPromptModel.train()

        if glb_step > max_steps:
            leave_training = True
            break
    
    if leave_training:
        break  

In [None]:
myPromptModel.load_state_dict(torch.load(f"{project_root}/models/{this_run_unicode}.ckpt"))
test_acc = evaluate(myPromptModel, test_dataloader, desc="Test")

In [None]:
# a simple measure for the convergence speed.
thres99 = 0.99*best_val_acc
thres98 = 0.98*best_val_acc
thres100 = best_val_acc
step100=step98=step99=max_steps
for val_time, acc in enumerate(acc_traces):
    if acc>=thres98:
        step98 = min(val_time*eval_every_steps, step98)
        if acc>=thres99:
            step99 = min(val_time*eval_every_steps, step99)
            if acc>=thres100:
                step100 = min(val_time*eval_every_steps, step100)

In [None]:
content_write += f"BestValAcc:{best_val_acc}\tEndValAcc:{acc_traces[-1]}\tcritical_steps:{[step98,step99,step100]}\n"
content_write += f"testAcc:{test_acc}\n"
content_write += "\n"

print(content_write)

with open(f"{project_root}{result_file}", "a") as fout:
    fout.write(content_write)