In [18]:
import os

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

import tensorflow_hub as hub
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from official.modeling import tf_utils
from official import nlp
from official.nlp import bert

# Load the required submodules
import official.nlp.bert.bert_models
import official.nlp.bert.configs
import official.nlp.bert.tokenization
from official.nlp.modeling import models

import json

max_len_mask = 3

In [19]:
gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12"
tf.io.gfile.listdir(gs_folder_bert)

['bert_config.json',
 'bert_model.ckpt.data-00000-of-00001',
 'bert_model.ckpt.index',
 'vocab.txt']

In [20]:
# Set up tokenizer to generate Tensorflow dataset
tokenizer = bert.tokenization.FullTokenizer(
    vocab_file=os.path.join(gs_folder_bert, "vocab.txt"),
     do_lower_case=True)

print("Vocab size:", len(tokenizer.vocab))

Vocab size: 30522


In [22]:
#phrase is connected with '_', single word is also masked with '_'
def load_data(input_file, max_len_mask, max_seq_len, max_prediction_per_sequence, is_eval_data = True):
    original_sents = []
    masked_sents = []
    vocab_positions = []
    vocab_synomies = []
    phrases = []
    masked_lm_positions = []
    masked_lm_weights = []
    idx = 1
    sent_ids = []
    with tf.io.gfile.GFile(input_file, "r") as reader:
            while True:
                line = reader.readline().strip()
                if not line:
                    break
                parts = line.split('\t')
                sent = parts[0]
                phrase = parts[1]
                synomies = []
                if is_eval_data:
                    synomies = parts[3:]
                
                words = sent.split()
                
                for i in range(max_len_mask):
                    #create a new sentence, then masks will replace the original phrases
                    #e.g. This is the state_of_the_art method.
                    # --> This is the [MASK] [MASK] [MASK] method.
                    # --> This is the [MASK] [MASK] method.
                    # --> This is the [MASK] method.
                    new_tokens = []
                    tokens = []
                    for word in words:
                        if word.__contains__('_'):
                            if word.endswith('_'):
                                tokens.append(word.replace('_', ''))
                            else:
                                sub_ps = word.split('_')
                                tokens.extend(sub_ps)
                            for x in range(i + 1):
                                new_tokens.append('mm')
                        else:
                            new_tokens.append(word)
                            tokens.append(word)
                    tokens_a = tokenizer.tokenize(' '.join(new_tokens))
                    tokens_b = tokenizer.tokenize(' '.join(tokens))
                    
                    if len(tokens_a) + len(tokens_b) > max_seq_len - 3:
                        evg_len = int((max_seq_len - 3) / 2)
                        tokens_a = tokens_a[:evg_len]
                        tokens_b = tokens_b[:evg_len]
                        
                    seq_len = len(tokens_a) + len(tokens_b) + 3
                    sub_ps = tokenizer.tokenize(phrase)
                    window_size = len(sub_ps)
                    phrase_ids = []
                    i = 0
                    while (i + window_size < len(tokens_b)):
                        if tokens_b[i: i + window_size] == sub_ps:
                            start = i + len(tokens_a) + 2
                            end = start + window_size - 1
                            phrase_ids.append((start, end))
                        i += 1
                    original_sents.append(tokens_b)
                    vocab_positions.append(phrase_ids)
                    phrases.append(phrase)
                    if len(synomies) > 0:
                        vocab_synomies.append(synomies)
                    masked_tokens, masked_positions, masked_weights = create_masks(tokens_a, max_prediction_per_sequence)
                    masked_sents.append(masked_tokens)
                    masked_lm_weights.append(masked_weights)
                    masked_lm_positions.append(masked_positions)
                    sent_ids.append(str(idx)+'_'+str(i))
        idx += 1
    return masked_sents,masked_lm_positions, masked_lm_weights, original_sents, vocab_synomies, vocab_positions, phrases, sent_ids

In [None]:
#phrase is connected with '_', single word is also masked with '_'
def load_data(input_file, max_len_mask, max_seq_len, max_prediction_per_sequence, is_eval_data = True):
    original_sents = []
    masked_sents = []
    vocab_positions = []
    vocab_synomies = []
    phrases = []
    masked_lm_positions = []
    masked_lm_weights = []
    idx = 1
    sent_ids = []
    with tf.io.gfile.GFile(input_file, "r") as reader:
            while True:
                line = reader.readline().strip()
                if not line:
                    break
                parts = line.split('\t')
                sent = parts[0]
                phrase = parts[1]
                synomies = []
                if is_eval_data:
                    synomies = parts[3:]
                
                words = sent.split()
                original_tokens = []
                for word in words:
                    if word.__contains__('_'):
                        if word.endswith('_'):
                            original_tokens.append(word.replace('_', ''))
                        else:
                            sub_ps = word.split('_')
                            original_tokens.extend(sub_ps)
                    else:
                        original_tokens.append(word)
                        
                orginal_sent = tokenizer.tokenize(' '.join(original_tokens))
                sub_ps = tokenizer.tokenize(phrase)
                window_size = len(sub_ps)
                phrase_ids = []
                i = 0
                while(i + window_size < len(orginal_sent)):
                    if orginal_sent[i:i + window_size] == sub_ps:
                        start = i 
                        end = start + window_size - 1
                        phrase_ids.append((start, end))
                    i += 1
                #record orginal sentence and its phrase positions, the sentence index is the array index    
                original_sents.append((tokens,phrase_ids, synomies))
                
                for i in range(max_len_mask):
                    #create a new sentence, then masks will replace the original phrases
                    #e.g. This is the state_of_the_art method.
                    # --> This is the [MASK] [MASK] [MASK] method.
                    # --> This is the [MASK] [MASK] method.
                    # --> This is the [MASK] method.
                    new_tokens = []
                    tokens = []
                    for word in words:
                        if word.__contains__('_'):
                            for x in range(i + 1):
                                new_tokens.append('mm')
                        else:
                            new_tokens.append(word)
                            
                    tokens_a = tokenizer.tokenize(' '.join(new_tokens))
                    
                    masked_tokens, masked_positions, masked_weights = create_masks(tokens_a, max_prediction_per_sequence)
                    masked_sents.append(masked_tokens)
                    masked_lm_weights.append(masked_weights)
                    masked_lm_positions.append(masked_positions)
                    sent_ids.append(str(idx)+'_'+str(i))
        idx += 1
    return masked_sents,masked_lm_positions, masked_lm_weights, original_sents

In [23]:
def create_masks(tokens, max_prediction_per_sequence):
    idx = 1
    masked_positions = []
    masked_tokens = []
    for token in tokens:
        if token == 'mm':
            masked_tokens.append('[MASK]')
            masked_positions.append(idx)
        else:
            masked_tokens.append(token)
        idx += 1
        
    masked_positions.sort()
    masked_lm_positions = masked_positions
    masked_lm_weights = [1]*len(masked_positions)
    #print(masked_tokens)
    assert len(masked_positions) <= max_prediction_per_sequence
    
    paddings = max_prediction_per_sequence - len(masked_positions)
    masked_lm_positions.extend([0]*(paddings))
    masked_lm_weights.extend([0]*paddings)
    return masked_tokens, masked_lm_positions, masked_lm_weights

In [24]:
def encode_sentence(tokens):
    tokens.append('[SEP]')
    return tokenizer.convert_tokens_to_ids(tokens)

In [25]:
def bert_encode(masked_sents, 
                original_sents, 
                masked_lm_positions, 
                masked_lm_weights, 
                max_prediction_per_sequence):
    
    sentence1 = tf.ragged.constant([encode_sentence(s) for s in np.array(masked_sents)])
    sentence2 = tf.ragged.constant([encode_sentence(s) for s in np.array(original_sents)])
    print(sentence1.shape)
    print(sentence2.shape)
    cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*sentence1.shape[0]
    input_word_ids = tf.concat([cls, sentence1, sentence2], axis=-1)
    
    input_mask = tf.ones_like(input_word_ids).to_tensor()
    type_cls = tf.zeros_like(cls)
    type_s1 = tf.zeros_like(sentence1)
    type_s2 = tf.ones_like(sentence2)
    input_type_ids = tf.concat(
      [type_cls, type_s1, type_s2], axis=-1).to_tensor()
    
    
    
    masked_lm_positions = tf.convert_to_tensor(masked_lm_positions, dtype=tf.int32)
    masked_lm_weights = tf.convert_to_tensor(masked_lm_weights, dtype=tf.int32)
    masked_lm_ids = tf.zeros_like(masked_lm_weights)
    
    inputs = {
      'input_word_ids': input_word_ids.to_tensor(),
      'input_mask': input_mask,
      'input_type_ids': input_type_ids, 
      'masked_lm_weights': masked_lm_weights, 
      'masked_lm_positions': masked_lm_positions,
      'masked_lm_ids':masked_lm_ids,
    }
    
    return inputs

In [None]:
def extract_features():
    

In [28]:
def calculate_similarity(substitute, original):
    ps_a = np.mean(substitute, axis=0)
    ps_b = np.mean(substitute, axis=0)
    
    return consine_similarity(ps_a, ps_b)

In [None]:
def combine_tokens(tokens):
    len_phrase = len(tokens[0])
    all_phrases = []
    for n in range(len_phrase):
        phrase = []
        for token_parts in tokens:
            phrase.append(token_parts[n])
        all_phrases.append(phrase)
    return all_phrases

In [None]:
class SentenceEntity():
    def __init__(self, sentence, phrase_position, sent_id):
        self.sentence = sentence
        self.phrase_position = phrase_position
        self.sent_id = sent_id
    
    def get_input(self):
        

In [29]:
def main():
    input_file = '/home/weiwei/lexical_simplification/BenchPS.txt'
    max_prediction_per_sequence = 5
    max_seq_len = 64
    max_len_mask = 3
    top_k = 10
    masked_sents, masked_lm_positions, masked_lm_weights, original_sents, vocab_synomies, vocab_positions, phrases, sent_ids = load_data(
        input_file = input_file, max_len_mask = max_len_mask,
        max_seq_len = max_seq_len, max_prediction_per_sequence = max_prediction_per_sequence)
    
    
    
#     for i in range(5):
#         print(masked_sents[i])
    
    inputs = bert_encode(masked_sents, original_sents, masked_lm_positions, masked_lm_weights, max_prediction_per_sequence = max_prediction_per_sequence)
    
    bert_config_file = os.path.join(gs_folder_bert, "bert_config.json")
    config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())
    bert_config = bert.configs.BertConfig.from_dict(config_dict)
    model, encoder, pretrained = bert.bert_models.pretrain_model(bert_config=bert_config, seq_length=max_seq_len,max_predictions_per_seq=max_prediction_per_sequence,  use_next_sentence_label=False, return_core_pretrainer_model=True)
    
    #lm_outputs = pretrain_model(bert_config=bert_config, seq_length=max_seq_len,max_predictions_per_seq=max_prediction_per_sequence)
    predictions = pretrained.predict(inputs)
    embeddings = encoder.predict(inputs)
    results = predictions['masked_lm']
    
    masked_weights = inputs['masked_lm_weights']
    idx = 0
    
    #record predicted words' indices
    all_sents = []
    #records all sentences
    for result in results:
        tokens = []
        #all masks in a sentence, each sentence only has one complex phrase
        for masked, weight in zip(result, masked_weights[idx]):
            if weight == 1:
                values, indices = tf.math.top_k(input=masked, k=top_k)
                words = tokenizer.convert_ids_to_tokens(values)
                tokens.append(words)
        phrases = combine_tokens(tokens) 
        for phrase in phrases:
            new_sent = masked_sents[idx]
            n = 0
            for token in phrases:
                #replace the mask to be the predicted token
                new_sent[masked_lm_positions[n]] = token
                n += 1
            all_sents.append(SentenceEntity(new_sent, masked_lm_positions[:len(phrase)],sent_ids[idx]))
        idx += 1
    

In [30]:
if __name__ == '__main__':
    main()

(1200, None)
(1200, None)
