In [1]:
# # On Google colab:

# # Data
# !mkdir data
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
# !mv input.txt data/shakespeare.txt

# # Module
# !pip install git+https://github.com/shuiruge/nanogpt.git

In [2]:
# Locally:
import sys
sys.path.append('../nanogpt')

In [3]:
import numpy as np
import tensorflow as tf
from keras.layers import Layer, Dense, LayerNormalization, Dropout
from keras.models import Model
from keras.optimizers import AdamW
from keras.losses import SparseCategoricalCrossentropy
from keras.callbacks import EarlyStopping
from dataclasses import dataclass
from typing import List

# from nanogpt.utils import (
from utils import (
    CharacterTokenizer, TokPosEmbedding, FeedForward, MultiHeadSelfAttention,
    ResNetWrapper, MaskedLanguageModelDataGenerator,
)

tf.random.set_seed(42)

2024-03-18 23:13:55.249361: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
with open('data/shakespeare.txt', 'r') as f:
    corpus = ''.join(f.readlines())
print(corpus[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [5]:
tokenizer = CharacterTokenizer(corpus, placeholders=['<MASK>'])
mask_token_id = tokenizer.get_id('<MASK>')
tokenizer.vocab_size, mask_token_id

(66, 65)

In [6]:
seq_len = 64
num_mask = 1
get_mask_positions = lambda context: [len(context)-i-1 for i in range(num_mask)]
data = MaskedLanguageModelDataGenerator(
    tokenizer.encode(corpus), mask_token_id, get_mask_positions)

contexts = []
targets = []
masks = []
while True:
    try:
        context, target, mask = data(seq_len, False)
        contexts.append(context)
        targets.append(target)
        masks.append(mask)
    except StopIteration:
        break
contexts = np.stack(contexts).astype('int64')
targets = np.asarray(targets, 'int64')
masks = np.asarray(masks, 'int64')

In [7]:
contexts[0], targets[0], masks[0]

(array([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43,
        44, 53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39,
        52, 63,  1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1,
        51, 43,  1, 57, 54, 43, 39, 49,  8,  0,  0, 13, 65]),
 array([50]),
 array([63]))

In [8]:
@dataclass
class BERTConfig:
    vocab_size: int
    seq_len: int
    embed_dim: int
    model_dim: int
    num_heads: int
    ffd_hidden_units: List[int]
    num_trans_blocks: int
    mask_token_id: int


class BERT(Model):

    def __init__(self, cfg: BERTConfig, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg

        self.embedding_layer = TokPosEmbedding(
            cfg.vocab_size, cfg.seq_len, cfg.embed_dim)

        self.trans_blocks = []
        for _ in range(cfg.num_trans_blocks):
            self.trans_blocks.append(
                ResNetWrapper(MultiHeadSelfAttention(cfg.model_dim, cfg.num_heads))
            )
            self.trans_blocks.append(
                ResNetWrapper(FeedForward(cfg.ffd_hidden_units, cfg.model_dim))
            )

        self.output_layer = Dense(cfg.vocab_size)

    def call(self, inputs):
        x, mask_positions = inputs
        x = self.embedding_layer(x)
        for layer in self.trans_blocks:
            x = layer(x)
        # x has shape [batch_size, seq_len, model_dim].
        # See: https://www.tensorflow.org/api_docs/python/tf/gather#batching
        x = tf.gather(x, indices=mask_positions, axis=1, batch_dims=1)
        x = self.output_layer(x)
        return x
    
    def generate(self, init_token_ids, num_new_tokens, max_iter, T):
        token_ids = list(init_token_ids)
        num_ref = len(init_token_ids) - num_new_tokens
        for _ in range(max_iter):
            input = (
                token_ids[-num_ref:] +
                [self.cfg.mask_token_id for _ in range(num_new_tokens)]
            )
            # Add batch_size for matching the input shape of `self.call`.
            input = tf.expand_dims(input, axis=0)
            mask = [[num_ref+i for i in range(num_new_tokens)]]
            # [1, num_new_tokens, vocab_size]
            logits = self((input, mask))
            # [num_new_tokens, vocab_size]
            logits = logits[0, :, :]
            # [num_new_tokens]
            next_token_ids = tf.random.categorical(logits/T, 1)
            token_ids += list(next_token_ids[:, 0].numpy())
        return token_ids

In [9]:
cfg = BERTConfig(
    tokenizer.vocab_size, seq_len,
    embed_dim=64,
    model_dim=64,
    num_heads=4,
    ffd_hidden_units=[4*64],
    num_trans_blocks=2,
    mask_token_id=mask_token_id,
)

model = BERT(cfg)
# model.build([[None, seq_len], [None, num_mask]])
model((contexts[:10], masks[:10]))
model.summary()

Model: "bert"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 tok_pos_embedding (TokPosE  multiple                  8320      
 mbedding)                                                       
                                                                 
 resnet_wrapper_of_multi_he  multiple                  8320      
 ad_self_attention (ResNetW                                      
 rapper)                                                         
                                                                 
 resnet_wrapper_of_feed_for  multiple                  33216     
 ward (ResNetWrapper)                                            
                                                                 
 resnet_wrapper_of_multi_he  multiple                  8320      
 ad_self_attention_1 (ResNe                                      
 tWrapper)                                                    

In [13]:
model.compile(
    optimizer=AdamW(),
    loss=SparseCategoricalCrossentropy(from_logits=True),
)
model.fit(
    x=(contexts, masks),
    y=targets,
    batch_size=64,
    validation_split=0.1,
    # The epochs argument shall be as large as possible. And we control the
    # true epochs by early-stopping.
    epochs=100,
    callbacks=[EarlyStopping()]
)

Epoch 1/100


2024-03-18 23:26:28.836144: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 513944064 exceeds 10% of free system memory.





KeyboardInterrupt



In [14]:
def generate(self, init_token_ids, num_new_tokens, max_iter, T):
    token_ids = list(init_token_ids)
    num_ref = len(init_token_ids) - num_new_tokens
    for _ in range(max_iter):
        input = (
            token_ids[-num_ref:] +
            [self.cfg.mask_token_id for _ in range(num_new_tokens)]
        )
        # Add batch_size for matching the input shape of `self.call`.
        input = tf.expand_dims(input, axis=0)
        mask = [[num_ref+i for i in range(num_new_tokens)]]
        # [1, num_new_tokens, vocab_size]
        logits = self((input, mask))
        # [num_new_tokens, vocab_size]
        logits = logits[0, :, :]
        # [num_new_tokens]
        next_token_ids = tf.random.categorical(logits/T, 1)
        token_ids += list(next_token_ids[:, 0].numpy())
    return token_ids

In [15]:
generated = generate(model, contexts[0, :-num_mask], 5, 100, 1)
print(tokenizer.decode(generated))

First Citizen:
Before we proceed any further, hear me speak.

AhuSto
dst,cl r ;fst  pissrnni u''Tt r't tu,i dr;ttictta tene.t.i?Sent,ti t'ys et,r:t;ad' ,'s tnts u?'ssi 'b,i  nhi ssse ei t!at
ii'nsy.ur' ; rei?'s is ti, tr y.tuttcos e er
u, niit',.e ,ff',tr;y aset,teuttet,titr ysasst tu t  emt.at'm'ttsi,,e i,ti'i.ti
n,an''?'t;t an nu' 'mtotit,rneiiati,i; ,ei,''a en.u ,tntsTssy pn utos t,sryetis!t,ts
e,tmou 'r rtrs te n itsd , y 'I;e  th,iem!ss'ttt si .-ti,ss.i tn.ttaulet's'i ', tt'nei ta,i,us't ',,ie n a,irsttonse!t!
!aeeny!ts., 's,otat,t,ai , dn e!tsst Ui ts


In [17]:
print(tokenizer.decode(contexts[0, :-num_mask]))

First Citizen:
Before we proceed any further, hear me speak.

A
