In [14]:
import random


class Tokenizer:

    def __init__(self):
        self.vocab = {
            'mark': list('PSEU'),
            'number': list('0123456789'),
            'letter': list('pqwertyuio'),
            'chinese_lower': list('〇一二三四五六七八九'),
            'chinese_upper': list('零壹贰叁肆伍陆柒捌玖'),
            'other': list('数字大写小母:=_'),
        }

        self.decoder = [j for i in self.vocab.values() for j in i]
        self.encoder = {j: i for i, j in enumerate(self.decoder)}

        self.label = {
            'number': 0,
            'letter': 1,
            'chinese_lower': 2,
            'chinese_upper': 3
        }
        self.prefix = ['数字', '字母', '小写', '大写']

    def decode(self, x):
        return ''.join([self.decoder[i] for i in x])

    def get_data(self, prefix):
        #生成问题和答案
        question = random.randint(1000, 9999)
        answer = int(str(question) * 4) * 4
        #answer = question**8
        
        question = list(str(question))
        answer = list(str(answer))

        #随机label
        label = random.choice(list(self.label.keys()))

        #根据label替换答案成其他字符集
        answer = [self.vocab[label][int(i)] for i in answer]

        #label转数字
        label = self.label[label]

        #组合问题和答案
        if prefix:
            prefix = list(self.prefix[label])
        else:
            prefix = list('__')
        token = prefix + [':'] + question + ['='] + answer

        #编码
        token = [self.encoder[i] for i in token]
        token = [self.encoder['S']] + token + [self.encoder['E']]

        return label, token

    def get_batch_data(self, prefix):
        data = [self.get_data(prefix=prefix) for _ in range(64)]

        label = [i[0] for i in data]
        token = [i[1] for i in data]

        return label, *self.batch_pad(token=token)

    def batch_pad(self, text=None, token=None):
        if text:
            #编码
            token = [[self.encoder[j] for j in i] for i in text]

        lens = max([len(i) for i in token])

        input_ids = []
        attention_mask = []
        for i in token:
            attention_mask.append([1] * len(i) + [0] * (lens - len(i)))
            input_ids.append(i + [self.encoder['P']] * (lens - len(i)))

        return input_ids, attention_mask


tokenizer = Tokenizer()

# [tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:10]
[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]]

['S字母:4072=qywioywioywioywiiE',
 'S数字:9552=38211821182118208E',
 'S数字:2761=11045104510451044E',
 'S字母:9826=eoepuoepuoepuoeprE',
 'S小写:1659=六六三六六六三六六六三六六六三六EP',
 'S字母:1690=yuypyuypyuypyuypEP',
 'S数字:7819=31279127912791276E',
 'S小写:5754=二三〇一八三〇一八三〇一八三〇一六E',
 'S字母:7683=epuetpuetpuetpuewE',
 'S字母:3736=qrortrortrortrorrE',
 'S大写:4882=壹玖伍贰玖玖伍贰玖玖伍贰玖玖伍贰捌E',
 'S大写:6507=贰陆零叁零陆零叁零陆零叁零陆零贰捌E',
 'S大写:2747=壹零玖捌玖零玖捌玖零玖捌玖零玖捌捌E',
 'S字母:9709=eiieoiieoiieoiieyE',
 'S数字:2022=8088808880888088EP',
 'S大写:6868=贰柒肆柒肆柒肆柒肆柒肆柒肆柒肆柒贰E',
 'S字母:3044=qwquuwquuwquuwquyE',
 'S大写:3059=壹贰贰叁柒贰贰叁柒贰贰叁柒贰贰叁陆E',
 'S数字:8316=33267326732673264E',
 'S字母:6345=wteiwteiwteiwteipE',
 'S大写:5408=贰壹陆叁肆壹陆叁肆壹陆叁肆壹陆叁贰E',
 'S小写:2499=九九九六九九九六九九九六九九九六EP',
 'S小写:5256=二一〇二六一〇二六一〇二六一〇二四E',
 'S字母:7908=eqyetqyetqyetqyewE',
 'S字母:9370=eurieurieurieuripE',
 'S数字:2954=11817181718171816E',
 'S字母:2502=qpppopppopppopppiE',
 'S大写:5686=贰贰柒肆陆贰柒肆陆贰柒肆陆贰柒肆肆E',
 'S数字:6815=27262726272627260E',
 'S字母:8127=ewtqqwtqqwtqqwtpiE',
 'S大写:4700=壹捌捌零壹捌捌零壹捌捌零壹捌捌零零E',
 'S小写:89

In [15]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [16]:
class ModelGEN(torch.nn.Module):

    def __init__(self):
        super().__init__()
        from transformers import GPT2Config, GPT2Model

        self.config = GPT2Config(bos_token_id=tokenizer.encoder['S'],
                                 eos_token_id=tokenizer.encoder['E'],
                                 n_embd=64,
                                 n_head=4,
                                 n_layer=4,
                                 n_positions=128,
                                 vocab_size=len(tokenizer.decoder))

        self.feature = GPT2Model(self.config)

        self.fc_out = torch.nn.Linear(64, self.config.vocab_size, bias=False)

        self.to(device)
        self.train()

    def forward(self, input_ids, attention_mask):
        out = self.feature(input_ids=input_ids,
                           attention_mask=attention_mask).last_hidden_state

        return self.fc_out(out)

model_gen = ModelGEN().to(device)
# model

In [17]:
class ModelCLS(torch.nn.Module):

    def __init__(self):
        super().__init__()
        from transformers import BertConfig, BertModel

        self.config = BertConfig(hidden_size=64,
                                 intermediate_size=64,
                                 max_position_embeddings=128,
                                 num_attention_heads=4,
                                 num_hidden_layers=4,
                                 vocab_size=len(tokenizer.decoder))

        self.feature = BertModel(self.config)

        self.fc_out = torch.nn.Sequential(torch.nn.Dropout(p=0.1),
                                          torch.nn.Linear(64, 4))

        self.to(device)
        self.train()

    def forward(self, input_ids, attention_mask):
        out = self.feature(input_ids=input_ids,
                           attention_mask=attention_mask).pooler_output

        return self.fc_out(out)

model = ModelCLS().to(device)
model

ModelCLS(
  (feature): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(53, 64, padding_idx=0)
      (position_embeddings): Embedding(128, 64)
      (token_type_embeddings): Embedding(2, 64)
      (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=64, out_features=64, bias=True)
              (key): Linear(in_features=64, out_features=64, bias=True)
              (value): Linear(in_features=64, out_features=64, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=64, out_features=64, bias=True)
              (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
          

In [18]:
class ModelPPO(torch.nn.Module):

    def __init__(self, model_gen):
        super().__init__()
        self.model_gen = model_gen
        self.v_head = torch.nn.Sequential(torch.nn.Dropout(0.1),
                                          torch.nn.Linear(64, 1))

        self.to(device)
        self.train()

    def forward(self, input_ids, attention_mask):
        last_hidden_state = self.model_gen.feature(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True).last_hidden_state

        logits = self.model_gen.fc_out(last_hidden_state)
        value = self.v_head(last_hidden_state).squeeze(-1)

        return logits, value

model = ModelPPO(model_gen).to(device)
model

ModelPPO(
  (model_gen): ModelGEN(
    (feature): GPT2Model(
      (wte): Embedding(53, 64)
      (wpe): Embedding(128, 64)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-3): 4 x GPT2Block(
          (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (fc_out): Linear(in_features=64, out_features=53, bias=False)
  )
  (v_head): Sequential(
    (0): Dropout(p=0.1, inplace=False)

In [19]:
generater = None


def generate(model_gen, input_ids):
    global generater
    if not generater:
        #包装类,用于生成
        from transformers import GPT2LMHeadModel
        generater = GPT2LMHeadModel(model_gen.config)
        generater.transformer = model_gen.feature
        generater.lm_head = model_gen.fc_out
        generater.to(device)

    return generater.generate(input_ids=input_ids,
                              min_length=-1,
                              top_k=0.0,
                              top_p=1.0,
                              do_sample=True,
                              pad_token_id=tokenizer.encoder['P'],
                              max_new_tokens=25,
                              eos_token_id=tokenizer.encoder['E'])