In [1]:
import tensorflow as tf

2024-09-01 13:19:46.729584: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-01 13:19:46.732574: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-01 13:19:46.741594: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-01 13:19:46.756655: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-01 13:19:46.760760: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-01 13:19:46.771911: I tensorflow/core/platform/cpu_feature_gu

**Notes**:
- This notebook follows [an online tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention) (and [at least one other](https://www.tensorflow.org/text/tutorials/text_generation) of the Tensorflow tutorials).
- This [blog post](https://janakiev.com/blog/jupyter-virtual-envs/) was referenced to set up the virtual environment.

In [2]:
import numpy as np
import typing
from typing import Any, Tuple
from prepare_data import load_data, reconstruct_from_labels, preprocess_text
from IPython.display import display, Markdown



In [3]:
context_raw, target_raw = load_data('./data/en/')

We store the **expected** output in `target_raw` and the input to our model in `context_raw`. Let's see an example:

In [4]:
target_raw

array(['1', '2', '1', ..., '1', '1', '.'], dtype='<U1')

Each element in `target_raw` is an operation (e.g. 0 = copy) followed by a character code. For example, `1` is "capitalize" and `0` is "copy". Note that `120` (`ord(x)`) is used for operations that take no arguments.

In [5]:
target_raw[24], context_raw[24]

('1', 'v')


## Creating a dataset

We begin by vectorizing our data. `target_raw` and `context_raw` are already tokenized by characters/operations.

We start by creating a vectorization for the `target_raw`.

In [6]:
input_vocab = sorted(set(context_raw))
print('Input vocab size in chars:', len(input_vocab))

chars_to_ids_in = tf.keras.layers.StringLookup(vocabulary=input_vocab)
# Invert: Map chars to IDs instead of IDs to chars
ids_to_chars_in = tf.keras.layers.StringLookup(vocabulary=chars_to_ids_in.get_vocabulary(), invert=True)

# in: "Input"
def text_from_ids_in(ids: list[int]):
	return tf.strings.reduce_join(ids_to_chars_in(ids), axis=-1)

Input vocab size in chars: 54


In [7]:
all_ids_input = chars_to_ids_in(context_raw)
all_ids_input

<tf.Tensor: shape=(7252805,), dtype=int64, numpy=array([40, 22, 25, ..., 16, 18,  1])>

Now, we do the same for the output.

In [8]:
output_vocab = sorted(set(target_raw))
print('Output vocab size in chars:', len(output_vocab))

chars_to_ids_out = tf.keras.layers.StringLookup(vocabulary=output_vocab)
# Invert: Map chars to IDs instead of IDs to chars
ids_to_chars_out = tf.keras.layers.StringLookup(vocabulary=chars_to_ids_out.get_vocabulary(), invert=True)


Output vocab size in chars: 8


In [9]:
sorted(output_vocab)

["'", ',', '-', '.', '1', '2', ':', '?']

In [10]:
all_ids_output = chars_to_ids_out(target_raw)
all_ids_output

<tf.Tensor: shape=(7252805,), dtype=int64, numpy=array([5, 6, 5, ..., 5, 5, 4])>

Now that we have vectorized inputs and outputs, let's create a `Dataset` we can feed to the model.

First, combine the expected inputs and outputs into a single vector:

In [11]:
def column(v):
	return tf.reshape(v, [-1, 1])

ids_and_outputs = tf.concat([
	column(all_ids_input), column(all_ids_output)
], 1)
ids_and_outputs

<tf.Tensor: shape=(7252805, 2), dtype=int64, numpy=
array([[40,  5],
       [22,  6],
       [25,  5],
       ...,
       [16,  5],
       [18,  5],
       [ 1,  4]])>

Next, create a `Dataset`:

In [12]:
# Separates ids_and_outputs along its first dimension into different items in the dataset.
ids_input_main = all_ids_input

# Tuples: from_tensor_slices pairs entries of each tuple item to produce the dataset.
input_dataset = tf.data.Dataset.from_tensor_slices(ids_input_main)
output_dataset = tf.data.Dataset.from_tensor_slices(all_ids_output)
dataset = tf.data.Dataset.zip(input_dataset, output_dataset)

# Preview the dataset -- demonstrates converting Tensors to numpy to text
for input, expected_sample_outputs in dataset.take(32):
	input_char = ids_to_chars_in(input).numpy().decode('utf-8')
	output_char = ids_to_chars_out(expected_sample_outputs).numpy().decode('utf-8')

	print('{} c({})'.format(input_char, output_char), end = ', ')


~ c(1), i c(2), l c(1), l c(1), u c(1), s c(1), t c(1), r c(1), a c(1), t c(1), i c(1), o c(1), n c(1),   c(1), ~ c(1), a c(2), l c(1), i c(1), c c(1), e c(1), s c('),   c(1), a c(2), d c(1), v c(1), e c(1), n c(1), t c(1), u c(1), r c(1), e c(1), s c(1), 

2024-09-01 13:19:55.462029: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [13]:
seq_length = 108

# batch: Convert the dataset to sequences of the target size.
# drop_remainder: Drop the last batch if it has fewer than seq_length elements
sequences = dataset.batch(seq_length + 1, drop_remainder=True)

# Roll each sequence so that it starts with a word or paragraph.
space_id_in = int(chars_to_ids_in([' '])[0])
parbreak_id_in = int(chars_to_ids_in(['~'])[0])


def shift_left_by(s, amount: int):
	padding = space_id_in * tf.ones([amount], dtype = tf.dtypes.int64)
	return tf.concat((s[amount:], padding), axis=0)

def word_align_sequences(in_seq, out_seq):
	first_space_index = tf.math.argmax(in_seq == space_id_in, axis=0)
	# Include a virtual paragraph break at the end of [a]. This means that first_paragraph_index
	# will be a large number if no paragraph break is found.
	paragraph_matches = tf.concat((in_seq == parbreak_id_in, [True]), axis=0)
	first_parbreak_index = tf.math.argmax(paragraph_matches, axis=0)

	roll_amount = tf.math.minimum(first_space_index + 1, first_parbreak_index)
	map_sequence = lambda s: shift_left_by(s, roll_amount)

	return map_sequence(in_seq), map_sequence(out_seq)
sequences = sequences.map(word_align_sequences)

for sample_inputs, expected_sample_outputs in sequences.take(12):
	print('Inputs:', text_from_ids_in(sample_inputs))

Inputs: tf.Tensor(b'~illustration ~alices adventures in wonderland ~by lewis carroll ~the millennium fulcrum edition 30 ~contents', shape=(), dtype=string)
Inputs: tf.Tensor(b'~chapter i down the rabbithole chapter ii the pool of tears chapter iii a caucusrace and a long tale chapter ', shape=(), dtype=string)
Inputs: tf.Tensor(b'iv the rabbit sends in a little bill chapter v advice from a caterpillar chapter vi pig and pepper chapter v ', shape=(), dtype=string)
Inputs: tf.Tensor(b'a mad teaparty chapter viii the queens croquetground chapter ix the mock turtles story chapter x the lobst   ', shape=(), dtype=string)
Inputs: tf.Tensor(b'quadrille chapter xi who stole the tarts chapter xii alices evidence ~chapter i down the rabbithole ~alice   ', shape=(), dtype=string)
Inputs: tf.Tensor(b'was beginning to get very tired of sitting by her sister on the bank and of having nothing to do once or twi ', shape=(), dtype=string)
Inputs: tf.Tensor(b'she had peeped into the book her sister was 

2024-09-01 13:19:55.652834: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Now we have word-aligned training data! To make things a bit easier on our model, let's also give it access to the next few characters:

In [14]:
shifts_features_amounts = [ 0, 8 ]
def add_lookahead_to_input(input):
	mapped_inputs = []
	for shift in shifts_features_amounts:
		mapped_inputs.append(shift_left_by(input, shift))
	return tuple(mapped_inputs)

def add_character_lookahead_features(input, output):
	return add_lookahead_to_input(input), output
sequences_with_lookahead = sequences.map(add_character_lookahead_features)


for sample_inputs, expected_sample_outputs in sequences_with_lookahead.take(1):
	for item in sample_inputs:
		print('Input:', text_from_ids_in(item))

Input: tf.Tensor(b'~illustration ~alices adventures in wonderland ~by lewis carroll ~the millennium fulcrum edition 30 ~contents', shape=(), dtype=string)
Input: tf.Tensor(b'ation ~alices adventures in wonderland ~by lewis carroll ~the millennium fulcrum edition 30 ~contents        ', shape=(), dtype=string)


In [15]:
dataset = sequences_with_lookahead

Our dataset now pairs inputs and labels!

**Note**: This [StackOverflow](https://stackoverflow.com/questions/53171885/how-to-use-tf-data-dataset-and-tf-keras-do-multi-inputs-and-multi-outpus) question, the documentation on [Dataset.zip](https://www.tensorflow.org/api_docs/python/tf/data/Dataset?hl=en#zip), and documentation on [Dataset.from_tensor_slices](https://www.tensorflow.org/api_docs/python/tf/data/Dataset?hl=en#from_tensor_slices) were helpful.

## Final preprocessing

We now shuffle the data, then do final batching.

In [16]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder = True).prefetch(tf.data.AUTOTUNE)
# Break into test and training data (no validation data for now).
# Inspired by https://stackoverflow.com/a/74609848.
validate_size = dataset.cardinality() * 1 // 8
dataset_validate = dataset.take(validate_size)
dataset = dataset.skip(validate_size)


## Building the model



In [17]:
# .get_vocabulary: Returns a list of the characters in use.
vocab_size_in = len(chars_to_ids_in.get_vocabulary())

vocab_size_out = len(chars_to_ids_out.get_vocabulary())
size_out = vocab_size_out # Output includes both commands and the command arg

EMBEDDING_DIM = 128
RNN_UNITS = 128 # Dimensionality of GRU output

print('vocab_size_in', vocab_size_in)
print('vocab_size_out', vocab_size_out)
print('EMBEDDING_DIM', EMBEDDING_DIM)
print('RNN_UNITS', RNN_UNITS)

vocab_size_in 55
vocab_size_out 9
EMBEDDING_DIM 128
RNN_UNITS 128


In [18]:
class LanguageModel(tf.keras.Model):
	def __init__(self):
		super().__init__()

		self.embedding_layer = tf.keras.layers.Embedding(vocab_size_in, EMBEDDING_DIM)
		self.merge_layer = tf.keras.layers.Concatenate()

		# return_sequences: Return the full sequence of outputs, rather than just the last.
		# return_state: Returns the last state in addition to the output
		self.gru_layer = tf.keras.layers.GRU(RNN_UNITS, return_sequences=True, return_state=True)
		self.dense_layer = tf.keras.layers.Dense(size_out, activation=tf.keras.activations.log_softmax)
	
	def call(self, inputs, states = None, return_state = False, training = False):
		x = self.merge_layer(
			list(map(lambda x: self.embedding_layer(x, training=training), inputs))
		)
		if states is None:
			batch_size, _ = inputs[0].shape
			states = self.gru_layer.get_initial_state(batch_size)

		x, states = self.gru_layer(x, initial_state = states, training = training)
		x = self.dense_layer(x, training = training)

		if return_state:
			return x, states
		else:
			return x

# We override tf.keras.Model to allow extracting the state later.

In [19]:
model = LanguageModel()

## Trying the (untrained) model


In [20]:
print(dataset.take(1))

for sample_inputs, expected_sample_outputs in dataset.take(1):
	sample_predictions = model(sample_inputs)
	print(sample_predictions.shape, ':: (batch_size, seq_length, num_commands)')

model.summary()

<_TakeDataset element_spec=((TensorSpec(shape=(64, None), dtype=tf.int64, name=None), TensorSpec(shape=(64, None), dtype=tf.int64, name=None)), TensorSpec(shape=(64, None), dtype=tf.int64, name=None))>
(64, 109, 9) :: (batch_size, seq_length, num_commands)


2024-09-01 13:19:59.405356: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Now let's inspect `sample_predictions`:

In [21]:
# Take one sample of the data, where sample_cmd_predictions[0] contains log probability
sampled_indices = tf.random.categorical(sample_predictions[0], num_samples = 1)
print(sampled_indices.shape)

# tf.squeeze: Removes dimensions of size 1.
sampled_indices = tf.squeeze(sampled_indices).numpy()
print(sampled_indices.shape)


(109, 1)
(109,)


In [22]:
input_text = text_from_ids_in(sample_inputs[0][0]).numpy().decode('utf-8')
print('Input:', input_text)

sampled_commands = ids_to_chars_out(sampled_indices).numpy()
reconstructed = reconstruct_from_labels(input_text, sampled_commands)
print('Next predictions:', reconstructed)

Input: fair play whatever i did that idea would bother me it was so tiresomely pertinacious that i resolved on requ 
Next predictions: 'f.a.iR ?p'l'a-y- .w'h,a[UNK]t:e[UNK]v'e'r. [UNK]i, .di,d- :tH?a-t. -i[UNK]d?e[UNK]a[UNK] .w[UNK]ou?l.d- -b[UNK]o?t,hE-r- ,m?e. ,i:t' ,w-a-s[UNK] -s,o ,ti[UNK]r,eS?o,m.e,l:y- PE:r[UNK]t-i,na:c.i[UNK]o,u-s? ,t[UNK]ha-t? 'i[UNK] [UNK]r,e,so-lV.e-d? ,o-n .rE'qu[UNK] 


Seemingly random output, as expected!

## Training!

We can train it now! It's a standard classification problem -- given the previous RNN state and the current character, predict the next character.

We're using the `SparseCategoricalCrossentropy` loss. See https://datascience.stackexchange.com/a/41923 and perhaps https://stats.stackexchange.com/a/420730 for commentary.

In [23]:
loss_fn = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

sample_batch_mean_loss = loss_fn(expected_sample_outputs, sample_predictions)
print('loss pre-training:', float(tf.exp(sample_batch_mean_loss)))

loss pre-training: 8.997515678405762


As expected, the initial loss is large.

Now we attach the loss function and an optimizer:

In [24]:
model.compile(optimizer='adam', loss=loss_fn)

We're just about ready to train.

In [25]:
# Set up checkpoints

checkpoint_path = './tf_model_checkpoints/checkpoint.weights.h5'

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
	filepath=checkpoint_path, monitor='val_loss', mode='min', save_weights_only=True, save_best_only=True
)

In [26]:
EPOCHS = 10

In [27]:

history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback], validation_data=dataset_validate)

Epoch 1/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m145s[0m 155ms/step - loss: 0.2683 - val_loss: 0.1879
Epoch 2/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 132ms/step - loss: 0.1506 - val_loss: 0.1680
Epoch 3/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 156ms/step - loss: 0.1340 - val_loss: 0.1584
Epoch 4/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m220s[0m 176ms/step - loss: 0.1232 - val_loss: 0.1492
Epoch 5/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m161s[0m 174ms/step - loss: 0.1173 - val_loss: 0.1471
Epoch 6/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 133ms/step - loss: 0.1147 - val_loss: 0.1434
Epoch 7/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m124s[0m 133ms/step - loss: 0.1103 - val_loss: 0.1434
Epoch 8/10
[1m910/910[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 136ms/step - loss: 0.1084 - val_loss: 0.1424
Epoch 9/

## Add punctuation

Let's try it!


In [28]:
def logits_to_text(original: str, predicted_logits):
	predicted_commands = tf.squeeze(tf.random.categorical(predicted_logits, num_samples=1))
	return reconstruct_from_labels(original, ids_to_chars_out(predicted_commands).numpy())

class Punctuator:
	def __init__(self, model: LanguageModel, temperature: float = 1.0):
		self.temperature = temperature
		self.model = model
		self.last_states = None

		# See https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor?hl=en
		skip_ids = chars_to_ids_out(['[UNK]'])
		out_vocab_size = len(chars_to_ids_out.get_vocabulary())
		print(out_vocab_size, skip_ids)
		self.prediction_mask = tf.sparse.to_dense(tf.sparse.reorder(tf.SparseTensor(
			indices=tf.reshape(skip_ids, [-1, 1]), # shape [N, ndims]. This specifies the nonzero elements' indices.
			values=[float('-inf')] * len(skip_ids),
			dense_shape=[out_vocab_size],
		)))
	
	def step(self, input: str|Any):
		# Data conversion
		input_chars = tf.strings.unicode_split(input, 'UTF-8')
		input_ids = chars_to_ids_in(input_chars)
		reshape_vec = lambda x: tf.reshape(x, [1, -1])
		inputs = tuple(map(reshape_vec, add_lookahead_to_input(input_ids)))

		# Run it!
		# predicted.shape is [batch, char, next_char_logits]
		predicted_commands_raw, states = self.model(inputs=inputs, states=self.last_states, return_state=True)
		self.last_states = states

		predicted_logits = (predicted_commands_raw[-1, :, :]) / self.temperature
		predicted_logits += self.prediction_mask # Sets some weights to -inf

		return predicted_logits

	def step_and_predict(self, original: str):
		return logits_to_text(self.step(self, original))

class CombinedPunctuator:
	def __init__(self, model):
		self.punctuators = [
			Punctuator(model, temperature = 1.0),
			Punctuator(model, temperature = 0.8),
			Punctuator(model, temperature = 0.4),
		]

	def step(self, text: str):
		step_size = seq_length - 1
		num_punctuators = len(self.punctuators)
		text_length = len(text)

		all_logits = np.zeros([ text_length, vocab_size_out ])

		for i in range(0, text_length, step_size):
			shift = -5
			for punctuator in self.punctuators:
				from_idx = max(0, i + shift)
				to_idx = min(text_length - 1, i + step_size + shift)

				text_shifted = text[from_idx:to_idx]
				while len(text_shifted) < step_size:
					text_shifted += ' '
				
				predicted_logits = punctuator.step(text_shifted)
				sliced_prediction = predicted_logits[0:to_idx-from_idx]

				all_logits[from_idx:to_idx] = all_logits[from_idx:to_idx] + (sliced_prediction / num_punctuators)
				shift += 5
					
		return logits_to_text(text, all_logits)

In [29]:
%%time

punctuator = CombinedPunctuator(model)

text = '''~the punctuator is a small machine learning model for punctuation restoration
at present its performance is rather poor i do hope however that with additional training
and very little rearchitecting the punctuator will be a usable and fast model
~i suspect that i will need to look into the neural machine translation tutorials will the
approach taken by the example seq2seq model for spanish to english translation be sufficient
will i need to learn about transformers
~note there are concerns about model size in addition to performance as models will need to be run
on mobile devices
~the punctuator was trained on old books does prose of a similar style work better heres some text from
frankenstein
~although it denied warmth safie agatha and felix departed on a long country walk
~interesting perhaps it isnt any better in that case how unfortunate 
'''
text = text.replace('\n', ' ')

display(Markdown(punctuator.step(text)))


9 tf.Tensor([0], shape=(1,), dtype=int64)
9 tf.Tensor([0], shape=(1,), dtype=int64)
9 tf.Tensor([0], shape=(1,), dtype=int64)


~The punctuator is a small machine, learning Model for punctuation restoration. at present, its performance is rather poor I do hope However, that with additional training, and very little rearchitecting the punctuator will be a usable and fast Model. ~I suspect that I will need to look into the neural machine, translation tutorials will the approach taken by the example seq2seq Model for spanish to English translation, be sufficient will i need to Learn about transformers. ~Note there are concerns about Model size in addition to performance as models will need to be run on mobile Devices. ~The punctuatOr was trained on old books does prose of a similar style work better Heres some text from frankenstein. ~Although it Denied Warmth, Safie Agatha and felix departed on a long country walk. ~Interesting perhaps, it Isnt any better in that case, How unfortunate : 

CPU times: user 5.42 s, sys: 475 ms, total: 5.9 s
Wall time: 5.92 s


That isn't working very well. For comparison, let's try an example from the training data:

In [30]:
orig = tf.strings.reduce_join(context_raw[4400:5500]).numpy().decode('utf-8')
print(orig)
display(Markdown(punctuator.step(orig)))

up somewhere ~down down down there was nothing else to do so alice soon began talking again dinahll miss me very much tonight i should think dinah was the cat i hope theyll remember her saucer of milk at teatime dinah my dear i wish you were down here with me there are no mice in the air im afraid but you might catch a bat and thats very like a mouse you know but do cats eat bats i wonder and here alice began to get rather sleepy and went on saying to herself in a dreamy sort of way do cats eat bats do cats eat bats and sometimes do bats eat cats for you see as she couldnt answer either question it didnt much matter which way she put it she felt that she was dozing off and had just begun to dream that she was walking hand in hand with dinah and saying to her very earnestly now dinah tell me the truth did you ever eat a bat when suddenly thump thump down she came upon a heap of sticks and dry leaves and the fall was over ~alice was not a bit hurt and she jumped up on to her feet in a mo

up somewhere? ~Down down down there was nothing else to do so, alice soon began talking again dinahll miss me very much tonight, I should think dinah was the cat, I hope theyll remember her saucer of milk at teatime dinah my dear I wish you were down here with me there are no mice in the air, I'm afraid, but you might catch a bat and thats very like a mouse. you know but do cats eat bats I Wonder and here alice began to get rather sleepy, and went on saying to herself in a dreamy sort of way. Do Cats eat, bats do Cats eat Bats, and sometimes do bats eat cats for you see, as she couldn't answer either question. it didn't much Matter which way she put it she felt that she was dozing off, and had just begun to dream that she was walking hand in hand with dinah, and saying to her very earnestly Now dinah tell me, the truth did you ever eat a bat When suddenly thump thump down she came upon a heap of sticks, and dry leaves, and the fall was over. ~Alice was not a bit hurt, and she jumped up on to her feet in a moment she looked up but it was all dark overhead before her was another long passage and the white r[UNK]a