# Megatron 进行预测

## CD

定位到工作目录，根据具体情况决定哦，不一定是下面的命令

In [None]:
%cd ..

## Importings

In [None]:
import copy
import csv
import json
import os
import random
import sys
import time
from contextlib import closing
from types import SimpleNamespace

import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

import mpu
from data_utils.tokenization import SentencePieceTokenizer, make_tokenizer
from pretrain_gpt2 import get_masks_and_position_ids
from predict_gpt2 import initialize_distributed, prepare_tokenizer, set_random_seed, setup_model, get_token_stream

## Args

In [None]:
args = SimpleNamespace(
    # Model arguments
    num_layers=12,
    hidden_size=768,
    num_attention_heads=12,
    max_position_embeddings=1024,
    vocab_size=None,
    make_vocab_size_divisible_by=128,
    attention_dropout=0.1,
    hidden_dropout=0.1,
    # Train/valid/test data arguments.
    seq_length=512,
    model_parallel_size=1,
    tokenizer_model_type='bert-large-uncased',
    tokenizer_type=SentencePieceTokenizer,
    tokenizer_path="./data/spm/gpt2_huamei_corpus_bpe_32k_v2.model",
    cache_dir=None,
    # Training arguments.
    load='./checkpoints/gpt2-117m-emotion.finetune/',
    seed=1234,
    checkpoint_activations=None,
    checkpoint_num_layers=1,
    finetune=None,
    no_load_optim=None,
    no_load_rng=None,
    resume_dataloader=None,
    fp16=True,
    hysteresis=2,
    loss_scale=None,
    loss_scale_window=1000,
    min_scale=1,
    distributed_backend='nccl',
    DDP_impl='local',
    local_rank=None,
    reset_position_ids=None,
    reset_attention_mask=None,
    eod_mask_loss=None, 
    # Text generate arguments.
    recompute=None,
    greedy=False,
    top_p=0.0,
    top_k=0,
    temperature=1.0,
    out_seq_length=128,
)

In [None]:
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))

if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
    # We are using (OpenMPI) mpirun for launching distributed data parallel processes
    local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
    local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))

    # Possibly running with Slurm
    num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
    nodeid = int(os.getenv('SLURM_NODEID', '0'))

    args.local_rank = local_rank
    args.rank = nodeid*local_size + local_rank
    args.world_size = num_nodes*local_size

args.model_parallel_size = min(args.model_parallel_size, args.world_size)
if args.rank == 0:
    print('using world size: {} and model-parallel size: {} '.format(
        args.world_size, args.model_parallel_size))

args.dynamic_loss_scale = False
if args.loss_scale is None:
    args.dynamic_loss_scale = True
    if args.rank == 0:
        print(' > using dynamic loss scaling')

# The args fp32_* or fp16_* meant to be active when the
# args fp16 is set. So the default behavior should all
# be false.
if not args.fp16:
    args.fp32_embedding = False
    args.fp32_tokentypes = False
    args.fp32_layernorm = False


## Init

In [None]:
%%time

# Disable CuDNN.
torch.backends.cudnn.enabled = False

# Pytorch distributed.
initialize_distributed(args)

# Random seeds for reproducability.
set_random_seed(args.seed)

# get the tokenizer
tokenizer = prepare_tokenizer(args)

# Model, optimizer, and learning rate.
model = setup_model(args)

args.device = torch.cuda.current_device()

# setting default batch size to 1
args.batch_size = 1

assert mpu.get_model_parallel_rank() == 0

## Inference functions

In [None]:
def infer_generative(contex_text, model, tokenizer):
    contex_text = contex_text.strip()
    context_tokens = tokenizer.EncodeAsIds(contex_text).tokenization
    context_length = len(context_tokens)

    if context_length >= args.seq_length//2:
        print(
            f"\nContext length {context_length}"
            "\nPlease give smaller context (half of the sequence length)!",
            file=sys.stderr
        )

    token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
    for output_tokens, _ in token_stream:
        ids = output_tokens.cpu().numpy().tolist()[0]
        s = tokenizer.DecodeIds([ids[-1]])
        yield s

## Infer texts

In [None]:
# 很多人都在说爱情不需要物质，尤其是陷入爱情的女性，往往容易将爱情和物质对立
#
# 日常生活中经常会遇到一些人，说话很直，经常得罪人，但是他们往往自己并不知道，或者说即便知道也好像不太在乎

# 南京长江大桥是长江上的一座桥梁
# 我最近迷恋上了早白垩纪土伦阶恐龙演化的相关知识，整天想得都是兽脚类，鸟臀类什么的，是不是心里有问题？很幼稚？能说说你对这个地质年代的知识吗？
# 广府文化是广府民系的文化。是以广州为核心、以珠江三角洲为通行范围的粤语文化，它从属于岭南文化，在岭南文化中个性最鲜明、影响最大，在各个领域常被作为粤文化的代称。
input_texts = [
    # '日常生活中经常会遇到一些人，说话很直，经常得罪人', # 情感
    '南京市长江大桥位于江苏', # 百科
    '广府文化是广府民系的文化。是以广州为核心、以珠江三角洲为通行范围的粤语文化', # 百科
    # '很多人都在说爱情不需要物质，尤其是陷入爱情的女性',    # 情感
    '啤酒作为一种风味独特的酒精饮料，苦而爽口，幽香清雅'
]


In [None]:
n_gen = 3

for txt in input_texts:
    print(txt)
    print()
    for i in range(n_gen):
        print(f'{i+1}) ', end='')
        for s in infer_generative(txt, model, tokenizer):
            print(s, end='')
        print()
    print()
    print('=' * 100)
    print()
    