In [3]:
import tensorflow as tf

import numpy as np

In [4]:
batch_size = 8
raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
    "datas/aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="training",
    seed=1337
)
raw_val_ds = tf.keras.preprocessing.text_dataset_from_directory(
    "datas/aclImdb/train",
    batch_size=batch_size,
    validation_split=0.2,
    subset="validation",
    seed=1337
)
raw_test_ds = tf.keras.preprocessing.text_dataset_from_directory(
    "datas/aclImdb/test",
    batch_size=batch_size
)
print(
    "Number of batches in raw_train_ds: %d"
    % tf.data.experimental.cardinality(raw_train_ds)
)

print(
    "Number of batches in raw_val_ds: %d"
    % tf.data.experimental.cardinality(raw_val_ds)
)

print(
    "Number of batches in raw_test_ds: %d"
    % tf.data.experimental.cardinality(raw_test_ds)
)

Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.
Found 25000 files belonging to 2 classes.
Number of batches in raw_train_ds: 2500
Number of batches in raw_val_ds: 625
Number of batches in raw_test_ds: 3125


In [5]:
for text_batch, label_batch in raw_train_ds.take(1):
    print(len(text_batch))
    for i in range(batch_size):
        print(text_batch.numpy()[i])
        print(label_batch.numpy()[i])

8
b'I feel it is my duty as a lover of horror films to warm other people about this horrible and very very bad "horror" film. Don\'t waste your time or money on this film, the acting is bad, the story is just one of the worst i have come across and the script was just awful. Nothing about it was good, you end up thinking to yourself why am i watching this crap. The plot had so many holes in it and they never got cleared up in the end, it was just so bad, i don\'t know how a film so terrible could be made. As i said before i love horror films and i was so let down, it was an 18 but you see little blood and no scares or jumps at all. Also what annoyed me was how stupid things happened in the film that had no point to the plot at all like the brother and sister kissing, why? is all i can say. Just don\'t bother, there are far more great horror films out there, just don\'t waste your time life is too short.'
0
b'This, and Immoral Tales, both left a bad taste in my mouth. It seems to me tha

In [6]:
# prepare the data

from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
import string
import re

def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return  tf.strings.regex_replace(
        stripped_html, "[%s]" % re.escape(string.punctuation), ""
    )

max_features = 20000
embedding_dim = 128
sequence_length = 500

vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=max_features,
    output_mode="int",
    output_sequence_length=sequence_length
)

text_ds = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)

In [7]:
def vectorize_text(text, label):
    text = tf.expand_dims(text, -1)
    return vectorize_layer(text), label

train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)
test_ds = raw_test_ds.map(vectorize_text)

train_ds = train_ds.cache().prefetch(buffer_size=10)
val_ds = val_ds.cache().prefetch(buffer_size=10)
test_ds = test_ds.cache().prefetch(buffer_size=10)


In [8]:
#build a model
from tensorflow.keras import layers

inputs = layers.Input(shape=(None, ), dtype="int64")
x = layers.Embedding(max_features, embedding_dim)(inputs)
x = layers.Dropout(0.5)(x)

x = layers.Conv1D(128, 7, padding="valid", activation="relu", strides=3)(x)
x = layers.Conv1D(128, 7, padding="valid", activation="relu", strides=3)(x)
x = layers.GlobalMaxPool1D()(x)

x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.5)(x)
predictions = layers.Dense(1, activation="sigmoid", name="predictions")(x)
model = tf.keras.Model(inputs, predictions)

model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()


Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding (Embedding)        (None, None, 128)         2560000   
_________________________________________________________________
dropout (Dropout)            (None, None, 128)         0         
_________________________________________________________________
conv1d (Conv1D)              (None, None, 128)         114816    
_________________________________________________________________
conv1d_1 (Conv1D)            (None, None, 128)         114816    
_________________________________________________________________
global_max_pooling1d (Global (None, 128)               0         
_________________________________________________________________
dense (Dense)                (None, 128)              

In [9]:
epochs = 3

model.fit(train_ds, validation_data=val_ds, epochs=epochs)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x22484682dc0>

In [10]:
model.evaluate(test_ds)





[0.4298565983772278, 0.8668400049209595]

In [11]:
inputs = layers.Input(shape=(1, ), dtype="string")
indices = vectorize_layer(inputs)
outputs = model(indices)

end_to_end_model = tf.keras.Model(inputs, outputs)
end_to_end_model.compile(
    loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]
)

end_to_end_model.evaluate(raw_test_ds)





[0.4298565089702606, 0.8668400049209595]

In [19]:
end_to_end_model.save("models/imdb_classification.h5", include_optimizer=False, save_format='h5')


NotImplementedError: Save or restore weights that is not an instance of `tf.Variable` is not supported in h5, use `save_format='tf'` instead. Got a model or layer TextVectorization with weights [<tensorflow.python.keras.engine.base_layer_utils.TrackableWeightHandler object at 0x00000224846738E0>]