In [1]:
from easydict import EasyDict as edict

cfg = edict()

cfg.GLOVE = 'http://nlp.stanford.edu/data/glove.840B.300d.zip'
cfg.DATA_DIR = '../../data/summarization/data'
cfg.MODEL_DIR = '../../data/summarization/model'

# from https://arxiv.org/abs/1704.04368
cfg.ART_LEN = 781
cfg.SUM_LEN = 56

## Download data

In [2]:
import zipfile
import wget
import os

os.chdir("/tf/src/examples/summarization")

if not os.path.exists(cfg.DATA_DIR):
  os.makedirs(cfg.DATA_DIR)

glovefile = os.path.join(cfg.DATA_DIR, "glove", "glove.840B.300d.txt")

while not os.path.exists(glovefile):
  glove_zip = os.path.join(cfg.DATA_DIR, "glove.840B.300d.zip")
  glove_unzip = os.path.join(cfg.DATA_DIR, "glove")
  wget.download(cfg.GLOVE, glove_zip)
  with zipfile.ZipFile(glove_zip) as f:
    f.extractall(glove_unzip)
  os.remove(glove_zip)

In [3]:
import tensorflow_datasets as tfds
import tensorflow as tf
import os
data_dir = os.path.join(cfg.DATA_DIR, "tf_data")
data = tfds.load('cnn_dailymail/plain_text', data_dir=data_dir)
data

{'test': <PrefetchDataset shapes: {article: (), highlights: ()}, types: {article: tf.string, highlights: tf.string}>,
 'train': <PrefetchDataset shapes: {article: (), highlights: ()}, types: {article: tf.string, highlights: tf.string}>,
 'validation': <PrefetchDataset shapes: {article: (), highlights: ()}, types: {article: tf.string, highlights: tf.string}>}

## Preprocess Dataset

In [4]:
from copynet_tf import Vocab

In [5]:
import spacy
nlp = spacy.load('en_core_web_sm')

In [6]:
from typing import Dict
import numpy as np
from tqdm import tqdm


class GloVeReader:
    UNK = 'UNKNOWN'
    PAD = 'PAD'
    START = '<S>'
    END = 'EOS'

    def read(self,
             filename: str) -> Dict[str, np.ndarray]:
        data = {}
        with open(filename, 'r') as fin:
            for line in tqdm(fin, desc='Loading vectors'):
                tokens = line.split(' ')
                data[tokens[0].strip()] = np.array(
                    tokens[1:], dtype=np.float32)
        return data

In [7]:
def process(vocab, data, fit=False, pretrained_vectors=None):
  articles = []
  highlights = []
  print("Caching data...")
  for tup in data:
    articles.append(tup['article'].numpy().decode())
    highlights.append(tup['highlights'].numpy().decode())
  print("Tokenizing data...")
  articles = list(nlp.pipe(articles, disable=["tagger", "parser", "ner"], batch_size=1000))
  highlights = list(nlp.pipe(highlights, disable=["tagger", "parser", "ner"], batch_size=1000))
  if fit:
    print("Fitting vocab over tokens...")
    if pretrained_vectors is None:
      raise ValueError("Give pretrained vectors while fitting")
    vocab.fit(articles, highlights, pretrained_vectors, 0, 5)
  processed = zip(
      vocab.transform(articles, "source"),
      vocab.transform(articles, "target", vocab._source_seq_len),
      vocab.transform(highlights, "target"),
      vocab.transform(highlights, "source", vocab._target_seq_len),
  )

  def gen():
    for X, Xt, y, yt in processed:
      yield (X, Xt, y, yt)
  
  return tf.data.Dataset.from_generator(
      gen,
      (tf.int32, tf.int32, tf.int32, tf.int32),
      (tf.TensorShape([vocab._source_seq_len]),
       tf.TensorShape([vocab._source_seq_len]),
       tf.TensorShape([vocab._target_seq_len]),
       tf.TensorShape([vocab._target_seq_len])))

In [8]:
reader = GloVeReader()
pretrained_vectors = reader.read(os.path.join(cfg.DATA_DIR, "glove", "glove.840B.300d.txt"))
vocab = Vocab(
    reader.START, reader.END, reader.PAD, reader.UNK,
    cfg.ART_LEN, cfg.SUM_LEN)

Loading vectors: 2196017it [03:15, 11210.26it/s]


In [9]:
train = process(vocab, data['train'], fit=True, pretrained_vectors=pretrained_vectors)

Caching data...
Tokenizing data...
Fitting vocab over tokens...
source max ('.', 11114933) min ('195-a', 1)
target max ('.', 1028967) min ('Mimicking', 1)


In [10]:
val = process(vocab, data['validation'])

Caching data...
Tokenizing data...


In [11]:
test = process(vocab, data['test'])

Caching data...
Tokenizing data...


In [12]:
def bytes_feature(value):
    """Returns a bytes_list from a string / byte."""

    # BytesList won't unpack a string from an EagerTensor.
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def pyserialize(X, Xt, y, yt):
    X = tf.io.serialize_tensor(X)
    y = tf.io.serialize_tensor(y)
    Xt = tf.io.serialize_tensor(Xt)
    yt = tf.io.serialize_tensor(yt)
    feature = {
        "X": bytes_feature(X),
        "y": bytes_feature(y),
        "Xt": bytes_feature(Xt),
        "yt": bytes_feature(yt),
    }
    example_proto = tf.train.Example(
        features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


def serialize(X, Xt, y, yt):
    serialized = tf.py_function(
        pyserialize,
        [X, Xt, y, yt],
        tf.string
    )
    return tf.reshape(serialized, ())

def save(vocab, train, test, val):
    base_loc = os.path.join(cfg.DATA_DIR, "prepared")
    if not os.path.exists(base_loc):
        os.makedirs(base_loc)
    print("******** Saving Vocabulary ********")
    vocab.save(os.path.join(base_loc, "vocab"))

    print("******** Saving Validation set ********")
    fname = os.path.join(base_loc, "val.tfrecord")
    writer = tf.data.experimental.TFRecordWriter(fname, "ZLIB")
    writer.write(val)

    print("******** Saving Test set ********")
    fname = os.path.join(base_loc, "test.tfrecord")
    writer = tf.data.experimental.TFRecordWriter(fname, "ZLIB")
    writer.write(test)

    print("******** Saving Training set ********")
    fname = os.path.join(base_loc, "train.tfrecord")
    writer = tf.data.experimental.TFRecordWriter(fname, "ZLIB")
    writer.write(train)

    print("******** Finished saving dataset ********")

In [14]:
save(vocab, train.map(serialize, -1), test.map(serialize, -1), val.map(serialize, -1))

******** Saving Vocabulary ********
******** Saving Validation set ********
******** Saving Test set ********
******** Saving Training set ********
******** Finished saving dataset ********
