# V2: SEQ2SEQ

This notebook follows [an online tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention#create_a_tfdata_dataset).

Other resources that may be helpful in the future: https://arxiv.org/pdf/2111.10746, https://aclanthology.org/2021.findings-emnlp.393.pdf

In [1]:
import tensorflow as tf

2024-09-10 22:23:44.633762: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-10 22:23:44.637183: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-10 22:23:44.647186: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-10 22:23:44.670967: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-10 22:23:44.678442: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-10 22:23:44.702345: I tensorflow/core/platform/cpu_feature_gu

**Notes**:
- This notebook follows [an online tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention) (and [at least one other](https://www.tensorflow.org/text/tutorials/text_generation) of the Tensorflow tutorials).
- This [blog post](https://janakiev.com/blog/jupyter-virtual-envs/) was referenced to set up the virtual environment.

In [2]:
import numpy as np
from pathlib import Path

import tensorflow_datasets as tfds


In [3]:
dataset_raw_train = tfds.load('ai2_arc_with_ir', split='train', shuffle_files=True)
dataset_raw_test = tfds.load('ai2_arc_with_ir', split='test', shuffle_files=True)
dataset_raw = dataset_raw_train.concatenate(dataset_raw_test)
dataset_raw = dataset_raw.map(lambda x: x['paragraph'])
for i in dataset_raw.take(1):
	print(i)

tf.Tensor(b'It and the other noble gases - helium, neon, krypton, xenon, and radon - will react with other substances only under extreme conditions. The noble gases The noble, or inert, gases are helium, neon, argon, krypton, xenon and radon. The rare gases are helium, neon, argon, krypton or xenon. The noble "gases" are helium, neon, argon, krypton and xenon. The noble gases are helium, neon, argon, krypton, xenon and radon. These occur for the noble gases helium, neon, argon, krypton, xenon, radon and ununoctium. Because of their chemical inertness, the elements in this group are called the Nobel Gases : Helium Neon Argon Krypton Xenon Radon Group I Elements The elements in this group have one electron in their outer electronic shell. The inert, or noble, gases (helium, neon, argon, krypton, xenon, and radon) all have completely filled outer shells. The other noble gases, which together make about 1% of the Earth\'s atmosphere, are neon, argon, krypton, xenon and radon. Most of the n

2024-09-10 22:23:47.498980: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-09-10 22:23:47.501014: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [4]:
data_file_paths = list(Path('data/processed/en/').glob('*.txt'))
dataset_raw = dataset_raw.concatenate(
	tf.data.TextLineDataset(data_file_paths)
)

In [5]:
dataset_raw_mlqa = tfds.load('mlqa/en', split='test', shuffle_files=True) # "train" not available
dataset_raw_mlqa = dataset_raw_mlqa.map(lambda x: x['context'])
dataset_raw = dataset_raw_mlqa.concatenate(dataset_raw)

In [6]:

def remove_unsupported_characters(s):
	return tf.strings.regex_replace(s, r'[^A-Za-z0-9ùúûüÿàâæçéèêëïîôœÙÚÛÜŸÀÂÆÇÉÈÊËÏÎÔŒ \t\n.,?!\-\':;&]', '')
dataset_raw = dataset_raw.map(remove_unsupported_characters)

We've now loaded the `.txt` training data files using [`tf.data.TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset). Each line in the source files is mapped to a new training example. 

Although some preprocessing has been done by `/data/process_data.py`, paragraphs aren't filtered out based on length/content. Let's do that now:

In [7]:

def break_long_sequences(paragraph):
	return tf.strings.regex_replace(paragraph, r'.{0,164}[.?]', r'\0 [SEP]')

def split_on_separators(text):
	return tf.data.Dataset.from_tensor_slices(tf.strings.split(text, '[SEP]'))

def filter_paragraphs(context, target):
	return tf.strings.length(context) > 5 and tf.strings.length(target) > 5

punctuation_chars = r'\?!.,"\-\':;'
def add_context(target):
	context = tf.strings.regex_replace(target, r'[\-]{2,}', ' - ')
	context = tf.strings.regex_replace(context, r'[\-\']', '')
	context = tf.strings.regex_replace(context, '[{}]+'.format(punctuation_chars), ' ')
	context = tf.strings.strip(
		tf.strings.regex_replace(context, '[ ]+', ' ')
	)
	context = tf.strings.lower(context)
	return context, target

dataset_raw = dataset_raw.map(break_long_sequences).flat_map(split_on_separators).map(add_context).filter(filter_paragraphs)
for text, label in dataset_raw.take(6):
	print('item', text, label)

item tf.Tensor(b'after completing the journey around south america on 23 february 2006 queen mary 2 met her namesake the original rms queen mary which is permanently docked at long beach california', shape=(), dtype=string) tf.Tensor(b'After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. ', shape=(), dtype=string)
item tf.Tensor(b'escorted by a flotilla of smaller ships the two queens exchanged a whistle salute which was heard throughout the city of long beach', shape=(), dtype=string) tf.Tensor(b' Escorted by a flotilla of smaller ships, the two Queens exchanged a whistle salute which was heard throughout the city of Long Beach. ', shape=(), dtype=string)
item tf.Tensor(b'queen mary 2 met the other serving cunard liners queen victoria and queen elizabeth 2 on 13 january 2008 near the statue of liberty in new york city harbour with a celebratory firework

2024-09-10 22:23:48.153853: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-09-10 22:23:48.162247: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Now let's inspect the data:

In [8]:
for text, label in dataset_raw.take(4).as_numpy_iterator():
	print('item', text, label)

item b'after completing the journey around south america on 23 february 2006 queen mary 2 met her namesake the original rms queen mary which is permanently docked at long beach california' b'After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. '
item b'escorted by a flotilla of smaller ships the two queens exchanged a whistle salute which was heard throughout the city of long beach' b' Escorted by a flotilla of smaller ships, the two Queens exchanged a whistle salute which was heard throughout the city of Long Beach. '
item b'queen mary 2 met the other serving cunard liners queen victoria and queen elizabeth 2 on 13 january 2008 near the statue of liberty in new york city harbour with a celebratory fireworks display queen elizabeth 2 and queen victoria made a tandem crossing of the atlantic for the meeting' b' Queen Mary 2 met the other serving Cunard li

2024-09-10 22:23:48.442871: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


### Batching

In [9]:
BUFFER_SIZE = 100_000
BATCH_SIZE = 32
dataset_train = dataset_raw.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Inspired by https://stackoverflow.com/a/74609848.
validate_size = 2
dataset_validate = dataset_train.take(validate_size)
dataset_train = dataset_train.skip(validate_size)

### Preparing to process data

The [`TextVectorization`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer takes a `standardize` option that preprocesses input data. The default removes punctuation, but we don't want that. Let's redefine it:

In [10]:

def standardize_tf_text(text):
	punctuation_regex = '[{p}]'.format(p = punctuation_chars)

	# Clean up &quot artifacts
	text = tf.strings.regex_replace(text, r'&quot;?', r' " ')

	# Surround punctuation with spaces for easier tokenization
	text = tf.strings.regex_replace(text, punctuation_regex, r' \0 ')

	# Replace punctuation this task isn't concerned with
	text = tf.strings.regex_replace(text, r' [.?!] ', '')

	# Remove repeated spaces
	text = tf.strings.regex_replace(text, r'\s+', ' ')

	# Add a special "capitalize the next letter" token
	# Disabled for this part:
	# text = tf.strings.regex_replace(text, r'(\s|^)([A-Z])', r' [CAP] \2')

	# Lowercase everything
	text = tf.strings.lower(text)

	# English specific: Move -ing, -s suffixes to new words
	text = tf.strings.regex_replace(text, r'([a-z]{3,})(ing|er|ed|ily|ly|ish|s)(\s|$)', r'\1 [\2]\3')

	# Remove leading and trailing spaces
	text = tf.strings.strip(text)

	# Add sequence markings
	return tf.strings.join(['[START]', text, '[END]'], separator=' ')

print(standardize_tf_text('This is a test! It\'s working?!'))

tf.Tensor(b"[START] thi [s] is a test , it ' s work [ing] , , [END]", shape=(), dtype=string)


The text standardization function can now be used to preprocess text:

In [11]:
# Keep only the 4000 most commonly used tokens
max_vocab_size = 4000

target_text_processor = tf.keras.layers.TextVectorization(
	standardize=standardize_tf_text,
	max_tokens=max_vocab_size,
	# Allow entries of different lengths
	ragged=True,
)
target_text_processor.adapt(dataset_train.map(lambda context, target: target))

print('First 14 target words:', target_text_processor.get_vocabulary()[:14])

# The target data should be roughly equivalent to the context data, except have additional (punctuation)
# tokens.
context_text_processor = target_text_processor

2024-09-10 22:23:58.837566: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:63: Filling up shuffle buffer (this may take a while): 53084 of 100000
2024-09-10 22:24:07.591801: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:480] Shuffle buffer filled.


First 14 target words: ['', '[UNK]', ',', '[s]', 'the', '[ed]', '[START]', '[END]', 'of', 'and', '[er]', 'to', '[ing]', 'in']


2024-09-10 22:25:04.800246: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


We can use these layers to convert to/from token IDs:


In [12]:
example_text = 'hello world this is a test tensorflow is processing this'
example_tokens = context_text_processor(example_text)
print('Example tokens', example_tokens)

context_vocab = np.array(context_text_processor.get_vocabulary())
tokens = context_vocab[example_tokens.numpy()]
print('Back to text', ' '.join(tokens))

Example tokens tf.Tensor(
[   6    1  125   43    3   20   14  546    1   20 3835   12   43    3
    7], shape=(15,), dtype=int64)
Back to text [START] [UNK] world thi [s] is a test [UNK] is process [ing] thi [s] [END]


### Processing the data

Now, we'll:
1. Map the data through the text processors we just made.
2. Shift the target data, so that our network is provided with a history of generated tokens.

In [13]:
def process_text(context, target):
	return context_text_processor(context), target_text_processor(target)

print('pre', context_text_processor('this is a test'))
print(dataset_train.map(process_text))
print('post')

def add_target_history(context, target):
	# .to_tensor(): Converts from RaggedTensors to Tensors.
	# We give our network the history as target_in
	target_in = target[:, :-1].to_tensor()
	target_out = target[:, 1:].to_tensor()
	return (context.to_tensor(), target_in), target_out
dataset_train = dataset_train.map(process_text).map(add_target_history).repeat()
dataset_validate = dataset_validate.map(process_text).map(add_target_history)

pre tf.Tensor([  6  43   3  20  14 546   7], shape=(7,), dtype=int64)
<_MapDataset element_spec=(RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int64), RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int64))>
post


In [14]:

def inspect_dataset(dataset: tf.data.Dataset):
	target_vocab = np.array(target_text_processor.get_vocabulary())
	for (context, target_in), target_out in dataset.take(1):
		context_words = context_vocab[context[0]]
		print('context', ','.join(context_words))
		print('target_in', ','.join(target_vocab[target_in[0]]))
		print('target_out', ','.join(target_vocab[target_out[0]]))

inspect_dataset(dataset_train)

2024-09-10 22:25:15.873282: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:63: Filling up shuffle buffer (this may take a while): 52483 of 100000


context [START],in,ancient,time,[s],greek,[s],and,persian,[s],had,been,at,war,for,centurie,[s],and,the,persian,priest,[s],call,[ed],[UNK],in,persian,came,to,be,known,as,[UNK],in,greek,[END],,,,,,,,,,,,,,,,,,,,,,
target_in [START],in,ancient,time,[s],,,greek,[s],and,persian,[s],had,been,at,war,for,centurie,[s],,,and,the,persian,priest,[s],,,call,[ed],[UNK],in,persian,,,came,to,be,known,as,[UNK],in,greek,,,,,,,,,,,,,,,,,,,,,,,,,,
target_out in,ancient,time,[s],,,greek,[s],and,persian,[s],had,been,at,war,for,centurie,[s],,,and,the,persian,priest,[s],,,call,[ed],[UNK],in,persian,,,came,to,be,known,as,[UNK],in,greek,,,[END],,,,,,,,,,,,,,,,,,,,,,,,


2024-09-10 22:25:20.275417: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:480] Shuffle buffer filled.


## Model

### The encoder

See https://www.tensorflow.org/text/tutorials/nmt_with_attention#the_encoder

In [15]:
class Encoder(tf.keras.Layer):
	def __init__(self, text_processor, units: int):
		"""
		Creates a new Encoder layer. [dimen] is the maxiumum number of elements of the input
		that can be processed by the encoder.
		"""
		super(Encoder, self).__init__()
		self.text_processor = text_processor
		self.vocab_size = text_processor.vocabulary_size()
		self.units = units

		# Converts tokens -> vectors
		self.embedding = tf.keras.layers.Embedding(
			# mask_zero: Treats zero as a padding value that should be ignored
			self.vocab_size, units, mask_zero = True,
		)
		gru = tf.keras.layers.GRU(
			units, return_sequences = True,
			# Use the recurrent_initializer suggested by the tutorial (& the default
			# for kernel_initializer).
			recurrent_initializer='glorot_uniform'
		)
		self.rnn = tf.keras.layers.Bidirectional(
			# merge_mode determines how the forward and backward layers are combined
			#            'concat' is another option here
			merge_mode = 'sum',
			layer=gru,
		)

	def call(self, x):
		x = self.embedding(x)
		x = self.rnn(x)
		return x

	def prepare_for_input(self, texts):
		"""
		Utility method that converts `texts` to a form that can be provided to the `call` method.
		"""
		texts = tf.convert_to_tensor(texts)
		if len(texts.shape) == 0:
			texts = texts[None]
		context = self.text_processor(texts).to_tensor()
		return context


Try it:

In [16]:
ENCODER_UNITS = 256
encoder = Encoder(context_text_processor, ENCODER_UNITS)

for (context, target_history), target_next in dataset_validate.take(1):
	encoder_result = encoder(context)
	print('Context tokens shape (batch, s):', context.shape)
	print('Encoder output shape (batch, s, ENCODER_UNITS):', encoder_result.shape)

Context tokens shape (batch, s): (16, 57)
Encoder output shape (batch, s, ENCODER_UNITS): (16, 57, 256)


### The attention layer

Attention can be thought of as training a lookup table with keys and values. The lookup table has inputs `values` and `query`.

In [17]:
class CrossAttention(tf.keras.Layer):
	def __init__(self, units, **kwargs):
		super().__init__()
		self.attention_layer = tf.keras.layers.MultiHeadAttention(
			key_dim=units,
			num_heads=1,
			**kwargs
		)
		# Keeps "the mean activation within each example close to 0 and the
		# activation standard deviation close to 1" -- https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization?hl=en
		self.norm_layer = tf.keras.layers.LayerNormalization()
		self.add_layer = tf.keras.layers.Add()
		self.supports_masking = True

	def call(self, query, value):
		attention_output = self.attention_layer(
			query = query,
			value = value,
			#use_causal_mask=True,
			# Return the attention scores for latter plotting
			# return_attention_scores = True,
		)

		x = self.add_layer([ query, attention_output ])
		x = self.norm_layer(x)
		return x


In [18]:
attention_layer = CrossAttention(ENCODER_UNITS)

# Test with an example
for (context, target_history), target_next in dataset_validate.take(1):
	embed_layer = tf.keras.layers.Embedding(target_text_processor.vocabulary_size(), output_dim=ENCODER_UNITS, mask_zero=True)
	target_embed = embed_layer(target_history)
	encoded_context = encoder(context)
	attention_result = attention_layer(target_embed, encoded_context)

	print('Encoded context sequence shape (batch, s, units):', encoded_context.shape)
	print('Target history sequence shape (batch, t, units):', target_embed.shape)
	print('Attention result shape (batch, t, units):', attention_result.shape)

	# Used later 
	test_encoded_context = encoded_context

Encoded context sequence shape (batch, s, units): (16, 47, 256)
Target history sequence shape (batch, t, units): (16, 52, 256)
Attention result shape (batch, t, units): (16, 52, 256)




### The decoder

The decoder produces queries for the attention layer. The decoder operates on `target_history`. At each step during training, it should have no information about future target output (that's what we're trying to determine). As such, we use a unidirectional RNN.


In [19]:
class CustomDense(tf.keras.layers.Dense):
	def __init__(self, *args, **kwargs):
		super(CustomDense, self).__init__(*args, **kwargs)
	
	def compute_mask(self, _inputs, mask=None):
		return mask

class Decoder(tf.keras.Layer):
	def __init__(self, text_processor, units):
		super(Decoder, self).__init__()
		self.text_processor = text_processor
		self.vocab_size = text_processor.vocabulary_size()
		self.units = units

		self.embedding_layer = tf.keras.layers.Embedding(
			# mask_zero: Treats zero as a padding value that should be ignored
			self.vocab_size, units, mask_zero = True,
		)
		self.rnn_layer = tf.keras.layers.GRU(
			units, return_sequences = True, return_state = True, recurrent_initializer='glorot_uniform',
		)
		self.attention_layer = CrossAttention(units)

		# Creates logits with the estimated probability of each output token
		self.output_layer = CustomDense(self.vocab_size)

		# Conversion:
		self.word_to_id = tf.keras.layers.StringLookup(
			vocabulary = text_processor.get_vocabulary(),
			mask_token = '',
			oov_token = '[UNK]',
		)
		self.id_to_word = tf.keras.layers.StringLookup(
			vocabulary = text_processor.get_vocabulary(),
			mask_token = '',
			oov_token = '[UNK]',
			invert = True,
		)
		# Pre-computing these simplifies exporting
		self.start_id = self.word_to_id('[START]')
		self.end_id = self.word_to_id('[END]')

		self.supports_masking = True
	
	def build(self, input_shape):
		# Nothing tha needs a size allocation based on the input shape
		pass
	
	def call(self, context, target_history, state = None, return_state = False):
		x = self.embedding_layer(target_history)
		x, state = self.rnn_layer(x, initial_state = state)
		x = self.attention_layer(x, context)

		logits = self.output_layer(x)
		if return_state:
			return logits, state
		else:
			return logits
	
	## Conversion/testing ##

	def tokens_to_text(self, tokens):
		text = tf.strings.reduce_join(self.id_to_word(tokens), separator = ' ')
		text = tf.strings.regex_replace(text, r'\s*\[START\]\s*', '')
		text = tf.strings.regex_replace(text, r'\s*\[END\].*$', '')
		return text

	def generate_next_token(self, context, target_history, done_vec, state, temperature = 0.0):
		# Note: is_done is a vector, indicating whether each item in the batch is done

		logits, state = self(context, target_history, state = state, return_state = True)

		# logits has shape (batch, t, target_vocab_size). Only generate the token corresponding
		# to the last logits in the sequence (at t - 1)
		if temperature > 0:
			next_token = tf.where(
				done_vec,
				tf.constant(0, dtype=tf.int64), # Emit 0 after a sequence is done
				tf.random.categorical(logits[:, -1, :] / temperature, num_samples = 1), # Otherwise, pick the token from a categorical distribution
			)
		else:
			next_token = tf.math.argmax(logits, axis=-1)
		done_vec = done_vec|(next_token == self.end_id)
		return next_token, done_vec, state
	
	def get_initial_state(self, context):
		# context has shape (batch_size, s, units)
		batch_size = tf.shape(context)[0]
		start_tokens = tf.fill([batch_size, 1], self.start_id)
		done_vec = tf.zeros([batch_size, 1], dtype = tf.bool)

		# From the Tensorflow source code:
		# > RNN expect the states in a list, even if single state.
		# Note: Without the [0] we get a type mismatch while exporting.
		initial_state = self.rnn_layer.get_initial_state(batch_size)[0]

		return start_tokens, done_vec, initial_state

Let's try it!

In [20]:
def test_generation_loop():
	decoder = Decoder(target_text_processor, ENCODER_UNITS)
	next_token, done_vec, state = decoder.get_initial_state(test_encoded_context[:3, :, :])
	tokens = [next_token]

	for i in range(8):
		next_token, done_vec, state = decoder.generate_next_token(test_encoded_context[:3, :, :], next_token, done_vec, state)
		tokens.append(next_token)
	
	# Merge all batch outputs into a single dimension
	tokens = tf.concat(tokens, -1) # -1 = last axis

	print('Output:', decoder.tokens_to_text(tokens).numpy())

test_generation_loop()

Output: b'healthy finish fetch strong 400 37 sake slipphealthy finish fetch strong 400 37 sake slipphealthy finish fetch strong 400 37 sake slipp'


## The model

We can now build a model for training and punctuation:

In [21]:
class Punctuator(tf.keras.Model):
	def __init__(self, units, context_text_processor, target_text_processor):
		super().__init__()
		self.encoder = Encoder(context_text_processor, units)
		self.decoder = Decoder(target_text_processor, units)
	
	def call(self, inputs):
		context, target_history = inputs
		context = self.encoder(context)
		logits = self.decoder(context, target_history)
		return logits
	
	def fix_punctuation_raw(self, input):
		"""
		Adds punctuation to `input`, where `input` is a `Tensor` with shape (batch_size, s) where s is the
		context length.
		"""
		context = self.encoder(input)

		next_token, done_vec, state = self.decoder.get_initial_state(context)

		# Although a TensorArray would allow more efficient exporting, the ONNX exporter seems to
		# have trouble with it. For now, use a Python list.
		tokens = []
		max_iterations = 56

		for i in range(max_iterations):
			# token_history has size: (batch, t, target_vocab_size)
			# token_history = tf.concat(tokens, 1)
			# print('history', model.decoder.id_to_word(token_history))
			next_token, done_vec, state = self.decoder.generate_next_token(context, next_token, done_vec, state, temperature=0)
			#tokens = tokens.write(i + 1, next_token)
			tokens.append(next_token)

			if tf.executing_eagerly() and tf.reduce_all(done_vec):
				break
		
		tokens = tf.concat(tokens, -1)
		return tokens

	def fix_punctuation(self, text: list[str]):
		inputs = self.encoder.prepare_for_input(text)
		tokens = self.fix_punctuation_raw(inputs)
		return self.decoder.tokens_to_text(tokens)

In [22]:
model = Punctuator(ENCODER_UNITS, context_text_processor, target_text_processor)

for (example_context_tok, example_target_hist), _ in dataset_validate.take(1):
	test_logits = model((example_context_tok, example_target_hist))
	print('Context tokens shape (batch, s):', example_context_tok.shape)
	print('Target history tokens shape (batch, t):', example_target_hist.shape)
	print('Logits shape (batch, t, vocab_size)', test_logits.shape)

Context tokens shape (batch, s): (16, 83)
Target history tokens shape (batch, t): (16, 92)
Logits shape (batch, t, vocab_size) (16, 92, 4000)


2024-09-10 22:25:50.110662: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [23]:
model.summary()

To avoid penalizing masked outputs, we use a custom loss function (see the tutorial):

In [24]:
base_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def masked_loss(y_true, y_predict):
	loss = base_loss_fn(y_true, y_predict)
	
	unmasked = y_true != 0
	unmasked = tf.cast(unmasked, loss.dtype)
	# Only consider output with a corresponding label.
	loss *= unmasked

	count_unmasked = tf.math.reduce_sum(unmasked)

	# reduce_sum: Adds all entries of a vector.
	return tf.math.reduce_sum(loss)/count_unmasked

In [25]:
def masked_accuracy(y_true, predict_logits):
	predicted_index = tf.math.argmax(predict_logits, axis=-1)
	predicted_index = tf.cast(predicted_index, y_true.dtype)

	match = tf.cast(y_true == predicted_index, tf.float32)
	unmasked = tf.cast(y_true != 0, tf.float32)
	count_unmasked = tf.math.reduce_sum(unmasked)

	return tf.math.reduce_sum(match * unmasked) / count_unmasked


In [26]:
#model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[masked_accuracy, masked_loss])
model.compile(optimizer='adam', loss=masked_loss, metrics=[masked_accuracy, masked_loss])

In [27]:
print('From the tutorial:')
vocab_size = float(target_text_processor.vocabulary_size())

print('expected loss', tf.math.log(vocab_size).numpy())
print('expected accuracy', 1/vocab_size)

From the tutorial:
expected loss 8.294049
expected accuracy 0.00025


In [28]:
model.evaluate(dataset_validate, steps=20, return_dict=True)


[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 9ms/step - loss: 8.3182 - masked_accuracy: 0.0013 - masked_loss: 5.8228 


  self.gen.throw(typ, value, traceback)


{'loss': 8.317972183227539,
 'masked_accuracy': 0.001192605821415782,
 'masked_loss': 5.545314788818359}

In [29]:


def test_punctuation(text):
	return '[test]: ' + model.fix_punctuation(text).numpy().decode('utf-8')

class DemoCallback(tf.keras.callbacks.Callback):
	def on_epoch_end(self, epoch_index: int, logs = None):
		print('\r', test_punctuation([ 'im testing this models performance how well is it working' ]))
		print(test_punctuation([ 'i think its working well but its really hard to tell why are the question marks missing' ]))
		if epoch_index % 3 == 0:
			# From the test data
			print(test_punctuation([
				'not that alice had any idea of doing that she felt as if she would never be able to talk again she was getting so much out of breath and still the queen cried faster faster and dragged her along'
			]))
			print(test_punctuation([ 'tensorflow is a library that is used for machine learning it is available for more languages than just python' ]))
			print(test_punctuation([ 'the joplin note taking app can be used to take multimedia notes' ]))
			print(test_punctuation([ 'here are a few words javascript typescript python joplin interesting loud and sequence these words are all very useful' ]))

test_punctuation(tf.constant([ 'this is an example they said' ]))

'[test]: suppose fate cristo velocity bless tie beside meet route hollow walt stair dar movement writer oh resemble windsor gate 1996 hundred anxiety survey 1980 x struck cricket septemb eight within coastal formation handkerchief decision globe mysteriou earn symbol establ examination 13 1948 test sharp try 1985 talk afraid attitude worse knot promote gratitude preserve catholic reaction'

In [30]:
history = model.fit(
	dataset_train,
	epochs = 30,
	steps_per_epoch = 1200,
	validation_data = dataset_validate,
	callbacks=[DemoCallback()]
)

Epoch 1/30


2024-09-10 22:26:16.769648: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:63: Filling up shuffle buffer (this may take a while): 97268 of 100000


[1m   1/1200[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m5:12:46[0m 16s/step - loss: 8.3111 - masked_accuracy: 0.0000e+00 - masked_loss: 8.3111

2024-09-10 22:26:16.998084: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:480] Shuffle buffer filled.


[1m 620/1200[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m2:31[0m 262ms/step - loss: 4.7711 - masked_accuracy: 0.2643 - masked_loss: 4.7711

KeyboardInterrupt: 

In [70]:

print(test_punctuation([
	'not that alice had any idea of doing that she felt as if she would never be able to talk again she was getting so much out of breath and still the queen cried faster faster and dragged her along'
]))
print(test_punctuation([ 'this is a test of the punctuation system for i am curious how well it works will it work' ]))

[test]: [cap] not that [cap] alice had any idea of doing that she felt as if she would nev [er] be able to talk again , she was gett [ing] so much out of breath , and still the queen ! cri [ed] fast [er] , him , and her [ed] her of her along .
[test]: [cap] thi [s] is a test of the [UNK] system for [cap] i am curiou [s] how well it work [s] will it work .


## Exporting

Based on the [Export](https://www.tensorflow.org/text/tutorials/nmt_with_attention#export) section of the tutorial:

In [67]:
class Export(tf.Module):
	def __init__(self, model):
		self.model = model
	
	@tf.function(input_signature=[tf.RaggedTensorSpec(dtype=tf.int64, shape=[None])])
	def fix_punctuation(self, input):
		# Returns encoded tokens
		return model.fix_punctuation_raw(
			tf.reshape(input, [1, -1])
		)

Run `fix_punctuation` once to compile it:

In [68]:
export = Export(model)

In [69]:
sample_inputs = context_text_processor('this sentence shall be punctuated for the following reasons first punctatuion makes things easier to read second um')
model.decoder.tokens_to_text(export.fix_punctuation(sample_inputs))

<tf.Tensor: shape=(), dtype=string, numpy=b'[cap] thi [s] sentence shall be [UNK] [ed] for the follow [ing] reason [s] , first [UNK] make [s] thing [s] easi [er] to read second [UNK] .'>

Now we save the model:

In [70]:
tf.saved_model.save(export, 'punctuator-seq2seq', signatures={ 'serving_default': export.fix_punctuation })



INFO:tensorflow:Assets written to: punctuator-seq2seq/assets


INFO:tensorflow:Assets written to: punctuator-seq2seq/assets


See [the documentation](https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export) for information about the `signatures` option.

In [71]:
import json

web_output_dir = Path('web')
vocab_output_file = web_output_dir / 'wordEncodings.ts'

vocab_output_file.write_text('''
// Auto-generated file!
// Created by v2-seq2seq.ipynb
export default {};
'''.format(json.dumps(target_text_processor.get_vocabulary(), indent = '\t')))

43714

### Testing the saved model

In [72]:
reloaded = tf.saved_model.load('punctuator-seq2seq')
# Warmup
reloaded.fix_punctuation(sample_inputs)
print('Imported and warmed up!')

Imported and warmed up!


In [73]:
%%time
model.decoder.tokens_to_text(reloaded.fix_punctuation(sample_inputs))


CPU times: user 135 ms, sys: 43.9 ms, total: 179 ms
Wall time: 93.1 ms


<tf.Tensor: shape=(), dtype=string, numpy=b'[cap] thi [s] sentence shall be [UNK] [ed] for the follow [ing] reason [s] , first [UNK] make [s] thing [s] easi [er] to read second [UNK] .'>

In [74]:
context_text_processor('this sentence shall be punctuated for the following reasons first punctatuion makes things easier to read second um')

<tf.Tensor: shape=(27,), dtype=int64, numpy=
array([   6,   46,    3, 2257,  173,   39,    1,    9,   31,    5,  159,
         14,  608,    3,  130,    1,  119,    3,  220,    3,  792,   12,
         13,  459,  313,    1,    7])>