# 例子 - 使用 `huggingface/transfomers.GPT2LMHeadModel` 进行微调

> ❗ **要点**:
>
> NVIDIA/Megatron-LM 的 `tokenizer` 会在 `128` 对齐的基础上，强制加上 `8` 个特殊 `token`，所以注意 `token` 的偏移！


该 Notebook 的代码大部分来自: <https://github.com/huggingface/transformers/blob/master/examples/run_generation.py>

## 环境变量

- 不需要 `CUDA` 设备:

In [1]:
%env CUDA_VISIBLE_DEVICES=-1

env: CUDA_VISIBLE_DEVICES=-1


## 导入模组

In [2]:
import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler

try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter

from tqdm.auto import tqdm, trange

from transformers import (
    WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
    BertConfig, BertForMaskedLM, BertTokenizer,
    GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
    OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
    RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
    DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
)


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


## 常量定义

- 最大输出长度

In [None]:
MAX_OUTPUT_LENGTH = 128

- GPT2 模型所在目录

In [None]:
MODEL_DIR = '../checkpoints/hfgpt2-117m-emotion'

- SentencePiece 模型文件路径

In [None]:
SPM_MODEL = '../data/spm/gpt2_huamei_corpus_bpe_32k_v2.model'

## 全局函数定义

In [None]:

def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def isample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
                    is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        for _ in range(length):
            inputs = {'input_ids': generated}
            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)

            # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in set(generated.view(-1).tolist()):
                next_token_logits[i] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: #greedy sampling:
                next_token = torch.argmax(filtered_logits).unsqueeze(0)
                yield next_token
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
                yield next_token
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)


## 变量初始化

- SentencePiece tokenizer

In [None]:
tokenizer = spm.SentencePieceProcessor()
tokenizer.load(SPM_MODEL)

- GPT2 Model

In [None]:
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
model.eval()

## 输入文本

In [None]:
context_string = r'''
很多人都在说爱情不需要物质，尤其是陷入爱情的女性，往往容易将爱情和物质对立
'''.strip()
# 日常生活中经常会遇到一些人，说话很直，经常得罪人，但是他们往往自己并不知道，或者说即便知道也好像不太在乎

context_tokens = np.array(tokenizer.encode_as_ids(context_string))

## 输出文本

In [None]:
print(context_string)

TOKEN_OFFSET = 8

for token in tqdm(
    isample_sequence(
        model=model,
        context=context_tokens+TOKEN_OFFSET,
        length=MAX_OUTPUT_LENGTH,
    ),
    total=MAX_OUTPUT_LENGTH
):
    output_id= (token - TOKEN_OFFSET).numpy()
    text = tokenizer.decode_ids([int(i) for i in output_id if i >= 0])
    print(text, end='')
