# tinyshakespeare-transformer

> Code for instantiating a pre-trained TinyShakespeare transformer model.


In [None]:
# | default_exp trained_models.tinyshakespeare_transformer

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
#| hide
from fastcore.test import *

In [None]:
# | export
from typing import Callable, Dict, Iterable, Tuple

In [None]:
# | export
import torch

In [None]:
# | export
from transformer_experiments.datasets.tinyshakespeare import (
    TinyShakespeareDataSet,
)
from transformer_experiments.models.transformer import TransformerLanguageModel
from transformer_experiments.tokenizers.char_tokenizer import (
    CharacterTokenizer,
)


In [None]:
from transformer_experiments.environments import get_environment

In [None]:
# | export


def create_model_and_tokenizer(
    saved_model_filename: str, dataset: TinyShakespeareDataSet, device: str
) -> Tuple[
    TransformerLanguageModel, CharacterTokenizer
]:
    """Instantiates a pre-trained TinyShakespeare model: creates transformer model,
    loads the model params from a saved file, and creates a tokenizer from the dataset's text.
    """

    # Create a tokenizer from the dataset's text
    tokenizer = CharacterTokenizer(dataset.text)

    # Create the model
    m = TransformerLanguageModel(vocab_size=tokenizer.vocab_size, device=device)
    m.to(device)

    # Load the model params from a saved file
    m.load_state_dict(
        torch.load(saved_model_filename, map_location=torch.device(device))
    )
    m.eval()

    return m, tokenizer

In [None]:
#| exporti

# Define names for special characters in the TinyShakespeare dataset
# that can be used as filenames.
special_char_names = {
    '\n': 'newline',
    ' ': 'space',
    '!': 'exclamation',
    '$': 'dollar',
    '&': 'ampersand',
    '\'': 'single_quote',
    ',': 'comma',
    '-': 'dash',
    ':': 'colon',
    ';': 'semicolon',
    '.': 'period',
    '?': 'question',
    '3': 'three',
}

In [None]:
#| export

class FilenameForToken:
    def __init__(self, tokenizer: CharacterTokenizer):
        self.tokenizer = tokenizer

    def __call__(self, token: str) -> str:
        """Given a character, returns a safe filename representing that character."""
        if token not in self.tokenizer.chars:
            raise ValueError(f'unknown character {token}')

        if token in special_char_names:
            return special_char_names[token]

        i = self.tokenizer.stoi[token]
        if i >= self.tokenizer.stoi['A'] and i <= self.tokenizer.stoi['Z']:
            return f'capital_{token.lower()}'
        elif i >= self.tokenizer.stoi['a'] and i <= self.tokenizer.stoi['z']:
            return f'lower_{token}'

        # Ensure that there is not some character in chars we didn't specifically handle.
        raise ValueError(f'unknown character {token}')

In [None]:
environment = get_environment()
print(f"environment is {environment.name}")

environment is local_mac


In [None]:
ts = TinyShakespeareDataSet(cache_file=environment.code_root / 'nbs/artifacts/input.txt')
tokenizer = CharacterTokenizer(ts.text)

In [None]:
# Tests for FilenameForToken
filename_for_token = FilenameForToken(tokenizer)

test_eq(filename_for_token('A'), 'capital_a')
test_eq(filename_for_token('a'), 'lower_a')
test_eq(filename_for_token(' '), 'space')
test_eq(filename_for_token('\n'), 'newline')
test_eq(filename_for_token('!'), 'exclamation')

# Test that we never get the ValueError exception at the end of the function for any character in chars.
for token in tokenizer.chars:
    filename_for_token(token)

# Test that we do get an exception for an unknown character.
with ExceptionExpected(ex=ValueError):
    filename_for_token('🤔')

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()