<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/training/t5_finetune/Mengzi_T5_Finetune_Chinese_Couplet_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# T5 对联
- 设计：Pretrained T5 + “对联 prompt” fine-tuning
  - 对比我的 [transformer training from scratch](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/%E4%B8%AD%E6%96%87%E5%AF%B9%E8%81%94Transformer_Source_Code_V1.ipynb)
- 数据：[对联github](https://github.com/wb14123/couplet-dataset)
- 相关内容
  - [Huggingface](https://huggingface.co/)
  - LangZhou Chinese [MengZi T5 pretrained Model](https://huggingface.co/Langboat/mengzi-t5-base) and [paper](https://arxiv.org/pdf/2110.06696.pdf)
  - [SimpleT5 by Shivanandroy](https://github.com/Shivanandroy/simpleT5) (on top of pytorch and pytorch lightning) and [his awesome medium article](https://medium.com/geekculture/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba)
- 进度
  - 02/2022, code draft, model testing/training in progress

In [1]:
# If for quick test purpose, if so, use 5k samples instead of 800k
IS_QUICK_TEST = True  #@param {type:"boolean"}

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Prepare Data


In [3]:
import os
import pathlib
import numpy as np
import pandas as pd
import pickle

In [4]:
working_dir = "/tmp/working_dir"
!mkdir -p {working_dir}
!wget https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz -P {working_dir}
!ls -l {working_dir}

--2022-02-07 02:02:03--  https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz
Resolving github.com (github.com)... 52.192.72.89
Connecting to github.com (github.com)|52.192.72.89|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/122695108/9643dda6-194e-11e8-9642-44c7d57d40ac?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220207%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220207T020203Z&X-Amz-Expires=300&X-Amz-Signature=fc60a3a41e08fdcbebd4bc384f5ec38ce188f9300fbfc31f9be875a91842b20d&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=122695108&response-content-disposition=attachment%3B%20filename%3Dcouplet.tar.gz&response-content-type=application%2Foctet-stream [following]
--2022-02-07 02:02:03--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/122695108/9643dda6-194e-11e8-9642-44c7d57d40ac?X-Amz-

In [5]:
!mkdir -p {working_dir}/couplet_files
!tar -xf {working_dir}/couplet.tar.gz -C {working_dir}/couplet_files

In [6]:
!head -1 {working_dir}/couplet_files/couplet/train/in.txt {working_dir}/couplet_files/couplet/train/out.txt

==> /tmp/working_dir/couplet_files/couplet/train/in.txt <==
晚 风 摇 树 树 还 挺 

==> /tmp/working_dir/couplet_files/couplet/train/out.txt <==
晨 露 润 花 花 更 红 


In [7]:
COUPLET_PATH = f'{working_dir}/couplet_files/couplet'
MAX_SEQ_LEN = 32  # Max 32 chinese char including punctuation marks

train_df, test_df = None, None
for t in ['train', 'test']:
  ins, outs = [], []
  for i in ['in', 'out']:
    with open(f"{COUPLET_PATH}/{t}/{i}.txt", "r") as f:
      for line in f:
        clean_line = line.strip().replace(' ', '').replace('\n', '').replace('\r', '')[:MAX_SEQ_LEN]
        if i=='in':
          ins.append(clean_line)
        else:
          outs.append(clean_line)
  # The column names to match simpleT5
  data_dict = {
      'source_text': ins,
      'target_text': outs,
  }
  if t == 'train':
    train_df = pd.DataFrame(data_dict)
  else:
    test_df = pd.DataFrame(data_dict)

In [8]:
COUPLET_PROMPOT = '对联：'
train_df['source_text'] = COUPLET_PROMPOT + train_df['source_text']
test_df['source_text'] = COUPLET_PROMPOT + test_df['source_text']

In [9]:
MAX_IN_TOKENS = len(COUPLET_PROMPOT) + MAX_SEQ_LEN
MAX_OUT_TOKENS = MAX_SEQ_LEN

In [10]:
# Ensure size match for every train/test sample
size_diff = len(COUPLET_PROMPOT)
for df in [train_df, test_df]:
  for i in range(len(df)):
    if len(df['source_text'].values[i]) != len(df['target_text'].values[i]) + size_diff:
      print("mismatch found:", df['source_text'].values[i], df['target_text'].values[i])
      break

In [11]:
train_df[1000:1010]

Unnamed: 0,source_text,target_text
1000,对联：昔日斯人尘绝去,何时雁侣梦归来
1001,对联：万户银河火,千山画海花
1002,对联：魁星点斗浴文光，陡生凤翼,大志干霄增笔力，独占鳌头
1003,对联：德颂巩义，民风淳正铸文明,道法自然，社会和谐享太平
1004,对联：中华儿女歌孙氏,世界人民仰泰山
1005,对联：智者虚怀常俯首,强人硬骨不屈膝
1006,对联：遍啸江湖，一腔热血酬山海,从来规矩，代数几何圆角锥
1007,对联：来横山高处，饱览云涛，松风舒朗抱，且悠游自在乾坤，清凉世界,待胜地佳时，静聆天籁，星汉洗尘心，漫领略摩诘意趣，和仲情怀
1008,对联：百战忠魂，千秋恨事,一朝义愤，万古馨香
1009,对联：猛志固常在,小儒安足为


In [12]:
test_df[1000:1010]

Unnamed: 0,source_text,target_text
1000,对联：没穷亲友往来，其家肯定势利,无正经人交接，这个必是奸邪
1001,对联：仁里胪欢，有脚阳春来大地,德林成荫，无声雨露润圆山
1002,对联：雅士云亡，谁共青灯说禅论正道,哲人其萎，何事宝筏登岸完大觉
1003,对联：悟空心似镜,知白意如风
1004,对联：书法工神王大令,风流闲雅谢临川
1005,对联：山门外三脚驴子,蒲团上一块兜楼
1006,对联：山种高梧，彩凤来仪，广纳良才谋福祉,金襄丕业，青龙起舞，勤施善策乐民生
1007,对联：蓝梦蓝图现实见,红尘红颜空中归
1008,对联：灵蛇出洞千山秀,紫燕归巢万木春
1009,对联：南山种豆望明月,北海牧羊思故乡


## Prepare Model

In [13]:
# !nvidia-smi  # Check GPU, P100/16G takes 100mins per epoch similar to 1080

In [14]:
# Quite install simple T5 package
!pip install -q simplet5 &> /dev/null

In [15]:
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

Global seed set to 42


In [16]:
class MengziSimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self, use_gpu: bool = True):
    self.tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
    self.model = T5ForConditionalGeneration.from_pretrained("Langboat/mengzi-t5-base")

In [17]:
model = MengziSimpleT5()
model.load_my_model()
model.model = model.model.to('cuda')

Downloading:   0%|          | 0.00/725k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/659 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/990M [00:00<?, ?B/s]

In [18]:
model.tokenizer("回答：天上有没有云彩？")

{'input_ids': [1347, 13, 7995, 2205, 355, 1973, 17, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

In [19]:
model.tokenizer.decode([1347, 13, 7995, 2205, 355, 1973, 17, 1])

'回答:天上有没有云彩?</s>'

In [29]:
def predict_now(in_str, model=model):
  model.model = model.model.to('cuda')
  in_request = f"{COUPLET_PROMPOT}{in_str[:MAX_SEQ_LEN]}"
  return model.predict(
      in_request,
      max_length=min(MAX_OUT_TOKENS, len(in_request) - len(COUPLET_PROMPOT) - 1),
      num_beams=1,
      top_p=1.0,
      top_k=1,
      do_sample=False) # topp, num_beams ...

# predict_now("灵蛇出洞千山秀")

## Training

In [21]:
model.train(train_df=train_df if not IS_QUICK_TEST else train_df[:5000],
            eval_df=test_df, 
            source_max_token_len=MAX_IN_TOKENS, 
            target_max_token_len=MAX_OUT_TOKENS, 
            batch_size=64,
            max_epochs=3,
            use_gpu=True,
            outputdir="/content/drive/MyDrive/ML/Models/t5-couplet")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 247 M 
-----------------------------------------------------
247 M     Trainable params
0         Non-trainable params
247 M     Total params
990.311   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [30]:
predict_now("灵蛇出洞千山秀")

['灵蛇出山万']