# V2: SEQ2SEQ

This notebook follows [an online tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention#create_a_tfdata_dataset).

Other resources that may be helpful in the future: https://arxiv.org/pdf/2111.10746, https://aclanthology.org/2021.findings-emnlp.393.pdf

In [1]:
import tensorflow as tf

2024-09-10 18:48:56.663087: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-10 18:48:56.665910: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-10 18:48:56.675072: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-10 18:48:56.689630: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-10 18:48:56.693881: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-10 18:48:56.705545: I tensorflow/core/platform/cpu_feature_gu

**Notes**:
- This notebook follows [an online tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention) (and [at least one other](https://www.tensorflow.org/text/tutorials/text_generation) of the Tensorflow tutorials).
- This [blog post](https://janakiev.com/blog/jupyter-virtual-envs/) was referenced to set up the virtual environment.

In [2]:
import numpy as np
from pathlib import Path

import tensorflow_datasets as tfds


In [38]:
dataset_raw_train = tfds.load('ai2_arc_with_ir', split='train', shuffle_files=True)
dataset_raw_test = tfds.load('ai2_arc_with_ir', split='test', shuffle_files=True)
dataset_raw = dataset_raw_train.concatenate(dataset_raw_test)
dataset_raw = dataset_raw.map(lambda x: x['paragraph'])
for i in dataset_raw.take(1):
	print(i)

tf.Tensor(b'It and the other noble gases - helium, neon, krypton, xenon, and radon - will react with other substances only under extreme conditions. The noble gases The noble, or inert, gases are helium, neon, argon, krypton, xenon and radon. The rare gases are helium, neon, argon, krypton or xenon. The noble "gases" are helium, neon, argon, krypton and xenon. The noble gases are helium, neon, argon, krypton, xenon and radon. These occur for the noble gases helium, neon, argon, krypton, xenon, radon and ununoctium. Because of their chemical inertness, the elements in this group are called the Nobel Gases : Helium Neon Argon Krypton Xenon Radon Group I Elements The elements in this group have one electron in their outer electronic shell. The inert, or noble, gases (helium, neon, argon, krypton, xenon, and radon) all have completely filled outer shells. The other noble gases, which together make about 1% of the Earth\'s atmosphere, are neon, argon, krypton, xenon and radon. Most of the n

2024-09-10 20:48:02.785603: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [39]:
data_file_paths = list(Path('data/processed/en/').glob('*.txt'))
dataset_raw = dataset_raw.concatenate(
	tf.data.TextLineDataset(data_file_paths)
)

In [40]:
dataset_raw_mlqa = tfds.load('mlqa/en', split='test', shuffle_files=True) # "train" not available
dataset_raw_mlqa = dataset_raw_mlqa.map(lambda x: x['context'])
dataset_raw = dataset_raw_mlqa.concatenate(dataset_raw)

In [41]:

def remove_unsupported_characters(s):
	return tf.strings.regex_replace(s, r'[^A-Za-z0-9ùúûüÿàâæçéèêëïîôœÙÚÛÜŸÀÂÆÇÉÈÊËÏÎÔŒ \t\n.,?!\-\':;&]', '')
dataset_raw = dataset_raw.map(remove_unsupported_characters)

In [42]:
def add_extra_sentences(s):
	result = [s]
	for sentence in [
		'It was interesting. It was all very interesting.',
		'I added this.',
		'Why would it not?',
		'There it was.',
		'It happened.',
		'However, this part was added.',
		'That was a sentence.',
		'Here is a word: Test.',
		'This might help.',
		'How well do these extra words help?',
	]:
		result.append(tf.strings.join([s, sentence]))
	for sentence in [
		'I am hopeful that this prefix sentence will help. ',
		'This is a sample. ',
		'Here is some text: ',
		'Do short added sentences help? ',
	]:
		result.append(tf.strings.join([sentence, s]))
	return tf.data.Dataset.from_tensor_slices(result)

# Append extra sentences to help the model learn to detect sentence breaks
# dataset_raw = dataset_raw.flat_map(add_extra_sentences)

We've now loaded the `.txt` training data files using [`tf.data.TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset). Each line in the source files is mapped to a new training example. 

Although some preprocessing has been done by `/data/process_data.py`, paragraphs aren't filtered out based on length/content. Let's do that now:

In [43]:
def remove_incorrect_sentences(paragraph):
	# Some parts of the training data include what seems to be image captions and other content that
	# starts with a lowercase after a sentence break.
	return tf.strings.regex_replace(paragraph, r'([.?!]) [a-z][^A-Z]+', r'\1')

def break_long_sequences(paragraph):
	return tf.strings.regex_replace(paragraph, r'.{0,120}[.?]', r'\0 [SEP]')

def split_on_separators(text):
	return tf.data.Dataset.from_tensor_slices(tf.strings.split(text, '[SEP]'))

def filter_paragraphs(context, target):
	return tf.strings.length(context) > 5 and tf.strings.length(target) > 5

punctuation_chars = r'\?!.,"\-\':;'
def add_context(target):
	context = tf.strings.regex_replace(target, r'[\-]{2,}', ' - ')
	context = tf.strings.regex_replace(context, r'[\-\']', '')
	context = tf.strings.regex_replace(context, '[{}]+'.format(punctuation_chars), ' ')
	context = tf.strings.strip(
		tf.strings.regex_replace(context, '[ ]+', ' ')
	)
	context = tf.strings.lower(context)
	return context, target

dataset_raw = dataset_raw.map(remove_incorrect_sentences).map(break_long_sequences).flat_map(split_on_separators).map(add_context).filter(filter_paragraphs)
for text, label in dataset_raw.take(12):
	print('item', text, label)

item tf.Tensor(b'after completing the journey around south america on 23 february 2006 queen mary 2 met her namesake the original rms queen mary which is permanently docked at long beach california', shape=(), dtype=string) tf.Tensor(b'After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. ', shape=(), dtype=string)
item tf.Tensor(b'escorted by a flotilla of smaller ships the two queens exchanged a whistle salute which was heard throughout the city of long beach', shape=(), dtype=string) tf.Tensor(b' Escorted by a flotilla of smaller ships, the two Queens exchanged a whistle salute which was heard throughout the city of Long Beach. ', shape=(), dtype=string)
item tf.Tensor(b'queen mary 2 met the other serving cunard liners queen victoria and queen elizabeth 2 on 13 january 2008 near the statue of liberty in new york city harbour with a celebratory firework

2024-09-10 20:48:07.692624: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Now let's inspect the data:

In [44]:
for text, label in dataset_raw.take(4).as_numpy_iterator():
	print('item', text, label)

item b'after completing the journey around south america on 23 february 2006 queen mary 2 met her namesake the original rms queen mary which is permanently docked at long beach california' b'After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. '
item b'escorted by a flotilla of smaller ships the two queens exchanged a whistle salute which was heard throughout the city of long beach' b' Escorted by a flotilla of smaller ships, the two Queens exchanged a whistle salute which was heard throughout the city of Long Beach. '
item b'queen mary 2 met the other serving cunard liners queen victoria and queen elizabeth 2 on 13 january 2008 near the statue of liberty in new york city harbour with a celebratory fireworks display queen elizabeth 2 and queen victoria made a tandem crossing of the atlantic for the meeting' b' Queen Mary 2 met the other serving Cunard li

2024-09-10 20:48:09.546335: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


### Batching

In [45]:
BUFFER_SIZE = 100_000
BATCH_SIZE = 16
dataset_train = dataset_raw.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Inspired by https://stackoverflow.com/a/74609848.
validate_size = 2
dataset_validate = dataset_train.take(validate_size)
dataset_train = dataset_train.skip(validate_size)


### Preparing to process data

The [`TextVectorization`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer takes a `standardize` option that preprocesses input data. The default removes punctuation, but we don't want that. Let's redefine it:

In [46]:

def standardize_tf_text(text):
	punctuation_regex = '[{p}]'.format(p = punctuation_chars)

	# Clean up &quot artifacts
	text = tf.strings.regex_replace(text, r'&quot;?', r' " ')

	# Surround punctuation with spaces for easier tokenization
	text = tf.strings.regex_replace(text, punctuation_regex, r' \0 ')

	# Remove repeated spaces
	text = tf.strings.regex_replace(text, r'\s+', ' ')

	# Add a special "capitalize the next letter" token
	text = tf.strings.regex_replace(text, r'(\s|^)([A-Z])', r' [CAP] \2')

	# Lowercase everything
	text = tf.strings.lower(text)

	# Remove leading and trailing spaces
	text = tf.strings.strip(text)

	# Add sequence markings
	return tf.strings.join(['[START]', text, '[END]'], separator=' ')

print(standardize_tf_text('This is a test! It\'s working?!'))

tf.Tensor(b"[START] [cap] this is a test ! [cap] it ' s working ? ! [END]", shape=(), dtype=string)


In [47]:
# Keep only the 4000 most commonly used tokens
max_vocab_size = 4000

bert_tokenizer_params = dict(
	vocab_size = max_vocab_size,
	reserved_tokens = [ '', '[UNK]', '[MASK]', '[START]', '[END]', '[cap]' ],
	bert_tokenizer_params = {},
	learn_params = {},
)

In [48]:
standardized_target = dataset_train.map(lambda context, target: standardize_tf_text(target))

from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

# This can be slow -- avoid recomputing:
vocab_file = Path('./data/en_vocab.txt')
if not vocab_file.exists():
	vocab = bert_vocab.bert_vocab_from_dataset(
		standardized_target.batch(1000),
		**bert_tokenizer_params,
	)

	with open(vocab_file, 'w') as f:
		for token in vocab:
			if token:
				f.write(token + '\n')

In [49]:
import tensorflow_text as tf_text

en_tokenizer = tf_text.BertTokenizer('data/en_vocab.txt', token_out_type=tf.string)

def bert_tokenize(text):
	# The Bert tokenizer outputs a RaggedTensor, where the last dimension groups parts of each word.
	# We want all tokens to be on the same dimension -- merge the last two.
	tokens = en_tokenizer.tokenize(standardize_tf_text(text)).merge_dims(-2, -1)

	# Work around a bug where special tokens like [START] are split into '[', 'START', ']'
	result = tf.strings.reduce_join(tokens, separator = ' ', axis=-1)
	result = tf.strings.regex_replace(result, r'\[ ([A-Za-z]+) \]', r'[\1]')
	result = tf.strings.regex_replace(result, r'\s+', r' ')
	return tf.strings.split(result, ' ')


example_text = ['hello world this is a test tensorflow is processing this', 'test']
example_tokens = bert_tokenize(example_text)
print('Example tokens', example_tokens)


def bert_detokenize(tokens):
	#return tf.strings.reduce_join(en_tokenizer.detokenize(tokens), separator = ' ')
	return tf.strings.regex_replace(tf.strings.reduce_join(tokens, separator = ' ', axis=-1), r' ##', '')

print('Back to text', bert_detokenize(example_tokens).numpy())

Example tokens <tf.RaggedTensor [[b'[START]', b'he', b'##ll', b'##o', b'world', b'this', b'is', b'a',
  b'test', b'ten', b'##s', b'##or', b'##f', b'##lo', b'##w', b'is',
  b'process', b'##ing', b'this', b'[END]']                           ,
 [b'[START]', b'test', b'[END]']]>
Back to text [b'[START] hello world this is a test tensorflow is processing this [END]'
 b'[START] test [END]']


The text standardization function can now be used to preprocess text:

In [50]:
vocabulary = vocab_file.read_text('utf-8').split('\n')
# TextVectorization fails to construct if the vocabulary contains mask or [UNK] tokens.
vocabulary = list(filter(lambda s: s and s != '[UNK]', vocabulary))

target_text_processor = tf.keras.layers.TextVectorization(
	# Standardize first with the BERT tokenizer
	standardize=lambda s: tf.strings.reduce_join(bert_tokenize(s), separator = ' ', axis=-1),
	max_tokens=max_vocab_size,
	# Allow entries of different lengths
	ragged=True,
	vocabulary=vocabulary,
)
#target_text_processor.adapt(dataset_train.map(lambda context, target: target))

print('First 14 target words:', target_text_processor.get_vocabulary()[:14])

# The target data should be roughly equivalent to the context data, except have additional (punctuation)
# tokens.
context_text_processor = target_text_processor

First 14 target words: ['', '[UNK]', '[MASK]', '[START]', '[END]', '[cap]', '!', '"', '&', "'", ',', '-', '.', '0']


We can use these layers to convert to/from token IDs:


### Processing the data

Now, we'll:
1. Map the data through the text processors we just made.
2. Shift the target data, so that our network is provided with a history of generated tokens.

In [51]:
def process_text(context, target):
	return context_text_processor(context), target_text_processor(target)

def add_target_history(context, target):
	# .to_tensor(): Converts from RaggedTensors to Tensors.
	# We give our network the history as target_in
	target_in = target[:, :-1].to_tensor()
	target_out = target[:, 1:].to_tensor()
	return (context.to_tensor(), target_in), target_out
dataset_train = dataset_train.map(process_text).map(add_target_history).repeat()
dataset_validate = dataset_validate.map(process_text).map(add_target_history)

In [52]:
def inspect_dataset(dataset: tf.data.Dataset):
	target_vocab = np.array(target_text_processor.get_vocabulary())
	context_vocab = target_vocab
	for (context, target_in), target_out in dataset.take(1):
		context_words = context_vocab[context[0]]
		print('context', ','.join(context_words))
		print('target_in', ','.join(target_vocab[target_in[0]]))
		print('target_out', ','.join(target_vocab[target_out[0]]))

inspect_dataset(dataset_train)

context [START],v,##i,##ke,##la,##s,was,then,elected,the,first,president,of,the,new,##ly,established,international,olympic,committee,i,##o,##c,[END],,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
target_in [START],[cap],v,##i,##ke,##la,##s,was,then,elected,the,first,president,of,the,new,##ly,established,[cap],international,[cap],olympic,[cap],committee,[cap],i,##o,##c,.,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
target_out [cap],v,##i,##ke,##la,##s,was,then,elected,the,first,president,of,the,new,##ly,established,[cap],international,[cap],olympic,[cap],committee,[cap],i,##o,##c,.,[END],,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


2024-09-10 20:48:35.824034: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


## Model

### The encoder

See https://www.tensorflow.org/text/tutorials/nmt_with_attention#the_encoder

In [18]:
class Encoder(tf.keras.Layer):
	def __init__(self, text_processor, units: int):
		"""
		Creates a new Encoder layer. [dimen] is the maxiumum number of elements of the input
		that can be processed by the encoder.
		"""
		super(Encoder, self).__init__()
		self.text_processor = text_processor
		self.vocab_size = text_processor.vocabulary_size()
		self.units = units

		# Converts tokens -> vectors
		self.embedding = tf.keras.layers.Embedding(
			# mask_zero: Treats zero as a padding value that should be ignored
			self.vocab_size, units, mask_zero = True,
		)
		gru = tf.keras.layers.GRU(
			units, return_sequences = True,
			# Use the recurrent_initializer suggested by the tutorial (& the default
			# for kernel_initializer).
			recurrent_initializer='glorot_uniform'
		)
		self.rnn = tf.keras.layers.Bidirectional(
			# merge_mode determines how the forward and backward layers are combined
			#            'concat' is another option here
			merge_mode = 'sum',
			layer=gru,
		)

	def call(self, x):
		x = self.embedding(x)
		x = self.rnn(x)
		return x

	def prepare_for_input(self, texts):
		"""
		Utility method that converts `texts` to a form that can be provided to the `call` method.
		"""
		texts = tf.convert_to_tensor(texts)
		if len(texts.shape) == 0:
			texts = texts[None]
		context = self.text_processor(texts).to_tensor()
		return context


Try it:

In [19]:
ENCODER_UNITS = 256
encoder = Encoder(context_text_processor, ENCODER_UNITS)

for (context, target_history), target_next in dataset_validate.take(1):
	encoder_result = encoder(context)
	print('Context tokens shape (batch, s):', context.shape)
	print('Encoder output shape (batch, s, ENCODER_UNITS):', encoder_result.shape)

2024-09-10 18:49:18.916428: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Context tokens shape (batch, s): (16, 59)
Encoder output shape (batch, s, ENCODER_UNITS): (16, 59, 256)


### The attention layer

Attention can be thought of as training a lookup table with keys and values. The lookup table has inputs `values` and `query`.

In [20]:
class CrossAttention(tf.keras.Layer):
	def __init__(self, units, **kwargs):
		super().__init__()
		self.attention_layer = tf.keras.layers.MultiHeadAttention(
			key_dim=units,
			num_heads=1,
			**kwargs
		)
		# Keeps "the mean activation within each example close to 0 and the
		# activation standard deviation close to 1" -- https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization?hl=en
		self.norm_layer = tf.keras.layers.LayerNormalization()
		self.add_layer = tf.keras.layers.Add()
		self.supports_masking = True

	def call(self, query, value):
		attention_output = self.attention_layer(
			query = query,
			value = value,
			#use_causal_mask=True,
			# Return the attention scores for latter plotting
			# return_attention_scores = True,
		)

		x = self.add_layer([ query, attention_output ])
		x = self.norm_layer(x)
		return x


In [21]:
attention_layer = CrossAttention(ENCODER_UNITS)

# Test with an example
for (context, target_history), target_next in dataset_validate.take(1):
	embed_layer = tf.keras.layers.Embedding(target_text_processor.vocabulary_size(), output_dim=ENCODER_UNITS, mask_zero=True)
	target_embed = embed_layer(target_history)
	encoded_context = encoder(context)
	attention_result = attention_layer(target_embed, encoded_context)

	print('Encoded context sequence shape (batch, s, units):', encoded_context.shape)
	print('Target history sequence shape (batch, t, units):', target_embed.shape)
	print('Attention result shape (batch, t, units):', attention_result.shape)

	# Used later 
	test_encoded_context = encoded_context

2024-09-10 18:49:27.803434: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Encoded context sequence shape (batch, s, units): (16, 138, 256)
Target history sequence shape (batch, t, units): (16, 167, 256)
Attention result shape (batch, t, units): (16, 167, 256)




### The decoder

The decoder produces queries for the attention layer. The decoder operates on `target_history`. At each step during training, it should have no information about future target output (that's what we're trying to determine). As such, we use a unidirectional RNN.


In [22]:
class CustomDense(tf.keras.layers.Dense):
	def __init__(self, *args, **kwargs):
		super(CustomDense, self).__init__(*args, **kwargs)
	
	def compute_mask(self, _inputs, mask=None):
		return mask

class Decoder(tf.keras.Layer):
	def __init__(self, text_processor, units):
		super(Decoder, self).__init__()
		self.text_processor = text_processor
		self.vocab_size = text_processor.vocabulary_size()
		self.units = units

		self.embedding_layer = tf.keras.layers.Embedding(
			# mask_zero: Treats zero as a padding value that should be ignored
			self.vocab_size, units, mask_zero = True,
		)
		self.rnn_layer = tf.keras.layers.GRU(
			units, return_sequences = True, return_state = True, recurrent_initializer='glorot_uniform',
		)
		self.attention_layer = CrossAttention(units)

		# Creates logits with the estimated probability of each output token
		self.output_layer = CustomDense(self.vocab_size)

		# Conversion:
		self.word_to_id = tf.keras.layers.StringLookup(
			vocabulary = text_processor.get_vocabulary(),
			mask_token = '',
			oov_token = '[UNK]',
		)
		self.id_to_word = tf.keras.layers.StringLookup(
			vocabulary = text_processor.get_vocabulary(),
			mask_token = '',
			oov_token = '[UNK]',
			invert = True,
		)
		# Pre-computing these simplifies exporting
		self.start_id = self.word_to_id('[START]')
		self.end_id = self.word_to_id('[END]')

		self.supports_masking = True
	
	def build(self, input_shape):
		# Nothing tha needs a size allocation based on the input shape
		pass
	
	def call(self, context, target_history, state = None, return_state = False):
		x = self.embedding_layer(target_history)
		x, state = self.rnn_layer(x, initial_state = state)
		x = self.attention_layer(x, context)

		logits = self.output_layer(x)
		if return_state:
			return logits, state
		else:
			return logits
	
	## Conversion/testing ##

	def tokens_to_text(self, tokens):
		text = tf.strings.reduce_join(self.id_to_word(tokens), separator = ' ')
		text = tf.strings.regex_replace(text, r'\s*\[START\]\s*', '')
		text = tf.strings.regex_replace(text, r'\s*\[END\].*$', '')
		return text

	def generate_next_token(self, context, target_history, done_vec, state, temperature = 0.0):
		# Note: is_done is a vector, indicating whether each item in the batch is done

		logits, state = self(context, target_history, state = state, return_state = True)

		# logits has shape (batch, t, target_vocab_size). Only generate the token corresponding
		# to the last logits in the sequence (at t - 1)
		if temperature > 0:
			next_token = tf.where(
				done_vec,
				tf.constant(0, dtype=tf.int64), # Emit 0 after a sequence is done
				tf.random.categorical(logits[:, -1, :] / temperature, num_samples = 1), # Otherwise, pick the token from a categorical distribution
			)
		else:
			next_token = tf.math.argmax(logits, axis=-1)
		done_vec = done_vec|(next_token == self.end_id)
		return next_token, done_vec, state
	
	def get_initial_state(self, context):
		# context has shape (batch_size, s, units)
		batch_size = tf.shape(context)[0]
		start_tokens = tf.fill([batch_size, 1], self.start_id)
		done_vec = tf.zeros([batch_size, 1], dtype = tf.bool)

		# From the Tensorflow source code:
		# > RNN expect the states in a list, even if single state.
		# Note: Without the [0] we get a type mismatch while exporting.
		initial_state = self.rnn_layer.get_initial_state(batch_size)[0]

		return start_tokens, done_vec, initial_state

Let's try it!

In [23]:
def test_generation_loop():
	decoder = Decoder(target_text_processor, ENCODER_UNITS)
	next_token, done_vec, state = decoder.get_initial_state(test_encoded_context[:3, :, :])
	tokens = [next_token]

	for i in range(8):
		next_token, done_vec, state = decoder.generate_next_token(test_encoded_context[:3, :, :], next_token, done_vec, state)
		tokens.append(next_token)
	
	# Merge all batch outputs into a single dimension
	tokens = tf.concat(tokens, -1) # -1 = last axis

	print('Output:', decoder.tokens_to_text(tokens).numpy())

test_generation_loop()

Output: b'roman scientists 2016 rock ##ians plan ##ian causedroman scientists 2016 rock ##ians plan ##ian causedroman scientists 2016 rock ##ians plan ##ian caused'


## The model

We can now build a model for training and punctuation:

In [24]:
class Punctuator(tf.keras.Model):
	def __init__(self, units, context_text_processor, target_text_processor):
		super().__init__()
		self.encoder = Encoder(context_text_processor, units)
		self.decoder = Decoder(target_text_processor, units)
	
	def call(self, inputs):
		context, target_history = inputs
		context = self.encoder(context)
		logits = self.decoder(context, target_history)
		return logits
	
	def fix_punctuation_raw(self, input):
		"""
		Adds punctuation to `input`, where `input` is a `Tensor` with shape (batch_size, s) where s is the
		context length.
		"""
		context = self.encoder(input)

		next_token, done_vec, state = self.decoder.get_initial_state(context)

		# Although a TensorArray would allow more efficient exporting, the ONNX exporter seems to
		# have trouble with it. For now, use a Python list.
		tokens = []
		max_iterations = 56

		for i in range(max_iterations):
			# token_history has size: (batch, t, target_vocab_size)
			# token_history = tf.concat(tokens, 1)
			# print('history', model.decoder.id_to_word(token_history))
			next_token, done_vec, state = self.decoder.generate_next_token(context, next_token, done_vec, state, temperature=0)
			#tokens = tokens.write(i + 1, next_token)
			tokens.append(next_token)

			if tf.executing_eagerly() and tf.reduce_all(done_vec):
				break
		
		tokens = tf.concat(tokens, -1)
		return tokens

	def fix_punctuation(self, text: list[str]):
		inputs = self.encoder.prepare_for_input(text)
		tokens = self.fix_punctuation_raw(inputs)
		return self.decoder.tokens_to_text(tokens)

In [25]:
model = Punctuator(ENCODER_UNITS, context_text_processor, target_text_processor)

for (example_context_tok, example_target_hist), _ in dataset_validate.take(1):
	test_logits = model((example_context_tok, example_target_hist))
	print('Context tokens shape (batch, s):', example_context_tok.shape)
	print('Target history tokens shape (batch, t):', example_target_hist.shape)
	print('Logits shape (batch, t, vocab_size)', test_logits.shape)

2024-09-10 18:49:36.970873: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Context tokens shape (batch, s): (16, 136)
Target history tokens shape (batch, t): (16, 151)
Logits shape (batch, t, vocab_size) (16, 151, 3890)


In [26]:
model.summary()

To avoid penalizing masked outputs, we use a custom loss function (see the tutorial):

In [27]:
base_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def masked_loss(y_true, y_predict):
	loss = base_loss_fn(y_true, y_predict)
	
	unmasked = y_true != 0
	unmasked = tf.cast(unmasked, loss.dtype)
	# Only consider output with a corresponding label.
	loss *= unmasked

	count_unmasked = tf.math.reduce_sum(unmasked)

	# reduce_sum: Adds all entries of a vector.
	return tf.math.reduce_sum(loss)/count_unmasked

In [28]:
def masked_accuracy(y_true, predict_logits):
	predicted_index = tf.math.argmax(predict_logits, axis=-1)
	predicted_index = tf.cast(predicted_index, y_true.dtype)

	match = tf.cast(y_true == predicted_index, tf.float32)
	unmasked = tf.cast(y_true != 0, tf.float32)
	count_unmasked = tf.math.reduce_sum(unmasked)

	return tf.math.reduce_sum(match * unmasked) / count_unmasked


In [29]:
#model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[masked_accuracy, masked_loss])
model.compile(optimizer='adam', loss=masked_loss, metrics=[masked_accuracy, masked_loss])

In [30]:
print('From the tutorial:')
vocab_size = float(target_text_processor.vocabulary_size())

print('expected loss', tf.math.log(vocab_size).numpy())
print('expected accuracy', 1/vocab_size)

From the tutorial:
expected loss 8.266165
expected accuracy 0.0002570694087403599


In [31]:
model.evaluate(dataset_validate, steps=20, return_dict=True)


[1m 2/20[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m1s[0m 96ms/step - loss: 8.2703 - masked_accuracy: 0.0000e+00 - masked_loss: 8.2703

2024-09-10 18:49:47.637571: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-09-10 18:49:47.814254: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 7ms/step - loss: 8.2718 - masked_accuracy: 0.0000e+00 - masked_loss: 5.7902 


  self.gen.throw(typ, value, traceback)


{'loss': 8.27199935913086,
 'masked_accuracy': 0.0,
 'masked_loss': 5.5146660804748535}

In [32]:


def test_punctuation(text):
	return '[test]: ' + model.fix_punctuation(text).numpy().decode('utf-8')

class DemoCallback(tf.keras.callbacks.Callback):
	def on_epoch_end(self, epoch_index: int, logs = None):
		print('\r', test_punctuation([ 'im testing this models performance how well is it working' ]))
		print(test_punctuation([ 'i think its working well but its really hard to tell why are the question marks missing' ]))
		if epoch_index % 3 == 0:
			# From the test data
			print(test_punctuation([
				'not that alice had any idea of doing that she felt as if she would never be able to talk again she was getting so much out of breath and still the queen cried faster faster and dragged her along'
			]))
			print(test_punctuation([ 'tensorflow is a library that is used for machine learning it is available for more languages than just python' ]))
			print(test_punctuation([ 'the joplin note taking app can be used to take multimedia notes' ]))
			print(test_punctuation([ 'here are a few words javascript typescript python joplin interesting loud and sequence these words are all very useful' ]))

test_punctuation(tf.constant([ 'this is an example they said' ]))

'[test]: slow standard merely ##back military cousin reply style charles relief colonists importance piece ##pted suicide per roman ##zer interview fuel ##ca originally ##ground empty wind won villefort seasons boat wine ##owed ##istic ##ap oh atmosphere ##ization abbé 2010 descended nor my probably ##tered throne collection woods paper lost slavery star ##father ##ey inhabitants ##ak board few'

In [55]:
history = model.fit(
	dataset_train,
	epochs = 30,
	steps_per_epoch = 1200,
	validation_data = dataset_validate,
	callbacks=[DemoCallback()]
)

Epoch 1/30
 [test]: [cap] i ' m test ##ing this models performance ? [cap] how well is it working ?.1748 - masked_accuracy: 0.9517 - masked_loss: 0.1748
[test]: [cap] i think it ' s working well , but it ' s really hard to tell why are the question mark ' s miss ##ing .
[test]: [cap] not that [cap] alice had any idea of doing that she felt as if she would never be able to talk again she was getting so much out of breath and still , the [cap] queen cried faster faster and dr ##ag ##ged her along .
[test]: [cap] ten ##s ##or ##f ##lo ##w is a library that is used for machine learn ##ing it is available for more languages than just p ##y ##th ##on .
[test]: [cap] the [cap] j ##op ##lin note taking a ##p ##p can be used to take m ##ult ##imed ##ia notes .
[test]: [cap] here are a few words , [cap] j ##a ##va ##s ##cript types ##cript , [cap] p ##y ##th ##on , [cap] j ##op ##lin interest ##ing loud and sequence these words are all very useful .
[1m1200/1200[0m [32m━━━━━━━━━━━━━━━━━━━━[0

KeyboardInterrupt: 

In [56]:

print(test_punctuation([
	'not that alice had any idea of doing that she felt as if she would never be able to talk again she was getting so much out of breath and still the queen cried faster faster and dragged her along'
]))
print(test_punctuation([ 'this is a test of the punctuation system for i am curious how well it works will it work' ]))

[test]: [cap] not that [cap] alice had any idea of doing that she felt as if she would never be able to talk again , she was getting so much out of breath , and still the [cap] queen cried , faster , faster and dr ##ag ##ged her along .
[test]: [cap] this is a test of the p ##un ##ct ##uation system , for [cap] i am c ##urious how well it works will it work .


## Exporting

Based on the [Export](https://www.tensorflow.org/text/tutorials/nmt_with_attention#export) section of the tutorial:

In [57]:
class Export(tf.Module):
	def __init__(self, model):
		self.model = model
	
	@tf.function(input_signature=[tf.RaggedTensorSpec(dtype=tf.int64, shape=[None])])
	def fix_punctuation(self, input):
		# Returns encoded tokens
		return model.fix_punctuation_raw(
			tf.reshape(input, [1, -1])
		)

Run `fix_punctuation` once to compile it:

In [58]:
export = Export(model)

In [61]:
sample_inputs = context_text_processor('this sentence shall be punctuated for the following reasons first punctatuion makes things easier to read second um')
model.decoder.tokens_to_text(export.fix_punctuation(sample_inputs))

TypeError: Binding inputs to tf.function failed due to `The two structures don't have the same nested structure.

First structure: type=list str=[TensorSpec(shape=(None,), dtype=tf.int64, name=None)]

Second structure: type=list str=[<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([   3,  122,  826,  969,  285,  116,   50,  984,  426, 1432,  104,
         89,  413, 3044,  161,   50,  984,  426,  396,  276,  214, 1429,
        689,   39,  469,  758,   94,  836,  342,   55,  194,    4])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([ 0, 32])>]

More specifically: The two structures don't have the same number of elements. First structure: type=list str=[TensorSpec(shape=(None,), dtype=tf.int64, name=None)]. Second structure: type=list str=[<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([   3,  122,  826,  969,  285,  116,   50,  984,  426, 1432,  104,
         89,  413, 3044,  161,   50,  984,  426,  396,  276,  214, 1429,
        689,   39,  469,  758,   94,  836,  342,   55,  194,    4])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([ 0, 32])>]
Entire first structure:
[.]
Entire second structure:
[., .]`. Received args: (<tf.RaggedTensor [[3, 122, 826, 969, 285, 116, 50, 984, 426, 1432, 104, 89, 413, 3044, 161,
  50, 984, 426, 396, 276, 214, 1429, 689, 39, 469, 758, 94, 836, 342, 55,
  194, 4]]>,) and kwargs: {} for signature: (input: RaggedTensorSpec(TensorShape([None]), tf.int64, 0, tf.int64)).

Now we save the model:

In [35]:
tf.saved_model.save(export, 'punctuator-seq2seq', signatures={ 'serving_default': export.fix_punctuation })



INFO:tensorflow:Assets written to: punctuator-seq2seq/assets


INFO:tensorflow:Assets written to: punctuator-seq2seq/assets


See [the documentation](https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export) for information about the `signatures` option.

In [36]:
import json

web_output_dir = Path('web')
vocab_output_file = web_output_dir / 'wordEncodings.ts'

vocab_output_file.write_text('''
// Auto-generated file!
// Created by v2-seq2seq.ipynb
export default {};
'''.format(json.dumps(target_text_processor.get_vocabulary(), indent = '\t')))

43430

### Testing the saved model

In [37]:
reloaded = tf.saved_model.load('punctuator-seq2seq')
# Warmup
reloaded.fix_punctuation(sample_inputs)
print('Imported and warmed up!')

Imported and warmed up!


In [38]:
%%time
model.decoder.tokens_to_text(reloaded.fix_punctuation(sample_inputs))


CPU times: user 152 ms, sys: 64.7 ms, total: 216 ms
Wall time: 65.7 ms


<tf.Tensor: shape=(), dtype=string, numpy=b'[cap] thi [s] sentence shall be [UNK] [ed] for the follow [ing] reason [s] . [cap] first [UNK] make [s] thing [s] easi [er] to read second [UNK] .'>

In [39]:
context_text_processor('this sentence shall be punctuated for the following reasons first punctatuion makes things easier to read second um')

<tf.Tensor: shape=(27,), dtype=int64, numpy=
array([   8,   45,    3, 2553,  193,   40,    1,    7,   28,    5,  139,
         14,  556,    3,   78,    1,  146,    3,  253,    3, 1003,   12,
         13,  384,  213,    1,    9])>