| [06_text_generation/02_预训练文本生成模型.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/06_text_generation/02_预训练文本生成模型.ipynb)  | 基于transformers的GPT、XLNet生成模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/06_text_generation/02_预训练文本生成模型.ipynb) |

# 预训练文本生成模型

## 英文GPT-2文本生成模型

In [None]:
!pip install transformers

In [1]:
import os

from transformers import AutoModelForCausalLM, pipeline
from transformers import AutoTokenizer

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
outputs = text_generator("hi lili, what is your ", max_length=50, do_sample=True)
outputs

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'hi lili, what is your ??????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????'}]

transformers的各模型下载地址：https://huggingface.co/models ，GPT2模型下载地址：https://huggingface.co/gpt2 ，里面有详细的模型使用方法，如果需要finetune自己的数据集，需要将预训练模型下载下来后，再在自己数据集上finetune 2-3个epoch。

## 中文GPT-2模型

用中文聊天数据训练的模型：liam168/chat-DialoGPT-small-zh

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

mode_name = 'liam168/chat-DialoGPT-small-zh'
tokenizer = AutoTokenizer.from_pretrained(mode_name)
model = AutoModelForCausalLM.from_pretrained(mode_name)

texts = ['你上几年级了？', '我上大学时是学生会主席']

for step, text in enumerate(texts):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens,
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)

    # pretty print last ouput tokens from bot
    print("Answer: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

Downloading:   0%|          | 0.00/616 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/863 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/487M [00:00<?, ?B/s]

## 中文xlnet模型


In [None]:
model_name = 'hfl/chinese-xlnet-base'
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(text_generator("我的爸爸是警察", max_length=50, do_sample=True))

xlnet模型做文本生成任务，需要补padding：

In [None]:
# Padding text helps XLNet with short prompts - proposed by Aman Rusia in https://github.com/rusiaaman/XLNet-gen#methodology
PADDING_TEXT = """1991年，俄国沙皇尼古拉斯二世及其家人的遗体
（除了阿列克谢和玛丽亚）被发现。
尼古拉斯的小儿子沙雷维奇·阿列克谢·尼古拉耶维奇的声音讲述了故事的其余部分。1883年西伯利亚西部，
一个年轻的格里戈里·拉斯普京被他的父亲和一群人邀请表演魔术。
拉斯普京有远见，谴责其中一人是偷马贼。虽然他的父亲最初因为这样的指控而打了他一巴掌，但拉斯普金看着这名男子被追赶到外面并被殴打。
二十年后，拉斯普京看到了圣母玛利亚的幻象，促使他成为一名牧师。拉斯普京很快成名，人们，甚至主教，都在乞求他的祝福。<eod> </s> <eos>"""
prompt = "今天俄国人开始在西伯利亚表演"
inputs = tokenizer.encode(PADDING_TEXT + prompt, add_special_tokens=False, return_tensors="pt")
prompt_length = len(tokenizer.decode(inputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
outputs = model.generate(inputs, max_length=250, do_sample=True, top_p=0.95, top_k=60)
generated = prompt + tokenizer.decode(outputs[0])[prompt_length:]
print(generated)