In [1]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import numpy as np

In [2]:
(X_train, y_train), (X_test, y_test) = keras.datasets.imdb.load_data()

In [3]:
X_train[0][:10]

[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]

In [4]:
word_index = keras.datasets.imdb.get_word_index()
id_to_word = {id_ + 3: word for word, id_ in word_index.items()}

In [5]:
print(id_to_word)

'fcker', 88155: 'gravestones', 23369: 'freshmen', 34649: 'formatted', 88156: 'drooping', 76036: 'zelig', 88157: 'yakusyo', 82060: 'lunceford', 88158: 'editorializing', 34650: 'plywood', 88159: 'banalities', 30539: 'nestor', 64829: 'revitalizes', 40755: 'voguing', 21669: 'sedate', 51860: 'dictum', 88160: 'brasher', 782: 'york', 88161: 'unchallengeable', 88162: 'subtelly', 8775: 'opposition', 88163: 'fetchingly', 70079: "'secrets", 88164: 'appearance\x85', 88165: 'teleflick', 19484: 'viennese', 10079: 'orphanage', 40756: 'movers', 27608: "cameraman's", 88166: "cameraman't", 88167: 'pornoes', 51861: 'embodiments', 88168: 'heorine', 16108: 'fraternity', 88169: "'procedures'", 659: 'finds', 88170: 'caratherisic', 27609: 'munshi', 20587: 'clashing', 40757: "mjh's", 88171: 'lärm', 76040: 'nikah', 51862: 'incandescent', 51863: 'stowing', 51864: 'acrid', 25216: 'eyewitness', 24004: 'maniacally', 51866: 'suspenders', 57063: 'acupat', 11585: 'nominee', 23370: 'toshiro', 51867: "'anita", 25217: 'c

In [6]:
for id_, token in enumerate(("<pad>", "<sos>", "<unk>")):
    id_to_word[id_] = token

In [7]:
" ".join([id_to_word[id_] for id_ in X_train[0][:10]])

'<sos> this film was just brilliant casting location scenery story'

In [8]:
datasets, info = tfds.load("imdb_reviews", as_supervised=True, with_info=True)
train_size = info.splits["train"].num_examples

In [9]:
def preprocess(X_batch, y_batch):
    X_batch = tf.strings.substr(X_batch, 0, 300)
    X_batch = tf.strings.regex_replace(X_batch, b'<br\\s*/?>', b" ")
    X_batch = tf.strings.regex_replace(X_batch, b"[^a-zA-Z']", b" ")
    X_batch = tf.strings.split(X_batch)
    return X_batch.to_tensor(default_value=b"<pad>"), y_batch

In [10]:
from collections import Counter

In [11]:
vocabulary = Counter()
for X_batch, y_batch in datasets["train"].batch(32).map(preprocess):
    for review in X_batch:
        vocabulary.update(list(review.numpy()))

In [12]:
vocabulary.most_common()[:3]

[(b'<pad>', 214309), (b'the', 61137), (b'a', 38564)]

In [13]:
vocab_size = 10000
truncated_vocabulary = [
    word for word, count in vocabulary.most_common()[:vocab_size]
]

In [14]:
words = tf.constant(truncated_vocabulary)
word_ids = tf.range(len(truncated_vocabulary), dtype=tf.int64)
vocab_init = tf.lookup.KeyValueTensorInitializer(words, word_ids)
num_oov_buckets = 1000
table =  tf.lookup.StaticVocabularyTable(vocab_init, num_oov_buckets)

In [15]:
table.lookup(tf.constant([b"This movie was faaaaaantastic".split()]))

<tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[   22,    12,    11, 10053]])>

In [16]:
def encode_words(X_batch, y_batch):
    return table.lookup(X_batch), y_batch

In [17]:
train_set = datasets["train"].batch(32).map(preprocess)
train_set = train_set.map(encode_words).prefetch(1)

In [21]:
embed_size = 128
model = keras.models.Sequential([
    keras.layers.Embedding(vocab_size + num_oov_buckets, embed_size, 
                           input_shape=[None]),
    keras.layers.GRU(128, return_sequences=True),
    keras.layers.GRU(128),
    keras.layers.Dense(1, activation="sigmoid")
])
model.compile(loss="binary_crossentropy", 
              optimizer=keras.optimizers.Nadam(lr=4e-4),
              metrics=["accuracy"])

In [22]:
history = model.fit(train_set, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


#### Masking

In [None]:
embed_size = 128
model = keras.models.Sequential([
    keras.layers.Embedding(vocab_size + num_oov_buckets, embed_size, 
                           input_shape=[None], mask_zero=True),
    keras.layers.GRU(128, return_sequences=True),
    keras.layers.GRU(128),
    keras.layers.Dense(1, activation="sigmoid")
])
model.compile(loss="binary_crossentropy", 
              optimizer=keras.optimizers.Nadam(lr=4e-4),
              metrics=["accuracy"])