请点击[此处](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576)查看本环境基本用法.  <br>
Please click [here ](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576) for more detailed instructions. 

### 1 预训练模型
我们首先在huggingface下载mengzi-t5-base模型以便后续训练。

In [48]:
!curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/pytorch_model.bin?download=True
!curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/config.json?download=true
!curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/spiece.vocab?download=true
!curl -L -O https://hf-mirror.com/Langboat/mengzi-t5-base/resolve/main/spiece.model?download=true

In [51]:
!nvidia-smi

Wed Dec  4 08:57:32 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:03:00.0 Off |                    0 |
| N/A   42C    P0    55W / 300W |   3207MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

### 2 数据准备
#### 2.1 数据下载
这里的数据是使用[chinese-poetry](https://github.com/chinese-poetry/chinese-poetry)收集的唐诗宋词，由于飞桨平台已经内置该数据集，所以我们只需添加进来就可以了，这里是解压缩数据。

In [52]:
!unzip -n ./data/data70759/poems_json.zip

Archive:  ./data/data70759/poems_json.zip


In [53]:
!pip install -q chinese-converter

[33mDEPRECATION: pytorch-lightning 1.5.10 has a non-standard dependency specifier torch>=1.7.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

In [54]:
# IS_TEST_FLOW = True
IS_TEST_FLOW = False

In [55]:
import json
import urllib.request
import pandas as pd
# from tqdm.notebook import tqdm
import chinese_converter  # 繁体到简体需要
import pickle
import os
import pandas as pd
import numpy as np

#### 2.2 数据处理
处理json格式，读入数据

In [56]:
# https://github.com/chinese-poetry/chinese-poetry, last update 04/18/2023
POEM_CONTENT = {
    'tang': {
        'total': 58,
        'pattern': "./poems_json/poet.tang.{0}.json"
    },
    'song': {
        'total': 255,
        'pattern': "./poems_json/poet.song.{0}.json"
    }
}

def get_poems(is_test=True, verbose=True):
  df_list = []
  for dynasty in POEM_CONTENT:
    size = 3 if is_test else POEM_CONTENT[dynasty]['total']
    for i in range(size):
      url = POEM_CONTENT[dynasty]['pattern'].format(i * 1000)
      if verbose:
        print(f"load {url} now")
      df_list.append(pd.read_json(url))
  return pd.concat(df_list)

In [None]:
df = get_poems(is_test=IS_TEST_FLOW, verbose=True)
df['concat_paragraphs'] = [''.join(map(str, l)) for l in df['paragraphs']]
df = df[['author', 'title', 'concat_paragraphs']]

def convert_schinese(tchinese):
  return chinese_converter.to_simplified(tchinese)

df['s_content'] = df.apply(lambda row: convert_schinese(''.join(row.concat_paragraphs)), axis=1)
df['s_title'] = df.apply(lambda row: convert_schinese(''.join(row.title)), axis=1)
df['s_author'] = df.apply(lambda row: convert_schinese(''.join(row.author)), axis=1)

my_df = df
print("my_df size", len(my_df))

In [58]:
MAX_AUTHOR_CHAR = 4
MAX_TITLE_CHAR = 12
MIN_CONTENT_CHAR = 20
MAX_CONTENT_CHAR = 32
BAD_TOKENS = " ()[]《》（）□{}abcdefgxyz一"

def trim_author_fn(row):
  return row.s_author[:MAX_AUTHOR_CHAR]

def trim_title_fn(row):
  trimed_title = row.s_title[:MAX_TITLE_CHAR]
  for b in BAD_TOKENS:
    trimed_title = trimed_title.replace(b, "")
  return trimed_title

def trim_content_fn(row):
  trimed_content = row.s_content[:MAX_CONTENT_CHAR]
  # # End with a period to avoid partial ending to confuse model
  for b in BAD_TOKENS:
    trimed_content = trimed_content.replace(b, "")
  last_period = trimed_content.rfind("。")
  return trimed_content[:last_period+1]
  # return trimed_content

# Trim the size, a soft copy to avoid the view/copy conflict warning
my_df['s_author_trim'] = my_df.copy().apply(trim_author_fn, axis=1)
my_df['s_title_trim'] = my_df.copy().apply(trim_title_fn, axis=1)
my_df['s_content_trim'] = my_df.copy().apply(trim_content_fn, axis=1)

print("my_df size", len(my_df))

my_df size 311860


对数据进行筛选

In [59]:
# Title cannot be empty
empty_title_mask = (my_df['s_title_trim'].str.len() == 0)
too_short_cotent_mask = (my_df['s_content_trim'].str.len() <= MIN_CONTENT_CHAR)
invalid_mask = (('无正文' == my_df['s_content_trim']) | ('无正文' == my_df['s_author_trim']))
too_short_mask =  empty_title_mask | too_short_cotent_mask | invalid_mask
# filtered_my_df = my_df.loc[too_short_mask]
# filtered_my_df

my_df = my_df.loc[~too_short_mask][[
  's_author_trim', 's_title_trim', 's_content_trim']]
print("my_df size", len(my_df))

my_df size 297836


In [60]:
import re
result_dict = {
    's_author_trim': [],
    's_title_trim': [],
    's_content_trim': [],
}
for i, row in my_df.iterrows():
  c = row['s_content_trim']
  snippets = list(re.split('，|。|？', c))
  lens = [len(s) for s in snippets if s.strip() != '']
  if max(lens) != min(lens) or max(lens) not in [5, 7]:
    continue
  result_dict['s_author_trim'].append(row['s_author_trim'])
  result_dict['s_title_trim'].append(row['s_title_trim'])
  result_dict['s_content_trim'].append(c)
# print("get rid of ", sum(bad_items))
my_df = pd.DataFrame(data=result_dict)
print("left", len(my_df))

left 225853


In [61]:
AUTHOR_PROMPT = "模仿："
TITLE_PROMPT = "作诗："
EOS_TOKEN = '</s>'
def build_dataset_df(df, include_author=True):
  dfc = df.copy()
  if include_author:
    dfc['source_text'] = TITLE_PROMPT + df['s_title_trim'] + EOS_TOKEN + AUTHOR_PROMPT + df['s_author_trim']
  else:
    dfc['source_text'] = TITLE_PROMPT + df['s_title_trim']
  dfc['target_text'] = df['s_content_trim']
  dfc = dfc[['source_text', 'target_text']]
  return dfc

带有作者的数据

In [None]:
df_author_title_content = build_dataset_df(my_df, True)
df_author_title_content[100:105]

不带作者的数据

In [63]:

df_title_content = build_dataset_df(my_df, False)
df_title_content[100:105]

Unnamed: 0,source_text,target_text
100,作诗：九月九日幸临渭亭登高得秋,九月正乘秋，三杯兴已周。泛桂迎尊满，吹花向酒浮。
101,作诗：登骊山高顶寓目,四郊秦汉国，八水帝王都。阊阖雄里闬，城阙壮规模。
102,作诗：幸秦始皇陵,眷言君失德，骊邑想秦余。政烦方改篆，愚俗乃焚书。
103,作诗：立春日游苑迎春,神皐福地三秦邑，玉台金阙九仙家。寒光犹恋甘泉树，淑景偏临建始花。
104,作诗：十月诞辰内殿宴羣臣效柏梁,润色鸿业寄贤才，叨居右弼媿盐梅。运筹帷幄荷时来，职掌图籍滥蓬莱。


In [64]:

merged_df = pd.concat([df_author_title_content, df_title_content])

In [65]:
merged_df = merged_df.sample(frac=1.)
merged_df

Unnamed: 0,source_text,target_text
178654,作诗：凤门泉</s>模仿：周文璞,石根两眼北流泉，乳窦潜通御井边。飘出宫花人不识，又沿衰草下山前。
188770,作诗：和仲弟十绝其</s>模仿：刘克庄,懒窥户外问晴阴，静向窗前閲古今。江国事稀聊袖手，钧天夢断久灰心。
71566,作诗：和吕秘校其二,王子如今未夢刀，不须感慨论官曹。寝郎悟意犹爲相，鄠尉知名固可褒。
42945,作诗：观灯玉台体十首其四,火树灯山高入云，红筵翠幄自成春。游女有时还解佩，青楼何处不留人。
77779,作诗：次韵答张天觉二首其,车轻马稳轡衔坚，但有蚊蝱喜扑缘。截断口前君莫问，人间差乐胜巢仙。
...,...,...
33325,作诗：宫词百首三十九,锦褥花明满殿铺，宫娥分坐学樗蒲。欲教官马冲关过，呪愿纤纤早掷卢。
139985,作诗：村饮,吴中霣霜晚，冬草有未衰。坐令老病叟，遂失凋年悲。
131932,作诗：送赵德庄右司赴江东漕八首</s>模仿：曾协,我公廊庙姿，当爲济时霖。生材必有用，应物初无心。
146867,作诗：早朝紫宸殿贺雪呈尤延之二</s>模仿：杨万里,雪花将瑞献君王，晴早销迟恋建章。不肯独清须带月，犹嫌未冷更吹霜。


### 3 训练
安装一下torch, simplet5等必要库

In [66]:
!pip install torch
!pip install simplet5
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

Looking in indexes: https://mirror.baidu.com/pypi/simple/, https://mirrors.aliyun.com/pypi/simple/
[33mDEPRECATION: pytorch-lightning 1.5.10 has a non-standard dependency specifier torch>=1.7.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mLooking in indexes: https://mirror.baidu.com/pypi/simple/, https://mirrors.aliyun.com/pypi/simple/
[33mDEPRECATION: pytorch-lightning 1.5.10 has a non-standard dependency specifier torch>=1.7.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m

In [67]:
torch.cuda.empty_cache()

加载mengzi-t5-base模型

In [68]:
# 指定本地模型路径
# local_model_path = "./mengzi_t5_base"
local_model_path = "./MengziT5_base"

# 定义 extra_ids 数量
extra_ids = 100

# 创建包含所有 extra_ids 的特殊标记列表
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]


class MengziSimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self, use_gpu: bool = True):
    # self.tokenizer = T5Tokenizer.from_pretrained(local_model_path,
    # extra_ids=extra_ids,
    # additional_special_tokens=additional_special_tokens)
    self.tokenizer = T5Tokenizer.from_pretrained(local_model_path)
    self.model = T5ForConditionalGeneration.from_pretrained(local_model_path)

In [69]:
# # # !pip install requests
# # !export HF_ENDPOINT=https://hf-mirror.com
# # !export HF_HOME="/home/aistudio/hf"
# # # !pwd
# # import os
# # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# # !cd ./cache
# # !ls -la ./.cache/huggingface/transformers

# from paddlenlp.transformers import T5ForConditionalGeneration, T5Tokenizer

# # 模型保存路径
# local_dir = "./mengzi_t5_base"

# # 下载并保存模型和分词器
# model_name = "Langboat/mengzi-t5-base"
# model = T5ForConditionalGeneration.from_pretrained(model_name)
# tokenizer = T5Tokenizer.from_pretrained(model_name)

# # 保存到本地
# model.save_pretrained(local_dir)
# tokenizer.save_pretrained(local_dir)
# !zip

In [70]:
model = MengziSimpleT5()
model.load_my_model()
model.model = model.model.to('cuda')

  state_dict = torch.load(resolved_archive_file, map_location="cpu")


In [71]:
model.tokenizer("桥形通汉上，峰势接云危。</s>烟霞交隐映，花鸟自参差。")

{'input_ids': [1012, 955, 406, 921, 23, 3, 1440, 2180, 799, 355, 4008, 4, 1, 1448, 4152, 690, 3934, 4990, 3, 17544, 178, 2572, 769, 4, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [72]:
model.tokenizer.decode([1012, 955, 406, 921, 23, 3, 1440, 2180, 799, 355, 4008, 4, 1, 1448, 4152, 690, 3934, 4990, 3, 17544, 178, 2572, 769, 4, 1])

'桥形通汉上,峰势接云危。</s> 烟霞交隐映,花鸟自参差。</s>'

将数据集以0.98, 0.02的比例划分为训练集和验证集

In [73]:
from sklearn.model_selection import train_test_split
merged_df = merged_df.sample(frac=1) # Shuffle
train_df, eval_df = train_test_split(merged_df, test_size=0.02)

In [74]:
print("train", len(train_df), "eval", len(eval_df))

train 442671 eval 9035


开始训练

In [79]:
model.train(train_df=train_df,
            eval_df=eval_df,
            source_max_token_len=(len(TITLE_PROMPT) + MAX_TITLE_CHAR +  1 + len(AUTHOR_PROMPT) + MAX_AUTHOR_CHAR),
            target_max_token_len=MAX_CONTENT_CHAR,
            batch_size=256,
            max_epochs=5,
            use_gpu=True,
            outputdir="./Models/t5-poem-v2.1")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 247 M 
-----------------------------------------------------
247 M     Trainable params
0         Non-trainable params
247 M     Total params
990.311   Total estimated model params size (MB)


                                                                      

Global seed set to 42


Epoch 4:  33%|███▎      | 580/1766 [09:57<20:21,  1.03s/it, loss=3.43, v_num=3, train_loss_step=3.400, val_loss_step=3.570, val_loss_epoch=3.550, train_loss_epoch=3.590] 

### 4 测试
训练完成后，对模型效果进行测试

In [None]:
def poem(title_str, opt_author=None, model=model,
         is_input_traditional_chinese=False,
         num_beams=2):
  model.model = model.model.to('cuda')
  if opt_author:
    in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR]
  else:
    in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR]
  if is_input_traditional_chinese:
    in_request = chinese_converter.to_simplified(in_request)
  out = model.predict(in_request,
                      max_length=MAX_CONTENT_CHAR,
                      num_beams=num_beams)[0].replace(",", "，")
  if is_input_traditional_chinese:
    out = chinese_converter.to_traditional(out)
    print(f"標題： {in_request.replace('</s>', ' ')}\n詩歌： {out}")
  else:
    print(f"标题： {in_request.replace('</s>', ' ')}\n诗歌： {out}")

In [77]:
for title in ['秋思', "百花", '佳人有约']:
  # Empty author means general style
  for author in ['', "杜甫", "李白", "李清照", "苏轼"]:
    poem(title, author)
  print()

标题： 作诗：秋思
诗歌： 西风送秋思，北风起南枝。黄叶满庭树，白露湿衣襟。
标题： 作诗：秋思 模仿：杜甫
诗歌： 秋思苦不眠，夜吟愁复起。天高云欲雨，风定雁初飞。
标题： 作诗：秋思 模仿：李白
诗歌： 秋风吹客衣，西风起我愁。独坐对明月，不觉夜来时。
标题： 作诗：秋思 模仿：李清照
诗歌： 江上秋容好，山中夜气清。月明寒未落，风冷客愁生。
标题： 作诗：秋思 模仿：苏轼
诗歌： 秋声萧萧入夢思，夜雨淅淅鸣蟋蟀。西风吹我衣衾冷，北窗凉生枕簟眠。

标题： 作诗：百花
诗歌： 百花如玉照春光，花外花开几度香。莫道人间多好物，只应无此与君同。
标题： 作诗：百花 模仿：杜甫
诗歌： 百花开后百花，红紫纷纷满地红。谁家春色无多子，独有东风向晚风。
标题： 作诗：百花 模仿：李白
诗歌： 百花如玉色，不与衆芳同。若使君王爱，何妨我辈人。
标题： 作诗：百花 模仿：李清照
诗歌： 百花开尽满城春，花柳如丝绿未匀。欲识芳菲无限意，只将红粉染胭脂。
标题： 作诗：百花 模仿：苏轼
诗歌： 百花开后日，花木发新枝。不待春风至，何妨晚节来。

标题： 作诗：佳人有约
诗歌： 佳人有约在芳辰，花发红英照眼新。不待春风催柳絮，却教春色到人间。
标题： 作诗：佳人有约 模仿：杜甫
诗歌： 佳人有约在天涯，不似花间笑语同。春色已随桃李发，雨声犹逐燕啼。
标题： 作诗：佳人有约 模仿：李白
诗歌： 佳人有约今何许，花柳阴阴满地香。莫道春风吹不至，且将红粉照窗纱。
标题： 作诗：佳人有约 模仿：李清照
诗歌： 佳人有约到芳丛，爲报春光入酒杯。花下红楼开宴座，月中黄屋起歌声。
标题： 作诗：佳人有约 模仿：苏轼
诗歌： 佳人有约醉中仙，花下相邀笑语同。莫道春光催酒盏，且教红粉染胭脂



In [78]:
for title in ['冬雪']:
  for author in  ['', "杜甫"]:
    for num_beams in (2, 3, 5, 10, 20, 50, 100, 200):
      print(f"num beams: {num_beams}")
      poem(title, author, num_beams=num_beams)
    print("-"*80)

num beams: 2
标题： 作诗：冬雪
诗歌： 朔风卷地雪纷纷，飞霰飘然下帝阍。天遣玉人成底事，人间无此与谁同。
num beams: 3
标题： 作诗：冬雪
诗歌： 朔风吹雪满江城，万瓦琼花照眼明。老去只知身是客，病来惟觉鬓成丝。
num beams: 5
标题： 作诗：冬雪
诗歌： 腊雪连三白，寒云入九重。天公忧岁晚，我独喜春深。
num beams: 10
标题： 作诗：冬雪
诗歌： 朔风吹雪满江城，万木千林冻不收。忽有飞花纷扑地，不知何物是吾家。
num beams: 20
标题： 作诗：冬雪
诗歌： 朔风卷地雪纷纷，白帝楼头雪作花。老去不知身是客，今朝重见鬓成翁。
num beams: 50
标题： 作诗：冬雪
诗歌： 朔风吹雪满天涯，独倚阑干对落晖。老去不知身是客，故来犹喜鬓成丝。
num beams: 100
标题： 作诗：冬雪
诗歌： 朔风吹雪白皑皑，万瓦琼花照眼明。不道人间无此景，却疑天上有佳人。
num beams: 200
标题： 作诗：冬雪
诗歌： 腊后寒犹重，冬来雪未消。遥知春意早，已觉岁华徂。
--------------------------------------------------------------------------------
num beams: 2
标题： 作诗：冬雪 模仿：杜甫
诗歌： 冬雪何太早，朔风已惨凄。天公不可见，我辈岂能知。
num beams: 3
标题： 作诗：冬雪 模仿：杜甫
诗歌： 朔风卷地来，飞雪满空山。寒气侵人骨，清光入客衣。
num beams: 5
标题： 作诗：冬雪 模仿：杜甫
诗歌： 朔雪连年积，阴风卷地来。天寒犹有雪，岁晚更无梅。
num beams: 10
标题： 作诗：冬雪 模仿：杜甫
诗歌： 江上雪初下，山前雪未消。冻云凝不散，飞雪满空山。
num beams: 20
标题： 作诗：冬雪 模仿：杜甫
诗歌： 腊后冬来雪，今朝霰雪飞。朔风欺客帽，寒日逼人衣。
num beams: 50
标题： 作诗：冬雪 模仿：杜甫
诗歌： 朔风卷地来，冬雪满空山。寒气入肌骨，冷光侵肌骨。
num beams: 100
标题： 作诗：冬雪 模仿：杜甫
诗歌： 腊雪连三白，冬雪满两黄。北风犹未定