<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/training/t5_finetune/Mengzi_T5_Finetune_Chinese_Poem_Writing_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# T5 写诗
- 设计：Pretrained T5 + “写诗 prompt” fine-tuning
  - 对比我的 [transformer training from scratch](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/%E4%B8%AD%E6%96%87%E5%86%99%E8%AF%97Transformer_Source_Code_Share_V1.ipynb)
  - 想要加入作者作为可选输入
    - 每个文章分两次输入，一次作者名字，一次“None”名字（通用）
- 数据：[诗歌github](https://github.com/chinese-poetry/chinese-poetry)
- 相关内容
  - [Huggingface](https://huggingface.co/)
  - LangZhou Chinese [MengZi T5 pretrained Model](https://huggingface.co/Langboat/mengzi-t5-base) and [paper](https://arxiv.org/pdf/2110.06696.pdf)
  - [SimpleT5 by Shivanandroy](https://github.com/Shivanandroy/simpleT5) (on top of pytorch and pytorch lightning) and [his awesome medium article](https://medium.com/geekculture/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba)
- 进度
  - 02/2022, code drafting

## Load Data

In [1]:
!nvidia-smi

Mon Feb  7 22:10:16 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
IS_TEST_FLOW = False  #@param {type: "boolean"}

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import json
import urllib.request
import pandas as pd
!pip install -q "tqdm>=4.36.1" > /tmp/na
from tqdm.notebook import tqdm
!pip install -q chinese-converter > /tmp/na
import chinese_converter  # 繁体到简体需要
import pickle
import os
import pandas as pd
import numpy as np

In [5]:
# https://github.com/chinese-poetry/chinese-poetry
POEM_CONTENT = {
    'tang': {
        'total': 58,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/json/poet.tang.{0}.json"
    },
    'song': {
        'total': 255,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/json/poet.song.{0}.json"
    }
}

def get_poems(is_test=True, verbose=True):
  df_list = []
  for dynasty in POEM_CONTENT:
    size = 3 if is_test else POEM_CONTENT[dynasty]['total']
    pbar = tqdm(total=size, desc="Dynasty " + dynasty)
    for i in range(size):
      url = POEM_CONTENT[dynasty]['pattern'].format(i * 1000)
      if verbose:
        print(f"download {url} now")
      df_list.append(pd.read_json(url))
      pbar.update(1)
  return pd.concat(df_list)

In [6]:
df = get_poems(is_test=IS_TEST_FLOW, verbose=False)
df['concat_paragraphs'] = [''.join(map(str, l)) for l in df['paragraphs']]
df = df[['author', 'title', 'concat_paragraphs']]

def convert_schinese(tchinese):
  return chinese_converter.to_simplified(tchinese)

df['s_content'] = df.apply(lambda row: convert_schinese(''.join(row.concat_paragraphs)), axis=1)
df['s_title'] = df.apply(lambda row: convert_schinese(''.join(row.title)), axis=1)
df['s_author'] = df.apply(lambda row: convert_schinese(''.join(row.author)), axis=1)

my_df = df
print("my_df size", len(my_df))

Dynasty tang:   0%|          | 0/58 [00:00<?, ?it/s]

Dynasty song:   0%|          | 0/255 [00:00<?, ?it/s]

my_df size 311855


In [7]:
MAX_AUTHOR_CHAR = 4
MAX_TITLE_CHAR = 12
MIN_CONTENT_CHAR = 10
MAX_CONTENT_CHAR = 64

def trim_author_fn(row):
  return row.s_author[:MAX_AUTHOR_CHAR]

def trim_title_fn(row):
  trimed_title = row.s_title[:MAX_TITLE_CHAR].replace(" ", "").replace("(", "").replace(")", "")
  return trimed_title

def trim_content_fn(row):
  trimed_content = row.s_content[:MAX_CONTENT_CHAR]
  # # End with a period to avoid partial ending to confuse model
  # last_period = trimed_content.rfind("。")
  # return trimed_content[:last_period+1]
  return trimed_content

# Trim the size, a soft copy to avoid the view/copy conflict warning
my_df['s_author_trim'] = my_df.copy().apply(trim_author_fn, axis=1)
my_df['s_title_trim'] = my_df.copy().apply(trim_title_fn, axis=1)
my_df['s_content_trim'] = my_df.copy().apply(trim_content_fn, axis=1)

In [8]:
# Title cannot be empty
empty_title_mask = (my_df['s_title_trim'].str.len() == 0)
too_short_cotent_mask = (my_df['s_content_trim'].str.len() <= MIN_CONTENT_CHAR)
invalid_mask = (('无正文' == my_df['s_content_trim']) | ('无正文' == my_df['s_author_trim']))
too_short_mask =  empty_title_mask | too_short_cotent_mask | invalid_mask
# filtered_my_df = my_df.loc[too_short_mask]
# filtered_my_df

qualitied_df = my_df.loc[~too_short_mask][[
  's_author_trim', 's_title_trim', 's_content_trim']]

In [9]:
qualitied_df.sample(3)

Unnamed: 0,s_author_trim,s_title_trim,s_content_trim
302,元稹,遣悲怀三首二,昔日戏言身后意，今朝皆到眼前来。衣裳已施行看尽，针线犹存未忍开。尚想旧情怜婢仆，也曾因夢送钱...
489,朱松,记草木杂诗七首月桂花,窗前小桂丛，着花无旷月。月行晦朔周，一再开复歇。初如醉肌红，忽作绛裙色。谁人相料理，耿耿自开...
730,刘着,送客亭,十年羁旅鬓成丝，千里淮山信息稀。送尽长亭短亭客，且看庄舄几时归。


In [10]:
AUTHOR_PROMPT = "模仿："
TITLE_PROMPT = "作诗："
EOS_TOKEN = '</s>'
def build_dataset_df(df, include_author=True):
  dfc = df.copy()
  if include_author:
    dfc['source_text'] = TITLE_PROMPT + df['s_title_trim'] + EOS_TOKEN + AUTHOR_PROMPT + df['s_author_trim']
  else:
    dfc['source_text'] = TITLE_PROMPT + df['s_title_trim']
  dfc['target_text'] = df['s_content_trim']
  dfc = dfc[['source_text', 'target_text']]
  return dfc

In [11]:
df_author_title_content = build_dataset_df(qualitied_df, True)
df_author_title_content[100:105]

Unnamed: 0,source_text,target_text
100,作诗：太子纳妃太平公主出降</s>模仿：高宗皇帝,龙楼光曙景，鲁馆啓朝扉。艳日浓妆影，低星降婺辉。玉庭浮瑞色，银牓藻祥徽。云转花萦盖，霞飘叶缀...
101,作诗：七夕宴悬圃二首一</s>模仿：高宗皇帝,羽盖飞天汉，凤驾越层峦。俱叹三秋阻，共敍一宵欢。璜亏夜月落，靥碎晓星残。谁能重操杼，纤手濯清澜。
102,作诗：七夕宴悬圃二首二</s>模仿：高宗皇帝,霓裳转云路，凤驾俨天潢。亏星凋夜靥，残月落朝璜。促欢今夕促，长离别后长。轻梭聊驻织，掩泪独悲伤。
103,作诗：过温汤</s>模仿：高宗皇帝,温渚停仙跸，丰郊驻晓旌。路曲回输影，岩虚传漏声。暖溜惊湍驶，寒空碧雾轻。林黄疎叶下，野白曙霜...
104,作诗：九月九日</s>模仿：高宗皇帝,端居临玉扆，初律啓金商。凤阙澄秋色，龙闱引夕凉。野净山气敛，林疎风露长。砌兰亏半影，岩桂发全...


In [12]:
df_title_content = build_dataset_df(qualitied_df, False)
df_title_content[100:105]

Unnamed: 0,source_text,target_text
100,作诗：太子纳妃太平公主出降,龙楼光曙景，鲁馆啓朝扉。艳日浓妆影，低星降婺辉。玉庭浮瑞色，银牓藻祥徽。云转花萦盖，霞飘叶缀...
101,作诗：七夕宴悬圃二首一,羽盖飞天汉，凤驾越层峦。俱叹三秋阻，共敍一宵欢。璜亏夜月落，靥碎晓星残。谁能重操杼，纤手濯清澜。
102,作诗：七夕宴悬圃二首二,霓裳转云路，凤驾俨天潢。亏星凋夜靥，残月落朝璜。促欢今夕促，长离别后长。轻梭聊驻织，掩泪独悲伤。
103,作诗：过温汤,温渚停仙跸，丰郊驻晓旌。路曲回输影，岩虚传漏声。暖溜惊湍驶，寒空碧雾轻。林黄疎叶下，野白曙霜...
104,作诗：九月九日,端居临玉扆，初律啓金商。凤阙澄秋色，龙闱引夕凉。野净山气敛，林疎风露长。砌兰亏半影，岩桂发全...


In [13]:
merged_df = pd.concat([df_author_title_content, df_title_content])

In [14]:
merged_df

Unnamed: 0,source_text,target_text
0,作诗：帝京篇十首一</s>模仿：太宗皇帝,秦川雄帝宅，函谷壮皇居。绮殿千寻起，离宫百雉余。连甍遥接汉，飞观迥凌虚。云日隐层阙，风烟出绮疎。
1,作诗：帝京篇十首二</s>模仿：太宗皇帝,岩廊罢机务，崇文聊驻辇。玉匣啓龙图，金绳披凤篆。韦编断仍续，缥帙舒还卷。对此乃淹留，欹案观坟典。
2,作诗：帝京篇十首三</s>模仿：太宗皇帝,移步出词林，停舆欣武宴。琱弓写明月，骏马疑流电。惊雁落虚弦，啼猿悲急箭。阅赏诚多美，于兹乃忘倦。
3,作诗：帝京篇十首四</s>模仿：太宗皇帝,鸣笳临乐馆，眺听欢芳节。急管韵朱弦，清歌凝白雪。彩凤肃来仪，玄鹤纷成列。去兹郑卫声，雅音方可悦。
4,作诗：帝京篇十首五</s>模仿：太宗皇帝,芳辰追逸趣，禁苑信多奇。桥形通汉上，峰势接云危。烟霞交隐映，花鸟自参差。何如肆辙迹？万里赏瑶池。
...,...,...
232,作诗：状元峰,马蹄一日遍长安，萤火鸡窗千载寒。从此锦衣归故里，文峰高并彩云端。
233,作诗：蜕龙洞,苍岩磊落任龙蟠，绵亘千年露未干。一自爲霖破壁去，至今风雨逼山寒。
234,作诗：登竺云山,独上千峰与万峰，晴岚淡写海江容。偶从动问山居事，笑拍岩前一树松。
235,作诗：寒云千叠山,松竹阴森护上方，老仙蓬髪一簪霜。闲来欹枕松风裏，归夢不知山水长。


## Modeling

In [15]:
# Quiet install simple T5 package
!pip install -q simplet5 &> /dev/null

In [16]:
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

Global seed set to 42


In [17]:
class MengziSimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self, use_gpu: bool = True):
    self.tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
    self.model = T5ForConditionalGeneration.from_pretrained("Langboat/mengzi-t5-base")

In [18]:
model = MengziSimpleT5()
model.load_my_model()
model.model = model.model.to('cuda')

Downloading:   0%|          | 0.00/725k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/659 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/990M [00:00<?, ?B/s]

In [19]:
model.tokenizer("桥形通汉上，峰势接云危。</s>烟霞交隐映，花鸟自参差。")

{'input_ids': [1012, 955, 406, 921, 23, 3, 1440, 2180, 799, 355, 4008, 4, 1, 1448, 4152, 690, 3934, 4990, 3, 17544, 178, 2572, 769, 4, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [20]:
model.tokenizer.decode([1012, 955, 406, 921, 23, 3, 1440, 2180, 799, 355, 4008, 4, 1, 1448, 4152, 690, 3934, 4990, 3, 17544, 178, 2572, 769, 4, 1])

'桥形通汉上,峰势接云危。</s> 烟霞交隐映,花鸟自参差。</s>'

In [21]:
from sklearn.model_selection import train_test_split
merged_df = merged_df.sample(frac=1) # Shuffle
train_df, eval_df = train_test_split(merged_df, test_size=0.02)

In [22]:
print("train", len(train_df), "eval", len(eval_df))

train 607776 eval 12404


In [None]:
model.train(train_df=train_df,
            eval_df=eval_df, 
            source_max_token_len=(len(TITLE_PROMPT) + MAX_TITLE_CHAR +  1 + len(AUTHOR_PROMPT) + MAX_AUTHOR_CHAR),
            target_max_token_len=MAX_CONTENT_CHAR, 
            batch_size=48,
            max_epochs=3,
            use_gpu=True,
            outputdir="/content/drive/MyDrive/ML/Models/t5-poem")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 247 M 
-----------------------------------------------------
247 M     Trainable params
0         Non-trainable params
247 M     Total params
990.311   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 42
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: -1it [00:00, ?it/s]