<a href="https://colab.research.google.com/github/zkhandker/rupi-kaur/blob/main/Nikhil_Scratch1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Playing Around with Word2Vec

In [1]:
#Install word2vec and import usual libraries

!pip install -q tqdm

import io
import itertools
import numpy as np
import os
import re
import string
import tensorflow as tf
import tqdm

from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Activation, Dense, Dot, Embedding, Flatten, GlobalAveragePooling1D, Reshape
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

SEED = 42 #no idea what these 2 lines do, but they have it in the word2vec documentation
AUTOTUNE = tf.data.AUTOTUNE

Let's work through the example in the word2vec documentation.

In [2]:
#let's look at an example sentence.

sentence = "The wide road shimmered in the hot sun"
tokens = list(sentence.lower().split())
print(len(tokens))
print(tokens)

8
['the', 'wide', 'road', 'shimmered', 'in', 'the', 'hot', 'sun']


In [3]:
#map each word to an index

vocab, index = {}, 1 # start indexing from 1
vocab['<pad>'] = 0 # add a padding token 
for token in tokens:
  if token not in vocab: 
    vocab[token] = index
    # print(token)
    index += 1
vocab_size = len(vocab)
print(vocab)

#and its inverse
inverse_vocab = {index: token for token, index in vocab.items()}
print(inverse_vocab)

{'<pad>': 0, 'the': 1, 'wide': 2, 'road': 3, 'shimmered': 4, 'in': 5, 'hot': 6, 'sun': 7}
{0: '<pad>', 1: 'the', 2: 'wide', 3: 'road', 4: 'shimmered', 5: 'in', 6: 'hot', 7: 'sun'}


In [4]:
#we can map the sentence to a vector
#note that "the" is repeated in the sentence
example_sequence = [vocab[word] for word in tokens]
print(example_sequence)

[1, 2, 3, 4, 5, 1, 6, 7]


In [6]:
#now we need to come up with "skip grams"
window_size = 2
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
      example_sequence, 
      vocabulary_size=vocab_size,
      window_size=window_size,
      negative_samples=0)
print(len(positive_skip_grams))

26


In [7]:
print(positive_skip_grams)
for target, context in positive_skip_grams[:5]:
  print(f"({target}, {context}): ({inverse_vocab[target]}, {inverse_vocab[context]})")

[[7, 1], [4, 1], [4, 3], [3, 4], [5, 1], [3, 5], [2, 4], [4, 5], [1, 7], [6, 5], [2, 3], [6, 1], [3, 2], [2, 1], [3, 1], [5, 3], [1, 6], [1, 2], [4, 2], [6, 7], [5, 4], [7, 6], [5, 6], [1, 5], [1, 4], [1, 3]]
(7, 1): (sun, the)
(4, 1): (shimmered, the)
(4, 3): (shimmered, road)
(3, 4): (road, shimmered)
(5, 1): (in, the)


In [8]:
# Get target and context words for one positive skip-gram.
target_word, context_word = positive_skip_grams[0]

# Set the number of negative samples per positive context. 
num_ns = 4

context_class = tf.reshape(tf.constant(context_word, dtype="int64"), (1, 1))
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
    true_classes=context_class, # class that should be sampled as 'positive'
    num_true=1, # each positive skip-gram has 1 positive context class
    num_sampled=num_ns, # number of negative context words to sample
    unique=True, # all the negative samples should be unique
    range_max=vocab_size, # pick index of the samples from [0, vocab_size]
    seed=SEED, # seed for reproducibility
    name="negative_sampling" # name of this operation
)
print(negative_sampling_candidates)
print([inverse_vocab[index.numpy()] for index in negative_sampling_candidates])

tf.Tensor([2 1 4 3], shape=(4,), dtype=int64)
['wide', 'the', 'shimmered', 'road']


In [9]:
# Add a dimension so you can use concatenation (on the next step).
negative_sampling_candidates = tf.expand_dims(negative_sampling_candidates, 1)

# Concat positive context word with negative sampled words.
context = tf.concat([context_class, negative_sampling_candidates], 0)

# Label first context word as 1 (positive) followed by num_ns 0s (negative).
label = tf.constant([1] + [0]*num_ns, dtype="int64") 

# Reshape target to shape (1,) and context and label to (num_ns+1,).
target = tf.squeeze(target_word)
context = tf.squeeze(context)
label =  tf.squeeze(label)

In [10]:
print(f"target_index    : {target}")
print(f"target_word     : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words   : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label           : {label}")

target_index    : 7
target_word     : sun
context_indices : [1 2 1 4 3]
context_words   : ['the', 'wide', 'the', 'shimmered', 'road']
label           : [1 0 0 0 0]


In [11]:
print(f"target  :", target)
print(f"context :", context )
print(f"label   :", label )

target  : tf.Tensor(7, shape=(), dtype=int32)
context : tf.Tensor([1 2 1 4 3], shape=(5,), dtype=int64)
label   : tf.Tensor([1 0 0 0 0], shape=(5,), dtype=int64)


In [16]:
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=10)
print(sampling_table)

[0.00315225 0.00315225 0.00547597 0.00741556 0.00912817 0.01068435
 0.01212381 0.01347162 0.01474487 0.0159558 ]


In [17]:
# Generates skip-gram pairs with negative sampling for a list of sequences
# (int-encoded sentences) based on window size, number of negative samples
# and vocabulary size.
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for vocab_size tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in dataset.
  for sequence in tqdm.tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence, 
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)

    # Iterate over each positive skip-gram pair to produce training examples 
    # with positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.expand_dims(
          tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1, 
          num_sampled=num_ns, 
          unique=True, 
          range_max=vocab_size, 
          seed=SEED, 
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      negative_sampling_candidates = tf.expand_dims(
          negative_sampling_candidates, 1)

      context = tf.concat([context_class, negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [119]:
#get data
path_to_file = tf.keras.utils.get_file('combined2', 'https://raw.githubusercontent.com/zkhandker/rupi-kaur/main/data/combined2.txt')

In [120]:
with open(path_to_file) as f: 
  lines = f.read().splitlines()
for line in lines[:20]:
  print(line)

how is it so easy for you
to be kind to people he asked

milk and honey dripped
from my lips as i answered

cause people have not
been kind to me
the first boy that kissed me
held my shoulders down
like the handlebars of
the first bicycle
he ever rode
i was five

he had the smell of
starvation on his lips
which he picked up from
his father feasting on his mother at 4 a.m.



In [121]:
print(path_to_file)

/root/.keras/datasets/combined2


In [122]:
text_ds = tf.data.TextLineDataset(path_to_file).filter(lambda x: tf.cast(tf.strings.length(x), bool))


In [123]:
# We create a custom standardization function to lowercase the text and 
# remove punctuation.
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  return tf.strings.regex_replace(lowercase,
                                  '[%s]' % re.escape(string.punctuation), '')

# Define the vocabulary size and number of words in a sequence.
vocab_size = 4096
sequence_length = 10

# Use the text vectorization layer to normalize, split, and map strings to
# integers. Set output_sequence_length length to pad all samples to same length.
vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode='int',
    output_sequence_length=sequence_length)

In [124]:
vectorize_layer.adapt(text_ds.batch(1024))
# import sys
# sys.setdefaultencoding('ISO-8859-1')


In [125]:
# Save the created vocabulary for reference.
inverse_vocab = vectorize_layer.get_vocabulary()
print(inverse_vocab[:20])

['', '[UNK]', 'you', 'the', 'i', 'to', 'and', 'of', 'is', 'it', 'me', 'your', 'my', 'a', 'in', 'not', 'that', 'when', 'with', 'for']


In [127]:
def vectorize_text(text):
  text = tf.expand_dims(text, -1)
  return tf.squeeze(vectorize_layer(text))

# Vectorize the data in text_ds.
text_vector_ds = text_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_layer).unbatch()


sequences = list(text_vector_ds.as_numpy_iterator())
print(len(sequences))

2568


In [128]:
for seq in sequences[:5]:
  print(f"{seq} => {[inverse_vocab[i] for i in seq]}")

[ 21   8   9  25 180  19   2   0   0   0] => ['how', 'is', 'it', 'so', 'easy', 'for', 'you', '', '', '']
[   5   23  353    5   80   31 1330    0    0    0] => ['to', 'be', 'kind', 'to', 'people', 'he', 'asked', '', '', '']
[ 499    6  149 1192    0    0    0    0    0    0] => ['milk', 'and', 'honey', 'dripped', '', '', '', '', '', '']
[  72   12   90   36    4 1340    0    0    0    0] => ['from', 'my', 'lips', 'as', 'i', 'answered', '', '', '', '']
[67 80 27 15  0  0  0  0  0  0] => ['cause', 'people', 'have', 'not', '', '', '', '', '', '']


In [129]:
targets, contexts, labels = generate_training_data(
    sequences=sequences, 
    window_size=2, 
    num_ns=4, 
    vocab_size=vocab_size, 
    seed=SEED)
print(len(targets), len(contexts), len(labels))

100%|██████████| 2568/2568 [00:00<00:00, 6812.13it/s]

2896 2896 2896





In [130]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
print(dataset)

<BatchDataset shapes: (((1024,), (1024, 5, 1)), (1024, 5)), types: ((tf.int32, tf.int64), tf.int64)>


In [131]:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
print(dataset)

<PrefetchDataset shapes: (((1024,), (1024, 5, 1)), (1024, 5)), types: ((tf.int32, tf.int64), tf.int64)>


In [132]:

class Word2Vec(Model):
  def __init__(self, vocab_size, embedding_dim):
    super(Word2Vec, self).__init__()
    self.target_embedding = Embedding(vocab_size, 
                                      embedding_dim,
                                      input_length=1,
                                      name="w2v_embedding", )
    self.context_embedding = Embedding(vocab_size, 
                                       embedding_dim, 
                                       input_length=num_ns+1)
    self.dots = Dot(axes=(3,2))
    self.flatten = Flatten()

  def call(self, pair):
    target, context = pair
    we = self.target_embedding(target)
    ce = self.context_embedding(context)
    dots = self.dots([ce, we])
    return self.flatten(dots)

In [133]:
def custom_loss(x_logit, y_true):
      return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=y_true)

In [134]:
embedding_dim = 128
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [135]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

In [136]:
word2vec.fit(dataset, epochs=20, callbacks=[tensorboard_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7ff34570c750>

In [138]:
weights = word2vec.get_layer('w2v_embedding').get_weights()[0]
vocab = vectorize_layer.get_vocabulary()

In [139]:
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

for index, word in enumerate(vocab):
  if  index == 0: continue # skip 0, it's padding.
  vec = weights[index] 
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
  out_m.write(word + "\n")
out_v.close()
out_m.close()

In [140]:
try:
  from google.colab import files
  files.download('vectors.tsv')
  files.download('metadata.tsv')
except Exception as e:
  pass

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Now generate text - try doing it by characters first?

In [152]:
# Download the dataset
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print('Length of text: {} characters'.format(len(text)))

Length of text: 66688 characters


In [157]:
print(text[:250])

how is it so easy for you
to be kind to people he asked

milk and honey dripped
from my lips as i answered

cause people have not
been kind to me
the first boy that kissed me
held my shoulders down
like the handlebars of
the first bicycle
he ever rod


In [158]:
vocab = sorted(set(text))
print('{} unique characters'.format(len(vocab)))

48 unique characters


In [159]:
example_texts = ['abcdefg', 'xyz']

chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars

<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

In [161]:
from tensorflow.keras.layers.experimental import preprocessing

ids_from_chars = preprocessing.StringLookup(
    vocabulary=list(vocab))

In [147]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (32, None, 256)           12288     
_________________________________________________________________
lstm (LSTM)                  (32, None, 1024)          5246976   
_________________________________________________________________
dense (Dense)                (32, None, 48)            49200     
Total params: 5,308,464
Trainable params: 5,308,464
Non-trainable params: 0
_________________________________________________________________


In [150]:
print(vectorize_text(text))

tf.Tensor([ 21   8   9  25 180  19   2   5  23 353], shape=(10,), dtype=int64)


In [162]:
chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True)

In [164]:
ids = ids_from_chars(chars)
ids
chars = chars_from_ids(ids)
chars

<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

In [165]:
def text_from_ids(ids):
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

In [166]:
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids

<tf.Tensor: shape=(66688,), dtype=int64, numpy=array([29, 36, 44, ..., 26, 39,  2])>

In [168]:
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode('utf-8'))

seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)

h
o
w
 
i
s
 
i
t
 


In [169]:
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

for seq in sequences.take(1):
  print(chars_from_ids(seq))

tf.Tensor(
[b'h' b'o' b'w' b' ' b'i' b's' b' ' b'i' b't' b' ' b's' b'o' b' ' b'e'
 b'a' b's' b'y' b' ' b'f' b'o' b'r' b' ' b'y' b'o' b'u' b'\n' b't' b'o'
 b' ' b'b' b'e' b' ' b'k' b'i' b'n' b'd' b' ' b't' b'o' b' ' b'p' b'e'
 b'o' b'p' b'l' b'e' b' ' b'h' b'e' b' ' b'a' b's' b'k' b'e' b'd' b'\n'
 b'\n' b'm' b'i' b'l' b'k' b' ' b'a' b'n' b'd' b' ' b'h' b'o' b'n' b'e'
 b'y' b' ' b'd' b'r' b'i' b'p' b'p' b'e' b'd' b'\n' b'f' b'r' b'o' b'm'
 b' ' b'm' b'y' b' ' b'l' b'i' b'p' b's' b' ' b'a' b's' b' ' b'i' b' '
 b'a' b'n' b's'], shape=(101,), dtype=string)


In [170]:
for seq in sequences.take(5):
  print(text_from_ids(seq).numpy())

b'how is it so easy for you\nto be kind to people he asked\n\nmilk and honey dripped\nfrom my lips as i ans'
b'wered\n\ncause people have not\nbeen kind to me\nthe first boy that kissed me\nheld my shoulders down\nlike'
b' the handlebars of\nthe first bicycle\nhe ever rode\ni was five\n\nhe had the smell of\nstarvation on his l'
b'ips\nwhich he picked up from\nhis father feasting on his mother at 4 a.m.\n\nhe was the first boy\nto teac'
b'h me my body was\nfor giving to those that wanted\nthat i should feel anything\nless than whole\n\nand my '


In [171]:
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

In [172]:
dataset = sequences.map(split_input_target)
for input_example, target_example in  dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Target:", text_from_ids(target_example).numpy())

Input : b'how is it so easy for you\nto be kind to people he asked\n\nmilk and honey dripped\nfrom my lips as i an'
Target: b'ow is it so easy for you\nto be kind to people he asked\n\nmilk and honey dripped\nfrom my lips as i ans'


In [173]:
# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

dataset

<PrefetchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>

In [175]:
# Length of the vocabulary in chars
vocab_size = len(vocab)

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024

class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True, 
                                   return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else: 
      return x

model = MyModel(
    # Be sure the vocabulary size matches the `StringLookup` layers.
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

In [176]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

(64, 100, 50) # (batch_size, sequence_length, vocab_size)


In [177]:
model.summary()


Model: "my_model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_3 (Embedding)      multiple                  12800     
_________________________________________________________________
gru_1 (GRU)                  multiple                  3938304   
_________________________________________________________________
dense_2 (Dense)              multiple                  51250     
Total params: 4,002,354
Trainable params: 4,002,354
Non-trainable params: 0
_________________________________________________________________


In [178]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()
print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

Input:
 b'heir\nbody and leave after saying you will find better than me.\n\nyou will stand there naked with half'

Next Char Predictions:
 b"q:'xdv47nw,\xc3\xadheuo1\n5zcqwwmkma9i4rx0*[UNK]'bk6-40vexaa3u:ytt5j8. ck0hp(*t40jr6t0g\ni9-oc)na(c0rxn9q80wo*."


In [179]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss:        ", mean_loss)

model.compile(optimizer='adam', loss=loss)

# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

EPOCHS = 20

history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])


Prediction shape:  (64, 100, 50)  # (batch_size, sequence_length, vocab_size)
Mean loss:         3.914017
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [180]:
class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature=temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "" or "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['','[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices = skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())]) 
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits] 
    predicted_logits, states =  self.model(inputs=input_ids, states=states, 
                                          return_state=True)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "" or "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states

In [182]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)


In [186]:
start = time.time()
states = None
next_char = tf.constant(['i felt so sad'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()

print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)

print(f"\nRun time: {end - start}")

i felt so sades like
tiches el is sraythanesting itín thion your nowe you hlveny wintít love allfune hearing
if thee the kurt
the warplops of on
you whinghate
whee havee
lowely of heasize
now you stiflinn spuns you a diks pealf.

frinc reeselld noft ckoplongafters
full your lageders
whint it saly han
ded an fach h ar fould you

ime fore. iver ablits
you hav skerinnawlige a loveand

of lly you an bets end yours. you.
so ken whthad the canin elousande it still an thimemunstill

hesth wi stply to comary onavem on owhel
i hevs waro in so
ker macas bivey


- sond wirging liver irs enound de rat in therr boingired clo ke potelf the stow tood arous whonged
bow lite a wryoul
cum of you wers the you atre
toll whowe in thay no
tee hell
awe you
sull. one set
quesello gayte


- chaed if anching. in the mbe ters it sitce
sull kads deats. sind bowthinim is it you pole
is
i chould
1x3 dowerfuregilg
but de
never wirl the roillkes the lilesty im hund thenísts
like thand you apl o fthecakring
by endith 