[Open In Colab](https://colab.research.google.com/github/shibing624/textgen/blob/main/examples/T5/T5_Finetune_Chinese_Couplet.ipynb)


# T5 对联
- 设计：Pretrained T5 + “对联 prompt” fine-tuning
  - 对比我的 [transformer training from scratch](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/%E4%B8%AD%E6%96%87%E5%AF%B9%E8%81%94Transformer_Source_Code_V1.ipynb)
- 数据：[对联github](https://github.com/wb14123/couplet-dataset)
- 相关内容
  - [Huggingface](https://huggingface.co/)
  - LangZhou Chinese [MengZi T5 pretrained Model](https://huggingface.co/Langboat/mengzi-t5-base) and [paper](https://arxiv.org/pdf/2110.06696.pdf)
  - [textgen](https://github.com/shibing624/textgen)

In [None]:
# If for quick test purpose, if so, use 3k samples instead of 800k
IS_TEST_FLOW = True  #@param {type:"boolean"}

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

## Prepare Data

In [None]:
import argparse
from loguru import logger
import sys
import os
import pathlib
import numpy as np
import pandas as pd
import pickle

In [None]:
working_dir = "./"
!mkdir -p {working_dir}
!wget https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz -P {working_dir}
!ls -l {working_dir}

In [None]:
!mkdir -p {working_dir}/couplet_files
!tar -xf {working_dir}/couplet.tar.gz -C {working_dir}/couplet_files

In [None]:
!head -1 {working_dir}/couplet_files/couplet/train/in.txt {working_dir}/couplet_files/couplet/train/out.txt

## Load Data

In [None]:
COUPLET_PATH = f'{working_dir}/couplet_files/couplet'
MAX_SEQ_LEN = 32  # Max 32 chinese char including punctuation marks
COUPLET_PROMPOT = '对联：'

train_df, test_df = None, None
for t in ['train', 'test']:
    ins, outs = [], []
    for i in ['in', 'out']:
        with open(f"{COUPLET_PATH}/{t}/{i}.txt", "r") as f:
            for line in f:
                clean_line = line.strip().replace(' ', '').replace('\n', '').replace('\r', '')[:MAX_SEQ_LEN]
                if i == 'in':
                    ins.append(clean_line)
                else:
                    outs.append(clean_line)
    # The column names to match textgen
    data_dict = {
        'prefix': [COUPLET_PROMPOT] * len(ins),
        'input_text': ins,
        'target_text': outs,
    }
    if t == 'train':
        train_df = pd.DataFrame(data_dict)
    else:
        test_df = pd.DataFrame(data_dict)

In [None]:
# Ensure size match for every train/test sample
for df in [train_df, test_df]:
    for i in range(len(df)):
        if len(df['input_text'].values[i]) != len(df['target_text'].values[i]):
            print("mismatch found:", df['input_text'].values[i], df['target_text'].values[i])
            break

In [None]:
train_df[:3]

In [None]:
test_df[:3]

In [None]:
eval_df = test_df
print("train", len(train_df), "eval", len(eval_df))

train_df = train_df.sample(3000) if IS_TEST_FLOW else train_df
eval_df = eval_df.sample(300) if IS_TEST_FLOW else eval_df
print("train", len(train_df), "eval", len(eval_df))

## Prepare Model

In [None]:
!nvidia-smi  # Check GPU, P100/16G takes 100mins per epoch similar to 1080

In [None]:
# Quite install textgen package
!pip install -q textgen

In [None]:
import torch
import sys
sys.path.append('../..')
from textgen.t5 import T5Model

In [None]:
model_type = 't5'
model_name = "Langboat/mengzi-t5-base"
output_dir = 'outputs/mengzi_t5_couplet/'
max_seq_length = 50
num_epochs = 10
batch_size = 32

In [None]:
model_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "max_seq_length": max_seq_length,
    "max_length": max_seq_length,
    "train_batch_size": batch_size,
    "num_train_epochs": num_epochs,
    "save_eval_checkpoints": False,
    "save_model_every_epoch": False,
    "evaluate_generated_text": True,
    "evaluate_during_training": True,
    "evaluate_during_training_verbose": True,
    "use_multiprocessing": False,
    "save_best_model": True,
    "output_dir": output_dir,
    "use_early_stopping": True,
}
# model_type: t5  model_name: Langboat/mengzi-t5-base
model = T5Model(model_type, model_name, args=model_args)

In [None]:
model.tokenizer("回答：天上有没有云彩？")

In [None]:
model.tokenizer.decode([1347, 13, 7995, 2205, 355, 1973, 17, 1])

In [None]:
def predict_now(sentences, model=model, prefix=COUPLET_PROMPOT):
    sentences_add_prefix = [prefix + ": " + i for i in sentences]
    print("inputs:", sentences)
    print("outputs:", model.predict(sentences_add_prefix))

predict_now(["灵蛇出洞千山秀"], model=model)

## Training

In [None]:
def sim_text_chars(text1, text2):
    if not text1 or not text2:
        return 0.0
    same = set(text1) | set(text2)
    m = len(same)
    n = len(text1) if len(text1) > len(text2) else len(text2)
    return m / n

def count_matches(labels, preds):
    logger.debug(f"labels: {labels[:10]}")
    logger.debug(f"preds: {preds[:10]}")
    match = sum([sim_text_chars(label, pred) for label, pred in zip(labels, preds)]) / len(labels)
    logger.debug(f"match: {match}")
    return match


model.train_model(train_df, eval_data=eval_df, matches=count_matches)
print(model.eval_model(eval_df, matches=count_matches))

In [None]:
predict_now(["灵蛇出洞千山秀"], model=model)

本节完。