In [16]:
import tensorflow as tf

In [17]:

from transformers import XLNetTokenizer
tokenizer = XLNetTokenizer.from_pretrained(
    'huseinzol05/xlnet-base-bahasa-cased', do_lower_case = False
)

INFO:transformers.tokenization_utils_base:Model name 'huseinzol05/xlnet-base-bahasa-cased' not found in model shortcut name list (xlnet-base-cased, xlnet-large-cased). Assuming 'huseinzol05/xlnet-base-bahasa-cased' is a path, a model identifier, or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/huseinzol05/xlnet-base-bahasa-cased/spiece.model from cache at /Users/samsonlee/.cache/torch/transformers/c5ed46a1c7dc1002ab4f2106928fce75836edca5e1988fb9ef5c7b34eadb7a88.69797efcf2cbceb2ff4faaa9fda1b49630bc0a6af197b3bf7709a355149d5f4a
INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/huseinzol05/xlnet-base-bahasa-cased/added_tokens.json from cache at None
INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/huseinzol05/xlnet-base-bahasa-cased/special_tokens_map.json from cache at 

In [18]:
import json

with open('export_model/vocab-xlnet-base.json') as fopen:
    data = json.load(fopen)
    
LABEL_VOCAB = data['label']
TAG_VOCAB = data['tag']

In [19]:
with tf.gfile.GFile('export_model/xlnet-base.pb.quantized', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

In [20]:
input_ids = graph.get_tensor_by_name('import/input_ids:0')
word_end_mask = graph.get_tensor_by_name('import/word_end_mask:0')
charts = graph.get_tensor_by_name('import/charts:0')
tags = graph.get_tensor_by_name('import/tags:0')
sess = tf.InteractiveSession(graph = graph)

In [21]:

BERT_MAX_LEN = 512
import numpy as np
from parse_nk_xlnet_base import BERT_TOKEN_MAPPING

def make_feed_dict_bert(sentences):
    all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)
    all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)

    subword_max_len = 0
    for snum, sentence in enumerate(sentences):
        tokens = []
        word_end_mask = []

        cleaned_words = []
        for word in sentence:
            word = BERT_TOKEN_MAPPING.get(word, word)
            if word == "n't" and cleaned_words:
                cleaned_words[-1] = cleaned_words[-1] + "n"
                word = "'t"
            cleaned_words.append(word)

        for word in cleaned_words:
            word_tokens = tokenizer.tokenize(word)
            for _ in range(len(word_tokens)):
                word_end_mask.append(0)
            word_end_mask[-1] = 1
            tokens.extend(word_tokens)
        tokens.append("<sep>")
        word_end_mask.append(1)
        tokens.append("<cls>")
        word_end_mask.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)

        subword_max_len = max(subword_max_len, len(input_ids))

        all_input_ids[snum, :len(input_ids)] = input_ids
        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

    all_input_ids = all_input_ids[:, :subword_max_len]
    all_word_end_mask = all_word_end_mask[:, :subword_max_len]
    return all_input_ids, all_word_end_mask

In [22]:
s = 'Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar sekiranya mengantuk ketika memandu.'.split()
sentences = [s]
i, m = make_feed_dict_bert(sentences)
i, m

(array([[  383,  1096, 21767,    88,   757,  1606, 15738,    24,   198,
          4049,  2479,  7529,   271,  7644,     9,     4,     3]]),
 array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]]))

In [23]:
charts_val, tags_val = sess.run((charts, tags), {input_ids: i, word_end_mask: m})
charts_val, tags_val

(array([[[[ 0.        , -2.7198172 , -2.6694536 , ..., -2.8303301 ,
           -2.8180645 , -2.3391323 ],
          [ 0.        , -1.6336582 , -2.2570708 , ..., -1.8680124 ,
           -1.8989975 , -1.8138791 ],
          [ 0.        , -0.92996144, -1.759844  , ..., -2.0503466 ,
           -1.6889832 , -2.0735917 ],
          ...,
          [ 0.        , -1.5485951 , -2.700939  , ..., -1.7113425 ,
           -1.8153486 , -2.475861  ],
          [ 0.        , -1.9920502 , -2.868455  , ..., -2.0737116 ,
           -2.0749557 , -2.1500504 ],
          [ 0.        ,  0.        ,  0.        , ...,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        , -1.5348125 , -1.6977209 , ..., -1.6391273 ,
           -1.8877805 , -1.9240421 ],
          [ 0.        , -2.7198172 , -2.6694536 , ..., -2.8303301 ,
           -2.8180645 , -2.3391323 ],
          [ 0.        , -1.2483176 , -1.5230591 , ..., -1.9593236 ,
           -1.7585529 , -2.185865  ],
          ...,
          [ 0

In [24]:
for snum, sentence in enumerate(sentences):
    chart_size = len(sentence) + 1
    chart = charts_val[snum,:chart_size,:chart_size,:]

In [25]:
import wget
url_chart_decoder = 'https://raw.githubusercontent.com/michaeljohns2/self-attentive-parser/michaeljohns2-support-tf2-patch/benepar/chart_decoder.pyx'
wget.download(url_chart_decoder)

'chart_decoder.pyx'

In [26]:
import chart_decoder_py

In [27]:
chart_decoder_py.decode(chart)

(7.619638919830322,
 array([ 0,  0,  0,  1,  2,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,
         8,  8,  9,  9, 10, 10, 11, 11, 12, 13]),
 array([14,  2,  1,  2, 14, 13,  3, 13,  4, 13,  5, 13,  6, 13,  7, 13,  8,
        13,  9, 13, 10, 13, 11, 13, 12, 13, 14]),
 array([ 3, 10,  0, 10,  0,  7,  0,  0,  5,  7,  0,  7,  0,  0,  0,  0,  0,
         7,  0,  0,  0,  2,  0,  3, 12,  0,  0]))

In [28]:
import nltk
from nltk import Tree

In [29]:
PTB_TOKEN_ESCAPE = {u"(": u"-LRB-",
    u")": u"-RRB-",
    u"{": u"-LCB-",
    u"}": u"-RCB-",
    u"[": u"-LSB-",
    u"]": u"-RSB-"}


def make_nltk_tree(sentence, tags, score, p_i, p_j, p_label):

    # Python 2 doesn't support "nonlocal", so wrap idx in a list
    idx_cell = [-1]
    def make_tree():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            tree = Tree(tag, [word])
            for sublabel in label[::-1]:
                tree = Tree(sublabel, [tree])
            return [tree]
        else:
            left_trees = make_tree()
            right_trees = make_tree()
            children = left_trees + right_trees
            if label:
                tree = Tree(label[-1], children)
                for sublabel in reversed(label[:-1]):
                    tree = Tree(sublabel, [tree])
                return [tree]
            else:
                return children

    tree = make_tree()[0]
    tree.score = score
    return tree

In [30]:
tree = make_nltk_tree(s, tags_val[0], *chart_decoder_py.decode(chart))
print(str(tree))

(S
  (NP-SBJ (<START> Dr) (NP-SBJ (NN Mahathir)))
  (VP
    (NNP menasihati)
    (NP (VB mereka))
    (VP
      (PRP supaya)
      (VP
        (CC berhenti)
        (VB berehat)
        (JJ dan)
        (VP
          (CC tidur)
          (VB sebentar)
          (SBAR
            (JJ sekiranya)
            (S (FRAG (NP (IN mengantuk))) (JJ ketika)))))))
  (NN memandu.))


In [28]:
def make_str_tree(sentence, tags, score, p_i, p_j, p_label):
    idx_cell = [-1]
    def make_str():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            s = u"({} {})".format(tag, word)
        else:
            children = []
            while ((idx_cell[0] + 1) < len(p_i)
                and i <= p_i[idx_cell[0] + 1]
                and p_j[idx_cell[0] + 1] <= j):
                children.append(make_str())

            s = u" ".join(children)
            
        for sublabel in reversed(label):
            s = u"({} {})".format(sublabel, s)
        return s
    return make_str()

In [29]:
make_str_tree(s, tags_val[0], *chart_decoder_py.decode(chart))

'(S (NP-SBJ (<START> Dr) (NP-SBJ (NN Mahathir))) (VP (NNP menasihati) (NP (VB mereka)) (VP (PRP supaya) (VP (CC berhenti) (VB berehat) (JJ dan) (VP (CC tidur) (VB sebentar) (SBAR (JJ sekiranya) (S (FRAG (NP (IN mengantuk))) (JJ ketika))))))) (NN memandu.))'

In [32]:
str(tree)

'(S\n  (NP-SBJ (<START> Dr) (NP-SBJ (NN Mahathir)))\n  (VP\n    (NNP menasihati)\n    (NP (VB mereka))\n    (VP\n      (PRP supaya)\n      (VP\n        (CC berhenti)\n        (VB berehat)\n        (JJ dan)\n        (VP\n          (CC tidur)\n          (VB sebentar)\n          (SBAR\n            (JJ sekiranya)\n            (S (FRAG (NP (IN mengantuk))) (JJ ketika)))))))\n  (NN memandu.))'

In [33]:
import re
re.sub('\n', ' <br>', str(tree))

'(S<br>  (NP-SBJ (<START> Dr) (NP-SBJ (NN Mahathir)))<br>  (VP<br>    (NNP menasihati)<br>    (NP (VB mereka))<br>    (VP<br>      (PRP supaya)<br>      (VP<br>        (CC berhenti)<br>        (VB berehat)<br>        (JJ dan)<br>        (VP<br>          (CC tidur)<br>          (VB sebentar)<br>          (SBAR<br>            (JJ sekiranya)<br>            (S (FRAG (NP (IN mengantuk))) (JJ ketika)))))))<br>  (NN memandu.))'