In [2]:
import torch
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration

In [3]:
t5_config = T5Config(
    vocab_size=32128,
    d_model=768,
    d_kv=64,
    d_ff=2048,
    num_layers=12,
    num_decoder_layers=12,
    num_heads=12,
    relative_attention_num_buckets=32,
    dropout_rate=0.1,
    layer_norm_epsilon=1e-6,
    initializer_factor=1.0,
    feed_forward_proj="gated-gelu",
    is_encoder_decoder=True,
    use_cache=True,
    pad_token_id=0,
    eos_token_id=1,
    decoder_start_token_id=0,
    tie_word_embeddings=False,
    torch_dtype="float32",
    gradient_checkpointing=False)
model = T5ForConditionalGeneration(t5_config)
tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")

Downloading:   0%|          | 0.00/725k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/659 [00:00<?, ?B/s]

In [9]:
model.load_state_dict(torch.load("../serving/trained_model/20220327_kaggle/pytorch_model.bin", map_location=torch.device('cpu')))

<All keys matched successfully>

In [11]:
model.save_pretrained("../serving/trained_model/GuwenNet")

In [5]:
def generate_classic(text):
    input_ids = tokenizer("转古文：" + text, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_length=100)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

def generate_modern(text):
    input_ids = tokenizer("转现代文：" + text, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_length=100)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [10]:
generate_classic("先帝开创的事业没有完成一半，却中途去世了。现在天下分裂成三个国家。蜀汉民力困乏，这实在是危急存亡的时候啊。")

['先帝创业未半而中道,今天下为三国,蜀汉民困,此实危存之时也。']

In [8]:
generate_classic("我们今天一起吃饭")

['其身甚。']

In [15]:
generate_modern("先帝创业未半而中道崩殂。今天下三分，冀州疲敝，此乃危急存亡之秋")

['先帝创业不到一半,中途中途就崩殂,现在下边三分,冀州疲弱,这是危急存亡的秋日。']