In [6]:
from text_gan import cfg
from text_gan.features import GloVe, NERTagger, PosTagger
from text_gan.utils import MapReduce

import en_core_web_sm
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
import logging
import os
import gc

In [2]:
cembs = GloVe(cfg.EMBS_FILE, cfg.CSEQ_LEN)
qembs = GloVe(cfg.EMBS_FILE, cfg.QSEQ_LEN, cembs.data)

In [3]:
ner = NERTagger(cfg.NER_TAGS_FILE, cfg.CSEQ_LEN)
pos = PosTagger(cfg.POS_TAGS_FILE, cfg.CSEQ_LEN)
nlp = en_core_web_sm.load()

In [4]:
train = tfds.load("squad", data_dir="/tf/data/tf_data", split='train')

In [5]:
def tokenize_example(x):
    context, question, ans = list(nlp.pipe([
        x['context'].decode('utf-8'),
        x['question'].decode('utf-8'),
        x['answers']['text'][0].decode('utf-8')
    ]))
    del x
    return (context, question, ans)

def substrSearch(ans, context):
    i = 0
    j = 0
    s = -1
    while i < len(context) and j < len(ans):
        if context[i].text == ans[j].text:
            s = i
            i += 1
            j += 1
        else:
            i += 1
            j = 0
            s = -1
    return s, j

In [6]:
mr = MapReduce()
train_iter = train.as_numpy_iterator()
train_tokenized = mr.process(tokenize_context, train_iter)

In [7]:
train_context = []
train_question = []
train_ans = []
for context, ques, ans in train_tokenized:
    ans_start, l = substrSearch(ans, context)
    if len(ques) >= 20 or ans_start == -1 or ans_start + l >= 250:
        continue
    train_context.append(context)
    train_question.append(ques)
    train_ans.append((ans_start, l))
len(train_context), len(train_question), len(train_ans)

(83433, 83433, 83433)

In [8]:
cembs.fit(train_context)
qembs.fit(train_question)
train_cidx = cembs.transform(train_context)
train_ner = ner.transform(train_context)
train_pos = ner.transform(train_context)
train_qidx = qembs.transform(train_question)


In [10]:
cseq = cfg.CSEQ_LEN
qseq = cfg.QSEQ_LEN

def gen():
    for cidx, ner, pos, qidx, ans in zip(
            train_cidx, train_ner, train_pos,
            train_qidx, train_ans):
        yield (cidx, ans, qidx, ner, pos)

train_dataset = tf.data.Dataset.from_generator(
    gen,
    (tf.int32, tf.int32, tf.int32, tf.uint8, tf.uint8),
    (
        tf.TensorShape([cseq]), tf.TensorShape([2]),
        tf.TensorShape([qseq]), tf.TensorShape([cseq]), tf.TensorShape([cseq]))
)

In [12]:
i = 0
for cidx, ans, qidx, ner, pos in train_dataset:
    i += 1
print("Total", i)

Total 83433
