<a href="https://colab.research.google.com/github/prokaj/elte-python/blob/main/tiny_shakespeare.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip freeze | grep -e 'optax' -e 'haiku' -e 'jax' -e 'tensorflow'


dm-haiku==0.0.9
jax==0.3.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.25+cuda11.cudnn805-cp38-cp38-manylinux2014_x86_64.whl
optax==0.1.4
tensorflow==2.9.2
tensorflow-datasets==4.6.0
tensorflow-estimator==2.9.0
tensorflow-gcs-config==2.9.1
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.28.0
tensorflow-metadata==1.11.0
tensorflow-probability==0.17.0


In [3]:
! pip install dm-haiku
! pip install optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tiny Shakespeare as a language modelling dataset."""

from typing import Iterator, Mapping

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

Batch = Mapping[str, np.ndarray]
NUM_CHARS = 128


def dataset_load(
    split: tfds.Split,
    *,
    batch_size: int,
    sequence_length: int,
) -> Iterator[Batch]:
  """Creates the Tiny Shakespeare dataset as a character modelling task."""

  def preprocess_fn(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
    x = x['text']
    x = tf.strings.unicode_split(x, 'UTF-8')
    x = tf.squeeze(tf.io.decode_raw(x, tf.uint8), axis=-1)
    x = tf.cast(x, tf.int32)
    return {'input': x[:-1], 'target': x[1:]}

  ds = tfds.load(name='tiny_shakespeare', split=split)
  ds = ds.map(preprocess_fn)
  ds = ds.unbatch()
  ds = ds.batch(sequence_length, drop_remainder=True)
  ds = ds.shuffle(100)
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.map(lambda b: tf.nest.map_structure(tf.transpose, b))  # Time major.

  return iter(tfds.as_numpy(ds))


def decode(x: np.ndarray) -> str:
  return ''.join([chr(x) for x in x])


def encode(x: str) -> np.ndarray:
  return np.array([ord(s) for s in x])

In [3]:
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Character-level language modelling with a recurrent network in JAX."""

from typing import Any, NamedTuple

from absl import logging
from absl import flags

import haiku as hk

import jax
from jax import lax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds

TRAIN_BATCH_SIZE = flags.DEFINE_integer('train_batch_size', 32, '')
EVAL_BATCH_SIZE = flags.DEFINE_integer('eval_batch_size', 1000, '')
SEQUENCE_LENGTH = flags.DEFINE_integer('sequence_length', 128, '')
HIDDEN_SIZE = flags.DEFINE_integer('hidden_size', 256, '')
SAMPLE_LENGTH = flags.DEFINE_integer('sample_length', 128, '')
LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-3, '')
TRAINING_STEPS = flags.DEFINE_integer('training_steps', 50_000, '')
EVALUATION_INTERVAL = flags.DEFINE_integer('evaluation_interval', 1000, '')
SAMPLING_INTERVAL = flags.DEFINE_integer('sampling_interval', 1000, '')
SEED = flags.DEFINE_integer('seed', 42, '')


In [4]:
from absl import app


In [5]:
app.parse_flags_with_usage(['main', '-v', '1'])
#flags.ArgumentParser().parse('')
TRAIN_BATCH_SIZE.value

32

In [6]:
flags.FLAGS.flag_values_dict()

{'logtostderr': False,
 'alsologtostderr': False,
 'log_dir': '',
 'v': 1,
 'verbosity': 1,
 'logger_levels': {},
 'stderrthreshold': 'fatal',
 'showprefixforinfo': True,
 'run_with_pdb': False,
 'pdb_post_mortem': False,
 'pdb': False,
 'run_with_profiling': False,
 'profile_file': None,
 'use_cprofile_for_profiling': True,
 'only_check_args': False,
 'op_conversion_fallback_to_while_loop': True,
 'runtime_oom_exit': True,
 'hbm_oom_exit': True,
 'test_srcdir': '',
 'test_tmpdir': '/tmp/absl_testing',
 'test_random_seed': 301,
 'test_randomize_ordering_seed': '',
 'xml_output_file': '',
 'tfds_debug_list_dir': False,
 'wikipedia_auto_select_flume_mode': True,
 'chex_n_cpu_devices': 1,
 'chex_assert_multiple_cpu_devices': False,
 'chex_skip_pmap_variant_if_single_device': True,
 'train_batch_size': 32,
 'eval_batch_size': 1000,
 'sequence_length': 128,
 'hidden_size': 256,
 'sample_length': 128,
 'learning_rate': 0.001,
 'training_steps': 50000,
 'evaluation_interval': 1000,
 'sampling

In [10]:


class LoopValues(NamedTuple):
  tokens: jnp.ndarray
  state: Any
  rng_key: jnp.ndarray


class TrainingState(NamedTuple):
  params: hk.Params
  opt_state: optax.OptState


def make_network() -> hk.RNNCore:
  """Defines the network architecture."""
  model = hk.DeepRNN([
      lambda x: jax.nn.one_hot(x, num_classes=NUM_CHARS),
      hk.LSTM(HIDDEN_SIZE.value),
      jax.nn.relu,
      hk.LSTM(HIDDEN_SIZE.value),
      hk.nets.MLP([HIDDEN_SIZE.value, NUM_CHARS]),
  ])
  return model


def make_optimizer() -> optax.GradientTransformation:
  """Defines the optimizer."""
  return optax.adam(LEARNING_RATE.value)


def sequence_loss(batch: Batch) -> jnp.ndarray:
  """Unrolls the network over a sequence of inputs & targets, gets loss."""
  # Note: this function is impure; we hk.transform() it below.
  core = make_network()
  sequence_length, batch_size = batch['input'].shape
  initial_state = core.initial_state(batch_size)
  logits, _ = hk.dynamic_unroll(core, batch['input'], initial_state)
  log_probs = jax.nn.log_softmax(logits)
  one_hot_labels = jax.nn.one_hot(batch['target'], num_classes=logits.shape[-1])
  return -jnp.sum(one_hot_labels * log_probs) / (sequence_length * batch_size)


@jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
  """Does a step of SGD given inputs & targets."""
  _, optimizer = make_optimizer()
  _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
  gradients = jax.grad(loss_fn)(state.params, batch)
  updates, new_opt_state = optimizer(gradients, state.opt_state)
  new_params = optax.apply_updates(state.params, updates)
  return TrainingState(params=new_params, opt_state=new_opt_state)


def sample(
    rng_key: jnp.ndarray,
    context: jnp.ndarray,
    sample_length: int,
) -> jnp.ndarray:
  """Draws samples from the model, given an initial context."""
  # Note: this function is impure; we hk.transform() it below.
  assert context.ndim == 1  # No batching for now.
  core = make_network()

  def body_fn(t: int, v: LoopValues) -> LoopValues:
    token = v.tokens[t]
    next_logits, next_state = core(token, v.state)
    key, subkey = jax.random.split(v.rng_key)
    next_token = jax.random.categorical(subkey, next_logits, axis=-1)
    new_tokens = v.tokens.at[t + 1].set(next_token)
    return LoopValues(tokens=new_tokens, state=next_state, rng_key=key)

  logits, state = hk.dynamic_unroll(core, context, core.initial_state(None))
  key, subkey = jax.random.split(rng_key)
  first_token = jax.random.categorical(subkey, logits[-1])
  tokens = jnp.zeros(sample_length, dtype=np.int32)
  tokens = tokens.at[0].set(first_token)
  initial_values = LoopValues(tokens=tokens, state=state, rng_key=key)
  values: LoopValues = lax.fori_loop(0, sample_length, body_fn, initial_values)

  return values.tokens


def main():
  flags.FLAGS.alsologtostderr = True
  flags.FLAGS.verbosity = 1

  # Make training dataset.
  train_data = dataset_load(
      tfds.Split.TRAIN,
      batch_size=TRAIN_BATCH_SIZE.value,
      sequence_length=SEQUENCE_LENGTH.value)

  # Make evaluation dataset(s).
  eval_data = {  # pylint: disable=g-complex-comprehension
      split: dataset_load(
          split,
          batch_size=EVAL_BATCH_SIZE.value,
          sequence_length=SEQUENCE_LENGTH.value)
      for split in [tfds.Split.TRAIN, tfds.Split.TEST]
  }

  # Make loss, sampler, and optimizer.
  params_init, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
  _, sample_fn = hk.without_apply_rng(hk.transform(sample))
  opt_init, _ = make_optimizer()

  loss_fn = jax.jit(loss_fn)
  sample_fn = jax.jit(sample_fn, static_argnums=[3])

  # Initialize training state.
  rng = hk.PRNGSequence(SEED.value)
  initial_params = params_init(next(rng), next(train_data))
  initial_opt_state = opt_init(initial_params)
  state = TrainingState(params=initial_params, opt_state=initial_opt_state)

  # Training loop.
  for step in range(TRAINING_STEPS.value + 1):
    # Do a batch of SGD.
    train_batch = next(train_data)
    state = update(state, train_batch)

    # Periodically generate samples.
    if step % SAMPLING_INTERVAL.value == 0:
      context = train_batch['input'][:, 0]  # First element of training batch.
      assert context.ndim == 1
      rng_key = next(rng)
      samples = sample_fn(state.params, rng_key, context, SAMPLE_LENGTH.value)

      prompt = decode(context)
      continuation = decode(samples)

      logging.info('\n===Prompt:\n%s\n===\n', prompt)
      logging.info('\n---Continuation:\n%s\n---\n', continuation)

    # Periodically evaluate training and test loss.
    if step % EVALUATION_INTERVAL.value == 0:
      for split, ds in eval_data.items():
        eval_batch = next(ds)
        loss = loss_fn(state.params, eval_batch)
        logging.info({
            'step': step,
            'loss': float(loss),
            'split': split,
        })


In [8]:
TRAIN_BATCH_SIZE.value

32

In [22]:
train_data = dataset_load(
      tfds.Split.TRAIN,
      batch_size=TRAIN_BATCH_SIZE.value,
      sequence_length=SEQUENCE_LENGTH.value)


INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: tiny_shakespeare/1.0.0
INFO:absl:Load dataset info from /tmp/tmpgo22fok1tfds
INFO:absl:Field info.description from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Generating dataset tiny_shakespeare (~/tensorflow_datasets/tiny_shakespeare/1.0.0)


Downloading and preparing dataset Unknown size (download: Unknown size, generated: 1.06 MiB, total: 1.06 MiB) to ~/tensorflow_datasets/tiny_shakespeare/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

INFO:absl:Downloading https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt into /root/tensorflow_datasets/downloads/raw.gith.com_karp_char-rnn_mast_tiny_inpugogO998CpE557gFI85J16FbtM1Ig3B0ySj9UhS6f7GM.txt.tmp.52fd7acb5b964ffdb17ac386990da976...


Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/1 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/tiny_shakespeare/1.0.0.incomplete7XK138/tiny_shakespeare-train.tfrecord*...:  …

INFO:absl:Done writing ~/tensorflow_datasets/tiny_shakespeare/1.0.0.incomplete7XK138/tiny_shakespeare-train.tfrecord*. Number of examples: 1 (shards: [1])


Generating validation examples...:   0%|          | 0/1 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/tiny_shakespeare/1.0.0.incomplete7XK138/tiny_shakespeare-validation.tfrecord*.…

INFO:absl:Done writing ~/tensorflow_datasets/tiny_shakespeare/1.0.0.incomplete7XK138/tiny_shakespeare-validation.tfrecord*. Number of examples: 1 (shards: [1])


Generating test examples...:   0%|          | 0/1 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/tiny_shakespeare/1.0.0.incomplete7XK138/tiny_shakespeare-test.tfrecord*...:   …

INFO:absl:Done writing ~/tensorflow_datasets/tiny_shakespeare/1.0.0.incomplete7XK138/tiny_shakespeare-test.tfrecord*. Number of examples: 1 (shards: [1])
INFO:absl:Constructing tf.data.Dataset tiny_shakespeare for split train, from ~/tensorflow_datasets/tiny_shakespeare/1.0.0


Dataset tiny_shakespeare downloaded and prepared to ~/tensorflow_datasets/tiny_shakespeare/1.0.0. Subsequent calls will reuse this data.


In [None]:
main()

INFO:absl:Load dataset info from ~/tensorflow_datasets/tiny_shakespeare/1.0.0
INFO:absl:Reusing dataset tiny_shakespeare (~/tensorflow_datasets/tiny_shakespeare/1.0.0)
INFO:absl:Constructing tf.data.Dataset tiny_shakespeare for split train, from ~/tensorflow_datasets/tiny_shakespeare/1.0.0
INFO:absl:Load dataset info from ~/tensorflow_datasets/tiny_shakespeare/1.0.0
INFO:absl:Reusing dataset tiny_shakespeare (~/tensorflow_datasets/tiny_shakespeare/1.0.0)
INFO:absl:Constructing tf.data.Dataset tiny_shakespeare for split train, from ~/tensorflow_datasets/tiny_shakespeare/1.0.0
INFO:absl:Load dataset info from ~/tensorflow_datasets/tiny_shakespeare/1.0.0
INFO:absl:Reusing dataset tiny_shakespeare (~/tensorflow_datasets/tiny_shakespeare/1.0.0)
INFO:absl:Constructing tf.data.Dataset tiny_shakespeare for split test, from ~/tensorflow_datasets/tiny_shakespeare/1.0.0
INFO:absl:
===Prompt:
Shalt see me once more strike at Tullus' face.
What, art thou stiff? stand'st out?

TITUS:
No, Caius Marci