In [None]:
import tensorflow as tf

print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

tf.config.optimizer.set_jit(True)

In [None]:
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
print('Compute dtype: %s' % policy.compute_dtype)
print('Variable dtype: %s' % policy.variable_dtype)

In [None]:
# Detect hardware
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
except ValueError:
  tpu = None
  gpus = tf.config.experimental.list_logical_devices("GPU")
    
# Select appropriate distribution strategy for hardware
if tpu:
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.experimental.TPUStrategy(tpu)
  print('Running on TPU ', tpu.master())  
elif len(gpus) > 0:
  strategy = tf.distribute.MirroredStrategy(gpus) # this works for 1 to multiple GPUs
  print('Running on ', len(gpus), ' GPU(s) ')
else:
  strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on CPU')

# How many accelerators do we have ?
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
import numpy as np
import re
import time

from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Dropout, Embedding, Bidirectional, GRU, LSTM
from tensorflow.keras.optimizers import Adam, RMSprop

In [None]:
# WandB – Install the W&B library
%pip install wandb --upgrade
import wandb
from wandb.keras import WandbCallback

import os
os.environ['WANDB_ENTITY'] = ""
os.environ['WANDB_PROJECT'] = ""

In [None]:
EPOCHS = 100
BATCH_SIZE = 10000
MAX_SEQ_LENGTH = 15
NUM_CLASSES = 500 + 1 # vocab size + 1 for masking
SHUFFLE_BUFFER_SIZE = 10000

In [None]:
config_defaults = {
    'optimizer': 'adam',
    'learning_rate': 0.005,
    'dropout': 0.0,
    'embedding_dims': 512,
    'rnn_units': 128,
    'rnn_type': 'gru',
    'bidirectional': 0,
    'stack_size': 1
}

In [None]:
LOG_DIR = 'runs'
TRAINING_DATASET_PATTERN = '../*_train_*.tfrec'
VALIDATION_DATASET_PATTERN = '../*_test_*.tfrec'

In [None]:
def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    regex = re.compile(r'_([0-9]*)\.')
    n = [int(regex.search(filename).group(1)) for filename in filenames]
    return sum(n)

def read_tfrecord(example):
    features = {
        'rating_chunk': tf.io.VarLenFeature(tf.int64),
        'label': tf.io.FixedLenFeature([], tf.int64)
    }

    example = tf.io.parse_single_example(example, features)

    rating_chunk = tf.sparse.to_dense(example['rating_chunk'])
    paddings = [[0, 0], [0, MAX_SEQ_LENGTH - len(rating_chunk)]]

    rating_chunk = tf.reshape(rating_chunk, (1, -1))
    rating_chunk = tf.pad(rating_chunk, paddings, 'CONSTANT')
    rating_chunk = rating_chunk[0]

    label = tf.one_hot(example['label'] + 1, NUM_CLASSES)  # TODO num classes to arg
    return rating_chunk, label

def load_dataset(filenames):
    # read from TFRecords. For optimal performance, read from multiple
    # TFRecord files at once and set the option experimental_deterministic = False
    # to allow order-altering optimizations.

    option_no_order = tf.data.Options()
    option_no_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(option_no_order)

    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_training_dataset(filenames, batch_size, shuffle_buffer_size):
    dataset = load_dataset(filenames)
    dataset = dataset.cache()
    dataset = dataset.repeat()
    dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(filenames, batch_size):
    dataset = load_dataset(filenames)
    dataset = dataset.cache()
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    
    # needed for TPU 32-core pod: the test dataset has only 3 files but there are 4 TPUs. FILE sharding policy must be disabled.
    opt = tf.data.Options()
    opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    dataset = dataset.with_options(opt)
    
    return dataset

In [None]:
def create_model(num_classes, optimizer_name, learning_rate, dropout, embedding_dims, rnn_units, rnn_type, bidirectional, stack_size):    
    model = Sequential()
    model.add(Embedding(num_classes, embedding_dims, mask_zero=True, input_length=MAX_SEQ_LENGTH)) # TODO input_length needed?
        
    def add_rnn_layer(model, rnn_type, bidirectinonal, dropout, return_sequences):
        if rnn_type == 'lstm':
            if bidirectional:
                model.add(Bidirectional(LSTM(rnn_units, return_sequences=return_sequences)))
            else:
                model.add(LSTM(rnn_units, return_sequences=return_sequences))
        else:
            if bidirectional:
                model.add(Bidirectional(GRU(rnn_units, return_sequences=return_sequences)))
            else:
                model.add(GRU(rnn_units, return_sequences=return_sequences))
        model.add(Dropout(dropout))

    for i in range(1, stack_size):
        add_rnn_layer(model, rnn_type, bidirectional, dropout, True)
    
    add_rnn_layer(model, rnn_type, bidirectional, dropout, False)
    
    model.add(Dense(num_classes))
    model.add(Activation('softmax', dtype='float32'))
    
    if optimizer_name == 'adam':
        optimizer = Adam(learning_rate=learning_rate)
    else:
        optimizer = RMSprop(learning_rate=learning_rate)
    
    metrics = [keras.metrics.CategoricalAccuracy(),
               keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_categorical_accuracy'),
               keras.metrics.TopKCategoricalAccuracy(k=10, name='top_10_categorical_accuracy'),
               keras.metrics.TopKCategoricalAccuracy(k=MAX_SEQ_LENGTH, name=f'top_{MAX_SEQ_LENGTH}_categorical_accuracy')]
    
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=metrics)
    
    return model

In [None]:
def train(config):
    print('Start training')
    # Initilize a new wandb run
    wandb.init(config=config)
    
    print('Reading dataset')
    training_filenames = tf.io.gfile.glob(TRAINING_DATASET_PATTERN)
    validation_filenames = tf.io.gfile.glob(VALIDATION_DATASET_PATTERN)

    train_steps = count_data_items(training_filenames) // BATCH_SIZE

    training_data = get_training_dataset(training_filenames, BATCH_SIZE, SHUFFLE_BUFFER_SIZE)
    validation_data = get_validation_dataset(validation_filenames, BATCH_SIZE)
    
    print('Creating model')    
    model = create_model(NUM_CLASSES,
                         config['optimizer'],
                         config['learning_rate'],
                         config['dropout'],
                         config['embedding_dims'],
                         config['rnn_units'],
                         config['rnn_type'],
                         config['bidirectional'],
                         config['stack_size'])
    model.summary()
    
    #model_checkpoint = keras.callbacks.ModelCheckpoint('model.hdf5', monitor='val_categorical_accuracy', save_best_only=True, period=5)
    early_stopping = keras.callbacks.EarlyStopping(monitor='val_categorical_accuracy', min_delta=0, patience=5, mode='auto')
    wandb_callback = WandbCallback(monitor='val_categorical_accuracy')

    callbacks = [wandb_callback, early_stopping]
    
    print('Start training')
    model.fit(training_data, validation_data=validation_data, epochs=EPOCHS, steps_per_epoch=train_steps, callbacks=callbacks)

In [None]:
param_grid = [{'rnn_units': 512, 'embedding_dims':128, 'dropout':0.5, 'bidirectional':1, 'stack_size':2}]

In [None]:
for param_config in param_grid:
    config = config_defaults.copy()
    config.update(param_config)
    train(config)