In [1]:
# !wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip -O xlnet.zip
# !unzip xlnet.zip

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [3]:
import sentencepiece as spm
from prepro_utils import preprocess_text, encode_ids

sp_model = spm.SentencePieceProcessor()
sp_model.Load('xlnet_cased_L-12_H-768_A-12/spiece.model')

def tokenize_fn(text):
    text = preprocess_text(text, lower= False)
    return encode_ids(sp_model, text)

In [4]:
SEG_ID_A   = 0
SEG_ID_B   = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4

special_symbols = {
    "<unk>"  : 0,
    "<s>"    : 1,
    "</s>"   : 2,
    "<cls>"  : 3,
    "<sep>"  : 4,
    "<pad>"  : 5,
    "<mask>" : 6,
    "<eod>"  : 7,
    "<eop>"  : 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]

In [5]:
text = 'A politician is a person active in party politics, or a person holding or seeking office in government. Politicians propose, support and create laws or policies that govern the land and, by extension, its people'
text

'A politician is a person active in party politics, or a person holding or seeking office in government. Politicians propose, support and create laws or policies that govern the land and, by extension, its people'

In [6]:
aug_percent = 0.8
splitted = text.split()
size = len(splitted)
cnt = int(aug_percent * size)
cnt

28

In [7]:
import json
with open('en.json') as fopen:
    stopwords = json.load(fopen)

In [8]:
import random
import string

results = []
samples = random.sample([i for i in range(size)], cnt)
for token_idx, token in enumerate(samples):
    if splitted[token] in string.punctuation:
        continue
    if splitted[token] in stopwords:
        continue
    results.append(token)
    
results

[1, 18, 4, 25, 19, 14, 0, 23, 11, 30, 27, 17, 15, 20, 5, 8, 29, 22]

In [29]:
import numpy as np

def tokenizer(string, mask_id):
    string = string.split()
    ids = []
    for no, word in enumerate(string):
        if no == mask_id:
            ids.append(MASK_ID)
        ids.extend(tokenize_fn(word))
    mask_ind = ids.index(MASK_ID)
    segment_id = [SEG_ID_A] * len(ids)
    input_mask = [0] * len(ids)
    
    perm_masks = np.zeros((1, len(ids)))
    perm_masks[0, mask_ind] = 1.0
    target_mappings = np.zeros((1, len(ids)))
    target_mappings[0, mask_ind] = 1.0
    
    return ids, segment_id, input_mask, mask_ind, perm_masks, target_mappings

In [12]:
import xlnet
import tensorflow as tf
import model_utils

kwargs = dict(
      is_training=True,
      use_tpu=False,
      use_bfloat16=False,
      dropout=0.0,
      dropatt=0.0,
      init='normal',
      init_range=0.1,
      init_std=0.05,
      clamp_len=-1)

xlnet_parameters = xlnet.RunConfig(**kwargs)
xlnet_config = xlnet.XLNetConfig(json_path='xlnet_cased_L-12_H-768_A-12/xlnet_config.json')

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])






  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [13]:
def top_k_logits(logits, k):
    if k == 0:
        return logits

    def _top_k():
        values, _ = tf.nn.top_k(logits, k=k)
        min_values = values[:, -1, tf.newaxis]
        return tf.where(
            logits < min_values,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )
    return tf.cond(
        tf.equal(k, 0),
        lambda: logits,
        lambda: _top_k(),
    )

def top_p_logits(logits, p):
    with tf.variable_scope('top_p_logits'):
        logits_sort = tf.sort(logits, direction='DESCENDING')
        probs_sort = tf.nn.softmax(logits_sort)
        probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)
        logits_masked = tf.where(probs_sums < p, logits_sort, tf.ones_like(
            logits_sort)*1000)  # [batchsize, vocab]
        min_logits = tf.reduce_min(logits_masked, axis=1, keepdims=True)  # [batchsize, 1]
        return tf.where(
            logits < min_logits,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )

class Model:
    def __init__(
        self,
    ):
        self.X = tf.placeholder(tf.int32, [None, None])
        self.segment_ids = tf.placeholder(tf.int32, [None, None])
        self.input_masks = tf.placeholder(tf.float32, [None, None])
        self.perm_masks = tf.placeholder(tf.float32, [None, None, None])
        self.target_mappings = tf.placeholder(tf.float32, [None, None, None])
        self.top_p = tf.placeholder(tf.float32, None)
        self.top_k = tf.placeholder(tf.int32, None)
        self.k = tf.placeholder(tf.int32, None)
        self.temperature = tf.placeholder(tf.float32, None)
        self.indices = tf.placeholder(tf.int32, [None, None])
        
        xlnet_model = xlnet.XLNetModel(
            xlnet_config=xlnet_config,
            run_config=xlnet_parameters,
            input_ids=self.X,
            seg_ids=self.segment_ids,
            input_mask=self.input_masks,
            perm_mask = self.perm_masks,
            target_mapping = self.target_mappings
        )
        
        output = xlnet_model.get_sequence_output()
        self.output = output
        lookup_table = xlnet_model.get_embedding_table()

        initializer = xlnet_model.get_initializer()
        with tf.variable_scope('model', reuse = tf.AUTO_REUSE):
            with tf.variable_scope('lm_loss'):
                softmax_w = lookup_table
                softmax_b = tf.get_variable(
                    'bias',
                    [xlnet_config.n_token],
                    dtype = output.dtype,
                    initializer = tf.zeros_initializer(),
                )
                logits = tf.einsum('ibd,nd->ibn', output, softmax_w) + softmax_b
                self.logits = logits
                
        logits = tf.gather_nd(self.logits, self.indices)
        logits = logits / self.temperature
        
        def necleus():
            return top_p_logits(logits, self.top_p)
        
        def select_k():
            return top_k_logits(logits, self.top_k)
        
        logits = tf.cond(self.top_p > 0, necleus, select_k)
        self.samples = tf.multinomial(
                logits, num_samples=self.k, output_dtype=tf.int32)

In [14]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Model()

sess.run(tf.global_variables_initializer())




INFO:tensorflow:memory input None
INFO:tensorflow:Use float type <dtype: 'float32'>

Instructions for updating:
Use keras.layers.dropout instead.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Use `tf.random.categorical` instead.


In [15]:
import collections
import re

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    """Compute the union of the current variables and checkpoint variables."""
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match('^(.*):\\d+$', name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var

    init_vars = tf.train.list_variables(init_checkpoint)

    assignment_map = collections.OrderedDict()
    for x in init_vars:
        (name, var) = (x[0], x[1])
        if name not in name_to_variable:
            continue
        assignment_map[name] = name_to_variable[name]
        initialized_variable_names[name] = 1
        initialized_variable_names[name + ':0'] = 1

    return (assignment_map, initialized_variable_names)

In [16]:
tvars = tf.trainable_variables()
checkpoint = 'xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt'
assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, 
                                                                                checkpoint)

In [17]:
saver = tf.train.Saver(var_list = assignment_map)
saver.restore(sess, checkpoint)

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt


In [20]:
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [31]:
tokenized = [tokenizer(text, result) for result in results]
a = list(zip(*tokenized))
len(a)

6

In [32]:
# ids, segment_id, input_mask, mask_ind, perm_masks, target_mappings

batch_x = pad_sequences(a[0],padding='post')
batch_segment = pad_sequences(a[1],padding='post', value = SEG_ID_PAD)
batch_mask = pad_sequences(a[2],padding='post', value = 1)
perm_masks = pad_sequences(a[4],padding='post')
target_mappings = pad_sequences(a[5],padding='post')

In [33]:
indices = a[3]
batch_indices = np.array([np.arange(len(indices)), indices]).T
batch_indices.shape

(18, 2)

In [34]:
batch_mask.shape

(18, 43)

In [36]:
# self.segment_ids = tf.placeholder(tf.int32, [None, None])
# self.input_masks = tf.placeholder(tf.float32, [None, None])
# self.perm_masks = tf.placeholder(tf.float32, [None, None, None])
# self.target_mappings = tf.placeholder(tf.float32, [None, None, None])

samples = sess.run(model.samples, feed_dict = {model.X: batch_x,
                                    model.input_masks: batch_mask,
                                    model.segment_ids: batch_segment,
                                    model.perm_masks: perm_masks,
                                    model.target_mappings: target_mappings,
                                    model.top_p: 0.8,
                                    model.top_k: 100,
                                    model.temperature: 0.8,
                                    model.indices: batch_indices,
                                    model.k: 5})

In [37]:
def convert_ids_to_tokens(ids):
    return [sp_model.IdToPiece(i) for i in ids]

In [41]:
for i in range(samples.shape[1]):
    print('SAMPLE %d'%(i))
    sample_i = samples[:, i]
    samples_tokens = convert_ids_to_tokens(samples[:, i].tolist())
    new_splitted = splitted[:]
    for no, index in enumerate(results):
        new_splitted[index] = samples_tokens[no]

    new = ' '.join(new_splitted)
    print('BEFORE:', text)
    print('AFTER:', new)
    print()

SAMPLE 0
BEFORE: A politician is a person active in party politics, or a person holding or seeking office in government. Politicians propose, support and create laws or policies that govern the land and, by extension, its people
AFTER: . ly is a ▁This ly in party ▁Eventually or a s holding or ▁The ly in ▁Typically ly ▁Typically s and s s or ▁Generally that ▁These the ▁ s by extension, its people

SAMPLE 1
BEFORE: A politician is a person active in party politics, or a person holding or seeking office in government. Politicians propose, support and create laws or policies that govern the land and, by extension, its people
AFTER: ▁It ▁ is a ▁Zo . in party ly or a ▁concurrently holding or ▁Typically ly in s ▁upon ly s and s s or ly that ly the ▁as ▁Upon by extension, its people

SAMPLE 2
BEFORE: A politician is a person active in party politics, or a person holding or seeking office in government. Politicians propose, support and create laws or policies that govern the land and, by extens