In [1]:
# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

# TensorFlow ≥2.0 is required
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"

In [2]:
# Common imports
import numpy as np
import os

# to make this notebook's output stable across runs
np.random.seed(42)
tf.random.set_seed(42)


In [3]:
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)


In [4]:
shakespeare_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
filepath = keras.utils.get_file("shakespeare.txt", shakespeare_url)
with open(filepath) as f:
    shakespeare_text = f.read()

In [5]:
type(shakespeare_text)


str

In [6]:
print(shakespeare_text[:148])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?



In [7]:
"".join(sorted(set(shakespeare_text.lower())))

"\n !$&',-.3:;?abcdefghijklmnopqrstuvwxyz"

In [8]:
tokenizer = keras.preprocessing.text.Tokenizer(char_level=True)
tokenizer.fit_on_texts(shakespeare_text)

In [9]:
tokenizer.word_index

{' ': 1,
 'e': 2,
 't': 3,
 'o': 4,
 'a': 5,
 'i': 6,
 'h': 7,
 's': 8,
 'r': 9,
 'n': 10,
 '\n': 11,
 'l': 12,
 'd': 13,
 'u': 14,
 'm': 15,
 'y': 16,
 'w': 17,
 ',': 18,
 'c': 19,
 'f': 20,
 'g': 21,
 'b': 22,
 'p': 23,
 ':': 24,
 'k': 25,
 'v': 26,
 '.': 27,
 "'": 28,
 ';': 29,
 '?': 30,
 '!': 31,
 '-': 32,
 'j': 33,
 'q': 34,
 'x': 35,
 'z': 36,
 '3': 37,
 '&': 38,
 '$': 39}

In [10]:
tokenizer.texts_to_sequences(["Eto"])

[[2, 3, 4]]

In [11]:
tokenizer.sequences_to_texts([[2, 3, 4]])

['e t o']

In [12]:
max_id = len(tokenizer.word_index) # number of distinct characters
dataset_size = tokenizer.document_count # total number of characters

print(f'number of distinct characters: {max_id}')
print(f'total number of characters: {dataset_size}')

number of distinct characters: 39
total number of characters: 1115394


In [13]:
[encoded] = np.array(tokenizer.texts_to_sequences([shakespeare_text])) - 1
train_size = dataset_size * 90 // 100
dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])

Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



2022-03-01 10:22:45.583637: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-03-01 10:22:45.583774: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [14]:
n_steps = 10
window_length = n_steps + 1 # target = input shifted 1 character ahead
dataset = dataset.window(window_length, shift=1, drop_remainder=True)

In [15]:
dataset = dataset.flat_map(lambda window: window.batch(window_length))

In [16]:
batch_size = 32
dataset = dataset.shuffle(10000).batch(batch_size)
dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))

In [17]:
dataset = dataset.map(
    lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))

In [18]:
dataset = dataset.prefetch(1)

In [19]:
for X_batch, Y_batch in dataset.take(1):
    print(X_batch.shape, Y_batch.shape)

2022-03-01 10:22:56.563725: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


(32, 10, 39) (32, 10)


In [20]:
model = keras.models.Sequential([
    keras.layers.GRU(128, return_sequences=True, input_shape=[None, max_id],
                     #dropout=0.2, recurrent_dropout=0.2),
                     dropout=0.2),
    keras.layers.GRU(128, return_sequences=True,
                     #dropout=0.2, recurrent_dropout=0.2),
                     dropout=0.2),
    keras.layers.TimeDistributed(keras.layers.Dense(max_id,
                                                    activation="softmax"))
])
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
history = model.fit(dataset, epochs=3)

Epoch 1/3


2022-03-01 10:23:02.101173: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-03-01 10:23:02.834677: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-03-01 10:23:03.402270: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-03-01 10:23:05.259809: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


      7/Unknown - 6s 18ms/step - loss: 3.6185

2022-03-01 10:23:06.352695: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 2/3
Epoch 3/3


In [21]:
def preprocess(texts):
    X = np.array(tokenizer.texts_to_sequences(texts)) - 1
    return tf.one_hot(X, max_id)

In [22]:
X_new = preprocess(["How are yo"])
#Y_pred = model.predict_classes(X_new)
Y_pred = np.argmax(model(X_new), axis=-1)
tokenizer.sequences_to_texts(Y_pred + 1)[0][-1] # 1st sentence, last char

' '

In [23]:
tf.random.set_seed(42)

tf.random.categorical([[np.log(0.5), np.log(0.4), np.log(0.1)]], num_samples=40).numpy()

array([[0, 1, 0, 2, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 2, 1, 0, 2, 1,
        0, 1, 2, 1, 1, 1, 2, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 2]])

In [24]:
def next_char(text, temperature=1):
    X_new = preprocess([text])
    y_proba = model(X_new)[0, -1:, :]
    rescaled_logits = tf.math.log(y_proba) / temperature
    char_id = tf.random.categorical(rescaled_logits, num_samples=1) + 1
    return tokenizer.sequences_to_texts(char_id.numpy())[0]

In [28]:
tf.random.set_seed(42)

next_char("How are yo", temperature=1)

't'

In [29]:
def complete_text(text, n_chars=50, temperature=1):
    for _ in range(n_chars):
        text += next_char(text, temperature)
    return text

In [30]:
tf.random.set_seed(42)

print(complete_text("t", temperature=0.2))

the   o       eoe       t       e        eee h     


In [31]:
print(complete_text("t", temperature=1))

thin ' et d
.enb.ecut,c 
nct rsiweeww,iee re lrn mi


In [32]:
print(complete_text("t", temperature=2))

th nh w! '
dehsgjarogao fdro; o bgml ohprfngsb rvap


In [33]:
print(complete_text("t", temperature=.1))

the   o                 e            e             


In [34]:
tf.random.set_seed(42)

In [47]:
dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])
dataset = dataset.window(window_length, shift=n_steps, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(window_length))
dataset = dataset.batch(1)
dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))
dataset = dataset.map(
    lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))
dataset = dataset.prefetch(1)

In [48]:
model = keras.models.Sequential([
    keras.layers.GRU(128, return_sequences=True, stateful=True,
                     dropout=0.2, recurrent_dropout=0.2,
                     batch_input_shape=[batch_size, None, max_id]),
    keras.layers.GRU(128, return_sequences=True, stateful=True,
                     dropout=0.2, recurrent_dropout=0.2),
    keras.layers.TimeDistributed(keras.layers.Dense(max_id,
                                                    activation="softmax"))
])



In [49]:
class ResetStatesCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs):
        self.model.reset_states()

In [50]:
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
history = model.fit(dataset, epochs=10,
                    callbacks=[ResetStatesCallback()])

Epoch 1/10


2022-03-01 12:06:16.760999: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


InvalidArgumentError: Graph execution error:

Detected at node 'sequential_4/gru_8/TensorArrayUnstack/TensorListFromTensor' defined at (most recent call last):
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 677, in start
      self.io_loop.start()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 367, in dispatch_shell
      await result
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2768, in run_cell
      result = self._run_cell(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2814, in _run_cell
      return runner(coro)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3012, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3191, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/d2/8_sjl_k56djc0380jpk4qxrc0000gn/T/ipykernel_13340/1118436172.py", line 2, in <module>
      history = model.fit(dataset, epochs=10,
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/sequential.py", line 374, in call
      return super(Sequential, self).call(inputs, training=training, mask=mask)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/layers/recurrent.py", line 679, in __call__
      return super(RNN, self).__call__(inputs, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/layers/recurrent_v2.py", line 431, in call
      last_output, outputs, states = backend.rnn(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/backend.py", line 4586, in rnn
      input_ta = tuple(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/backend.py", line 4587, in <genexpr>
      ta.unstack(input_) if not go_backwards else ta
Node: 'sequential_4/gru_8/TensorArrayUnstack/TensorListFromTensor'
Detected at node 'sequential_4/gru_8/TensorArrayUnstack/TensorListFromTensor' defined at (most recent call last):
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 677, in start
      self.io_loop.start()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 471, in dispatch_queue
      await self.process_one()
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 460, in process_one
      await dispatch(*args)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 367, in dispatch_shell
      await result
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 662, in execute_request
      reply_content = await reply_content
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 360, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 532, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2768, in run_cell
      result = self._run_cell(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2814, in _run_cell
      return runner(coro)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3012, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3191, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/d2/8_sjl_k56djc0380jpk4qxrc0000gn/T/ipykernel_13340/1118436172.py", line 2, in <module>
      history = model.fit(dataset, epochs=10,
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
      y_pred = self(x, training=True)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/sequential.py", line 374, in call
      return super(Sequential, self).call(inputs, training=training, mask=mask)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
      return self._run_internal_graph(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/layers/recurrent.py", line 679, in __call__
      return super(RNN, self).__call__(inputs, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/layers/recurrent_v2.py", line 431, in call
      last_output, outputs, states = backend.rnn(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/backend.py", line 4586, in rnn
      input_ta = tuple(
    File "/Users/westonshuken/tensorflow-test/env/lib/python3.8/site-packages/keras/backend.py", line 4587, in <genexpr>
      ta.unstack(input_) if not go_backwards else ta
Node: 'sequential_4/gru_8/TensorArrayUnstack/TensorListFromTensor'
2 root error(s) found.
  (0) INVALID_ARGUMENT:  Specified a list with shape [32,39] from a tensor with shape [1,39]
	 [[{{node sequential_4/gru_8/TensorArrayUnstack/TensorListFromTensor}}]]
	 [[sequential_4/gru_8/TensorArrayUnstack/TensorListFromTensor/_58]]
  (1) INVALID_ARGUMENT:  Specified a list with shape [32,39] from a tensor with shape [1,39]
	 [[{{node sequential_4/gru_8/TensorArrayUnstack/TensorListFromTensor}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_450934]

In [41]:
stateless_model = keras.models.Sequential([
    keras.layers.GRU(128, return_sequences=True, input_shape=[None, max_id]),
    keras.layers.GRU(128, return_sequences=True),
    keras.layers.TimeDistributed(keras.layers.Dense(max_id,
                                                    activation="softmax"))
])

In [42]:
stateless_model.build(tf.TensorShape([None, None, max_id]))

In [43]:
stateless_model.set_weights(model.get_weights())
model = stateless_model

In [44]:
tf.random.set_seed(42)

print(complete_text("t"))

tospgind l wass'dd sialo'd do raperrdeder.
why seal
