<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/inference/2021_transformer_supervised_AI_Writing_Demo_06_2021.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Connect to Google Drive to reference [models and vocabs](https://drive.google.com/drive/folders/1d5vk9nrse4lJ55wb5zsW2wgodkwWb-V2?usp=sharing) and Initialize

- Run all code and test examples by replacing chars
- Please note I set topk=1 and tempature=1.0 for reproduce, play with different inference params when you run it.
- 重要：[模型文件](https://drive.google.com/drive/folders/1d5vk9nrse4lJ55wb5zsW2wgodkwWb-V2?usp=sharing)存在Google Drive，推荐用Google账号打开，点击`Add to shortcut`，之后在你Drive的主页面`shared with me`看到目录后选择`add shortcut to Drive`，这样可以mount后本地可以操作文件，但要注意路径一致

## Imports

In [2]:
import pickle
import os
import pandas as pd
import numpy as np
!pip install keras-transformer &> /dev/null
os.environ['TF_KERAS'] = '1'
from keras_transformer import get_model, decode, get_custom_objects
import tensorflow as tf

## Load Configs

In [None]:
# Mount Google Drive if you have your copies of model/configs
from google.colab import drive
drive.mount('/content/drive')

In [3]:
# Copy https://drive.google.com/drive/folders/1d5vk9nrse4lJ55wb5zsW2wgodkwWb-V2 and match your local dir here
MODEL_DIR = 'drive/MyDrive/ML/Models/szhu_public_062021/'

In [4]:
# 如出错，请拷贝最开始介绍的那个Google Drive的所有文件，并mount到colab
!ls {MODEL_DIR}

couplet_model_config.pickle  couplet_vocab.pickle      poem_model.h5
couplet_model.h5	     poem_model_config.pickle  poem_vocab.pickle


In [5]:
with open(os.path.join(MODEL_DIR, 'couplet_model_config.pickle'), 'rb') as handle:
  couplet_model_config = pickle.load(handle)
with open(os.path.join(MODEL_DIR, 'couplet_vocab.pickle'), 'rb') as handle:
  couplet_vocab_dict = pickle.load(handle)
with open(os.path.join(MODEL_DIR, 'poem_model_config.pickle'), 'rb') as handle:
  poem_model_config = pickle.load(handle)
with open(os.path.join(MODEL_DIR, 'poem_vocab.pickle'), 'rb') as handle:
  poem_vocab_dict = pickle.load(handle)

In [6]:
rev_couplet_vocab_dict = {v: k for k, v in couplet_vocab_dict.items()}
rev_poem_vocab_dict = {v: k for k, v in poem_vocab_dict.items()}

In [7]:
assert 9133 == len(couplet_vocab_dict)
assert 11289 == len(poem_vocab_dict)

## Initialize models and sup methods

In [8]:
couplet_model = get_model(
    embed_weights=np.random.random((len(couplet_vocab_dict),
                                    couplet_model_config['embed_dim'])),
    **couplet_model_config)
couplet_model.load_weights(os.path.join(MODEL_DIR, 'couplet_model.h5'))


poem_model = get_model(
    embed_weights=np.random.random((len(poem_vocab_dict),
                                    poem_model_config['embed_dim'])),
    **poem_model_config)
poem_model.load_weights(os.path.join(MODEL_DIR, 'poem_model.h5'))

In [9]:
START_TOKEN_ID = poem_vocab_dict['<START>']
END_TOKEN_ID = poem_vocab_dict['<END>']
PAD_TOKEN_ID = poem_vocab_dict['<PAD>']

COUPLET_MAX_SEQ_LEN = 34
POEM_MAX_INPUT_SEQ = 14
POEM_MAX_OUTPUT_SEQ = 66

def couplet_inference(pre_couplet, top_k=1, temperature=1.0):
  out = "上: " + pre_couplet + "\n"
  in_vector = [START_TOKEN_ID]
  for c in pre_couplet:
    in_vector.append(couplet_vocab_dict[c])
  in_vector.append(END_TOKEN_ID)
  decoded = decode(
      couplet_model,
      [in_vector],
      start_token=couplet_vocab_dict['<START>'],
      end_token=couplet_vocab_dict['<END>'],
      pad_token=couplet_vocab_dict['<PAD>'],
      max_len=COUPLET_MAX_SEQ_LEN,
      top_k=top_k,
      temperature=temperature,
  )
  for i in range(len(decoded)):
    out += '下: ' + ''.join(map(lambda x: rev_couplet_vocab_dict[x],
                       decoded[i][1:-1]))
  print(out)

def poem_encode(raw_text, is_decode_input, is_decode_output):
  assert not (is_decode_input and is_decode_output)
  output = []
  if not is_decode_output:
    output.append(START_TOKEN_ID)
  for c in raw_text:
    output.append(poem_vocab_dict[c])
  output.append(END_TOKEN_ID)
  # padding
  total_size = POEM_MAX_OUTPUT_SEQ if is_decode_input or is_decode_output else POEM_MAX_INPUT_SEQ
  for i in range(total_size - len(output)):
    output.append(PAD_TOKEN_ID)
  return output

def poem_decode(token_ids):
  output = ""
  for token_id in token_ids:
    if token_id > 2:
      output += rev_poem_vocab_dict[token_id]
    elif token_id == 0:
      break
  return output

def poem_inference(title, top_k=1, temperature=1.0):
  out = "标题: " + title + "\n"
  decoded = decode(
      poem_model,
      poem_encode(title, False, False),
      start_token=START_TOKEN_ID,
      end_token=END_TOKEN_ID,
      pad_token=PAD_TOKEN_ID,
      max_len=POEM_MAX_OUTPUT_SEQ,
      top_k=top_k,
      temperature=temperature,
  )
  out += "正文: " + poem_decode(decoded)
  print(out)

poem_inference('秋思')
couplet_inference('欢天喜地度佳节')

标题: 秋思
正文: 秋风吹雨过，秋色满江城。一叶无人到，千山有客情。
上: 欢天喜地度佳节
下: 举国迎春贺新年


# Inference

In [10]:
for pre in ['欢天喜地度佳节', '不待鸣钟已汗颜，重来试手竟何艰',
            '当年欲跃龙门去，今日真披马革还', '载歌在谷']:
  couplet_inference(pre, top_k=1, temperature=1.0)

上: 欢天喜地度佳节
下: 举国迎春贺新年
上: 不待鸣钟已汗颜，重来试手竟何艰
下: 只缘沧海常风雨，再去翻身只等闲
上: 当年欲跃龙门去，今日真披马革还
下: 此际重逢凤阙来，明朝再赋凤凰鸣
上: 载歌在谷
下: 如醉如痴


In [11]:
for t in ['秋思', '百度', '湾区春日之谜', '自由而无用之灵魂']:
  poem_inference(t, top_k=1, temperature=1.0)

标题: 秋思
正文: 秋风吹雨过，秋色满江城。一叶无人到，千山有客情。
标题: 百度
正文: 百尺孤城上，千金万里中。山川无限水，水石有余风。
标题: 湾区春日之谜
正文: 春风吹雨不成秋，春色如何一日休。不是春光无处着，只应春色是人愁。
标题: 自由而无用之灵魂
正文: 我生不知，不识不知。我之不知，我之不知。我亦不知，不如不知。我亦不知，不知何爲。
