# 例子 - 使用 `huggingface/transfomers.GPT2LMHeadModel` 进行预测

> ❗ **要点**:
>
> NVIDIA/Megatron-LM 的 `tokenizer` 会在 `128` 对齐的基础上，强制加上 `8` 个特殊 `token`，所以注意 `token` 的偏移！


该 Notebook 的代码大部分来自: <https://github.com/huggingface/transformers/blob/master/examples/run_generation.py>

## 工作目录

In [None]:
%cd ..

## 环境变量

- 预测时，不需要 `CUDA` 设备:

In [None]:
%env CUDA_VISIBLE_DEVICES=-1

## 导入模组

In [None]:
import json
import os
from contextlib import closing, ExitStack
from functools import partial

import torch
import torch.nn.functional as F
import numpy as np
from tqdm.auto import trange, tqdm

import sentencepiece as spm
from transformers import GPT2LMHeadModel, GPT2Config

from data_utils.tokenization import SentencePieceTokenizer, make_tokenizer

## 常量定义

- 最大输出长度

In [None]:
MAX_OUTPUT_LENGTH = 64

- GPT2 模型所在目录

In [None]:
MODEL_DIRS = [
    'checkpoints/tr_191023',
#     'checkpoints/hfgpt2-117m-emotion',
#     'checkpoints/gpt2.hf-117m.finetune-baike-dev/checkpoint-100',
]

- SentencePiece 模型文件路径

In [None]:
SPM_MODEL = 'data/spm/gpt2_huamei_corpus_bpe_32k_v2.model'

## 全局函数定义

In [None]:

def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def gen_next_token(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        for _ in range(length):
            inputs = {'input_ids': generated}
            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)

            # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in set(generated.view(-1).tolist()):
                next_token_logits[i] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: #greedy sampling:
                next_token = torch.argmax(filtered_logits).unsqueeze(0)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
            yield next_token[0]


def predict(model, tokenizer, ids, length=256, printing=True, tqdm_callable=None):
    result = ''
    stop_flags = [
        tokenizer.get_command(n).Id
        for n in ('eos', 'pad')
    ]
    with closing(
        gen_next_token(model=model, length=length, context=ids)
    ) as iterable:
        if tqdm_callable:
           iterable =  tqdm_callable(iterable)
        for token in iterable:
            output_id = int(token.numpy())
            if output_id >= 0:
                if output_id in stop_flags:
                    break
                s = tokenizer.DecodeIds([output_id])
                result += s
                if printing:
                    print(s, end='')
    return result


## 变量初始化

- SentencePiece tokenizer

In [None]:
%%time

tokenizer = make_tokenizer(SentencePieceTokenizer, None, model_path=SPM_MODEL)

- GPT2 Model

In [None]:
%%time

models = {}
for model_dir in MODEL_DIRS:
    print(f'Load model from {model_dir} ...')   
    config = GPT2Config.from_pretrained(model_dir)
    model = GPT2LMHeadModel.from_pretrained(model_dir, config=config).eval()
    models[model_dir] = model
    print('Ok')


## 预测多个短文本

### 输入文本

In [None]:
# 很多人都在说爱情不需要物质，尤其是陷入爱情的女性，往往容易将爱情和物质对立
#
# 日常生活中经常会遇到一些人，说话很直，经常得罪人，但是他们往往自己并不知道，或者说即便知道也好像不太在乎

# 南京长江大桥是长江上的一座桥梁
# 我最近迷恋上了早白垩纪土伦阶恐龙演化的相关知识，整天想得都是兽脚类，鸟臀类什么的，是不是心里有问题？很幼稚？能说说你对这个地质年代的知识吗？
# 广府文化是广府民系的文化。是以广州为核心、以珠江三角洲为通行范围的粤语文化，它从属于岭南文化，在岭南文化中个性最鲜明、影响最大，在各个领域常被作为粤文化的代称。
input_texts = [
    '日常生活中经常会遇到一些人，说话很直，经常得罪人',
    '南京市长江大桥位于江苏',
    '广府文化是广府民系的文化。是以广州为核心、以珠江三角洲为通行范围的粤语文化',
    '很多人都在说爱情不需要物质，尤其是陷入爱情的女性',    
]

input_list = [
    [int(n) for n in tokenizer.EncodeAsIds(s.strip())]
    for s in input_texts
]

### 输出文本

In [None]:
for model_dir, model in models.items():
    print(f'Predict with {model_dir} ')
    print()

    for txt, ids in zip(input_texts, input_list):
        print(txt)
        
        predict(model, tokenizer, ids, length=MAX_OUTPUT_LENGTH)
        print()
        print('-'*100)
        print()

    print()
    print('='*100)
    print()


## 文件预测

In [None]:
import csv

input_file = 'data/test_xinli_qax_convai.json'
output_file = 'data/test_xinli_qax_convai_预测结果-1024_04.tsv'

n_gen = 3
length = MAX_OUTPUT_LENGTH

total = sum(1 for _ in open(input_file))
tqdm_callable = lambda x: tqdm(x, desc=f'sample {j+1} ({k+1}/{n_gen})', leave=False)
for i, (model_dir, model) in enumerate(models.items()):
    print(f'Predict with {model_dir} ')
    print()
    
    root, ext = os.path.splitext(output_file)
    output_file = f'{root}.{i+1}{ext}'

    with ExitStack() as stack:
        fp_input = stack.enter_context(open(input_file))
        fp_output = stack.enter_context(open(output_file, 'w'))
        writer = csv.writer(fp_output, delimiter='\t')
        
        for j, line in tqdm(enumerate(fp_input), total=total):
            line = line.strip()
            if not line:
                continue
            data = json.loads(line)
            text = (data.get('text') or '').strip()
            if not text:
                continue
            ids = [int(n) for n in tokenizer.EncodeAsIds(text.strip())]
            # 可能超长的不要
            if len(ids) + length > model.config.n_ctx:
                break
            row = [text]
            for k in range(n_gen):
                s_gen = predict(model, tokenizer, ids, length, printing=False, tqdm_callable=tqdm_callable)
                row.append(s_gen)
            writer.writerow(row)
           