# Tensorflow gpt2 implemenation

Here we define a tensorflow implementation of gpt2 as provided by https://github.com/ShenakhtPajouh/gpt2-keras
and load the weights. The goal is to duplicate as closely as possible the Pytorch implemenation
from LLMFS. See load_gpt.ipynb in ch05

In [6]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
!pip install -q tiktoken
import tiktoken


[0m

## gpt2.py

In [7]:
# From https://github.com/ShenakhtPajouh/gpt2-keras
import tensorflow as tf
import numpy as np


def get_tensor_shape(x):
    x = tf.convert_to_tensor(x)
    static_shape = x.shape.as_list()
    if tf.executing_eagerly():
        return static_shape
    dynamic_shape = tf.shape(x)
    if static_shape is None:
        return dynamic_shape
    dynamic_shape = tf.unstack(dynamic_shape)
    shape = []
    for st, dyn in zip(static_shape, dynamic_shape):
        if st is None:
            shape.append(dyn)
        else:
            shape.append(st)
    return shape


def gelu(x):
    return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))

def dropout_fn(x, dropout):
    if dropout is None or dropout == 0.0:
        return x
    else:
        return tf.nn.dropout(x, rate=dropout)


class LayerNormalization(tf.keras.layers.Layer):

    def __init__(self, trainable=True, name=None):
        super().__init__(name=name, trainable=trainable)
        self.beta = None
        self.gamma = None

    def build(self, input_shape):
        self.beta = self.add_weight(name="beta", shape=input_shape[-1:], initializer=tf.zeros_initializer())
        self.gamma = self.add_weight(name="gamma", shape=input_shape[-1:], initializer=tf.ones_initializer())
        super().build(input_shape)

    def call(self, inputs, axis=-1, epsilon=1e-5):
        # mean, variance = tf.nn.moments(inputs, axis, keep_dims=True)
        mean, variance = tf.nn.moments(inputs, axis, keepdims=True)
        rdev = tf.math.rsqrt(variance + epsilon)
        x = (inputs - mean) * rdev
        output = x * self.gamma + self.beta
        return output

    def __call__(self, inputs, axis=-1, epsilon=1e-5):
        return super().__call__(inputs=inputs,
                                axis=axis, epsilon=epsilon)


class SelfAttention(tf.keras.layers.Layer):

    def __init__(self, num_attention_heads=1, size_per_head=512,
                 one_sided=True,
                 query_act=None,
                 initializer_range=0.02,
                 value_act=None,
                 key_act=None,
                 trainable=True,
                 name=None):
        super().__init__(name=name, trainable=trainable)
        # `query_layer` = [B*F, N*H]
        self.attention_size = num_attention_heads * size_per_head
        self.query_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=query_act,
            name="query",
            kernel_initializer=tf.random_normal_initializer(stddev=initializer_range)
        )
        # `key_layer` = [B*T, N*H]
        self.key_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=key_act,
            name="key",
            kernel_initializer=tf.random_normal_initializer(stddev=initializer_range)
        )
        # `value_layer` = [B*T, N*H]
        self.value_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=value_act,
            name="value",
            kernel_initializer=tf.random_normal_initializer(stddev=initializer_range)
        )
        self.size_per_head = size_per_head
        self.num_attention_heads = num_attention_heads
        self.one_sided = one_sided

    def reshape(self, x, use_2d=False, shape=None):
        if use_2d:
            batch_size, seq_length = shape[0], shape[1]
        else:
            _shape = get_tensor_shape(x)
            batch_size, seq_length = _shape[0], _shape[1]
        x = tf.reshape(x, [batch_size, seq_length, self.num_attention_heads, self.size_per_head])
        x = tf.transpose(x, [0, 2, 1, 3])
        return x

    def final_shape(self, x, use_2d=False):
        shape = get_tensor_shape(x)
        batch_size, seq_length = shape[0], shape[2]
        x = tf.transpose(x, [0, 2, 1, 3])
        if use_2d:
            x = tf.reshape(x, [batch_size * seq_length, self.num_attention_heads * self.size_per_head])
        else:
            x = tf.reshape(x, [batch_size, seq_length, self.num_attention_heads * self.size_per_head])
        return x

    def get_mask(self, inputs_shape, cache_length=None, mask=None):
        batch_size, seq_length = inputs_shape[0], inputs_shape[2]
        if self.one_sided:
            rng = tf.range(seq_length)
            one_sided_mask = tf.less_equal(rng, tf.expand_dims(rng, 1))
            if cache_length is not None:
                prev_mask = tf.ones([seq_length, cache_length], tf.bool)
                one_sided_mask = tf.concat([prev_mask, one_sided_mask], 1)
        if mask is not None:
            if cache_length is not None:
                prev_mask = tf.ones([batch_size, cache_length], tf.bool)
                mask = tf.concat([prev_mask, mask], 1)
            if cache_length is None:
                cache_length = 0
            mask = tf.reshape(mask, [batch_size, 1, 1, seq_length + cache_length])
        if self.one_sided:
            if mask is not None:
                one_sided_mask = tf.logical_and(mask, one_sided_mask)
            return one_sided_mask
        else:
            return mask

    def attend(self, query, key, value, mask=None, dropout=None):
        dim = tf.cast(self.size_per_head, query.dtype)
        _sqrt = tf.math.sqrt(dim)
        _sqrt = tf.cast(_sqrt, query.dtype)
        coefficients = tf.matmul(query, key, transpose_b=True) / _sqrt
        if mask is not None:
            mask = tf.cast(mask, coefficients.dtype)
            coefficients = coefficients * mask - (1 - mask) * 1e5
        coefficients = tf.math.softmax(coefficients, -1)
        coefficients = dropout_fn(coefficients, dropout)
        results = tf.matmul(coefficients, value)
        return results

    def call(self, inputs, cache=None, mask=None,
             attention_dropout=None, return_cache=False,
             use_2d=False, shape=None):
        """
        inputs: a tensor of shape [batch_size, seq_length, dim] if use_2d is false,
                else a tensor of shape [batch_size * seq_length, dim]
        cache: A dictionary consist of key and value from previous calls.
        mask: a boolean tensor of shape [batch_size, seq_length]
        attention_probs_dropout_prob: dropout use for attention mechanism
        return_cache: if True, it returns key and values as besides layer output
        use_2d: if it is True, the model uses 2D matrices as inputs and outputs
        shape: if use_2d is True, then the shape is [batch_size, seq_length]
        """
        query = self.query_layer(inputs)
        key = self.key_layer(inputs)
        value = self.value_layer(inputs)
        if use_2d and shape is None:
            raise ValueError("if use_2d is True, then the shape must be specified")
        query = self.reshape(query, use_2d, shape)
        key = self.reshape(key, use_2d, shape)
        value = self.reshape(value, use_2d, shape)
        cache_length = None
        if cache is not None:
            key = tf.concat([cache["key"], key], 2)
            value = tf.concat([cache["value"], value], 2)
            cache_length = get_tensor_shape(cache["key"])[2]
        inputs_shape = get_tensor_shape(query)
        mask = self.get_mask(inputs_shape, cache_length, mask)
        result = self.attend(query, key, value, mask, attention_dropout)
        result = self.final_shape(result, use_2d)
        if return_cache:
            cache = {"key": key, "value": value}
            return result, cache
        else:
            return result

    def __call__(self, inputs, cache=None, mask=None,
             attention_dropout=None, return_cache=False,
             use_2d=False, shape=None):
        """
        inputs: a tensor of shape [batch_size, seq_length, dim] if use_2d is false,
                else a tensor of shape [batch_size * seq_length, dim]
        cache: A dictionary consist of key and value from previous calls.
        mask: a boolean tensor of shape [batch_size, seq_length]
        attention_probs_dropout_prob: dropout use for attention mechanism
        return_cache: if True, it returns key and values as besides layer output
        use_2d: if it is True, the model uses 2D matrices as inputs and outputs
        shape: if use_2d is True, then the shape is [batch_size, seq_length]
        """
        return super().__call__(
            inputs=inputs,
            cache=cache,
            mask=mask,
            attention_dropout=attention_dropout,
            return_cache=return_cache,
            use_2d=use_2d,
            shape=shape
        )

class AttentionLayer(tf.keras.layers.Layer):

    def __init__(self, config, name=None, trainable=True, initializer_range=0.02):
        super().__init__(name=name, trainable=trainable)
        self.layer_norm = LayerNormalization(name="layer_norm")
        self.self_attention = SelfAttention(num_attention_heads=config["n_head"],
                                            size_per_head=config["n_embd"] // config["n_head"],
                                            initializer_range=initializer_range,
                                            name="self"
                                            )
        self.projection = tf.keras.layers.Dense(units=config["n_embd"],
                                                kernel_initializer=tf.random_normal_initializer(stddev=initializer_range),
                                                name="projection")


    def call(self, inputs, cache=None, dropout=None, attention_dropout=None,
             return_cache=False, use_2d=False, shape=None):
        """

        inputs: a tensor of shape [batch_size, seq_length, dim] if use_2d is False, else [batch_size * seq_length, dim]
        cache: (Optional): a dictionary of tensors key and value from previous calls.
        return_cache: if True, returns a dictionary of key and value tensors besides layer output.
        use_2d: if is True then the inputs and outputs are 2D tensors instead of 3D (for tpu performance)
        shape: if use_2d then it's [batch_size, seq_length]
        """
        x = self.layer_norm(inputs)
        x = self.self_attention(x, attention_dropout=attention_dropout,
                                cache=cache,
                                return_cache=return_cache,
                                use_2d=use_2d,
                                shape=shape)
        if return_cache:
            x, cache = x
        x = self.projection(x)
        x = dropout_fn(x, dropout)
        if return_cache:
            return x, cache
        else:
            return x

    def __call__(self, inputs, cache=None, dropout=None, attention_dropout=None,
                 return_cache=False, use_2d=False, shape=None):
        """

        inputs: a tensor of shape [batch_size, seq_length, dim] if use_2d is False, else [batch_size * seq_length, dim]
        cache: (Optional): a dictionary of tensors key and value from previous calls.
        return_cache: if True, returns a dictionary of key and value tensors besides layer output.
        use_2d: if is True then the inputs and outputs are 2D tensors instead of 3D (for tpu performance)
        shape: if use_2d then it's [batch_size, seq_length]
        """
        return super().__call__(
            inputs=inputs,
            cache=cache,
            dropout=dropout,
            attention_dropout=attention_dropout,
            return_cache=return_cache,
            use_2d=use_2d,
            shape=shape
        )



class MultiLayerPerceptron(tf.keras.layers.Layer):

    def __init__(self, activation_fn=None, embedding_size=768,
                 perceptron_size=3072, trainable=True,
                 initializer_range=0.02, name=None):
        super().__init__(name=name, trainable=trainable)
        self.layer_norm = LayerNormalization(name="layer_norm")
        self.perceptron = tf.keras.layers.Dense(units=perceptron_size,
                                                activation=activation_fn,
                                                kernel_initializer=tf.random_normal_initializer(stddev=initializer_range),
                                                name="perceptron")
        self.projection = tf.keras.layers.Dense(units=embedding_size,
                                                kernel_initializer=tf.random_normal_initializer(stddev=initializer_range),
                                                name="projection")

    def call(self, inputs, dropout=None):
        """

        inputs: tensor of [batch_size, seq_length, dim]

        """
        x = self.layer_norm(inputs)
        x = self.perceptron(x)
        x = self.projection(x)
        x = dropout_fn(x, dropout)
        return x

    def __call__(self, inputs, dropout=None):
        return super().__call__(inputs=inputs,
                                dropout=dropout)


class Block(tf.keras.layers.Layer):

    def __init__(self, config, trainable=True, initializer_range=0.02, name=None):
        super().__init__(name=name, trainable=trainable)
        self.attention = AttentionLayer(config=config,
                                        initializer_range=initializer_range,
                                        name="attention")
        self.mlp = MultiLayerPerceptron(activation_fn=gelu,
                                        embedding_size=config["n_embd"],
                                        perceptron_size=4 * config["n_embd"],
                                        initializer_range=initializer_range,
                                        name="mlp")

    def call(self, inputs, cache=None, dropout=None, attention_dropout=None,
            return_cache=False, use_2d=False, shape=None):
        x = inputs
        a = self.attention(inputs=x,
                           cache=cache,
                           dropout=dropout,
                           attention_dropout=attention_dropout,
                           return_cache=return_cache,
                           use_2d=use_2d,
                           shape=shape)
        if return_cache:
            a, cache = a
        x = x + a
        m = self.mlp(inputs=x,
                     dropout=dropout)
        x = x + m
        if return_cache:
            return x, cache
        else:
            return x

    def __call__(self, inputs, cache=None, dropout=None, attention_dropout=None,
                 return_cache=False, use_2d=False, shape=None):
        return super().__call__(inputs=inputs,
                                cache=cache,
                                dropout=dropout,
                                attention_dropout=attention_dropout,
                                return_cache=return_cache,
                                use_2d=use_2d,
                                shape=shape)


class Transformer(tf.keras.Model):

    def __init__(self, config, trainable=True, name=None):
        super().__init__(name=name)
        self.trainable = trainable
        self.blocks = []
        self.blocks_num = config["n_layer"]
        for ids in range(self.blocks_num):
            block = Block(config=config,
                          name="block_%d" % ids)
            self.blocks.append(block)
        self.layer_norm = LayerNormalization(name="layer_norm")

    def call(self, inputs, cache=None, dropout=None, attention_dropout=None,
             return_cache=False, blocks=None, use_2d=False, shape=None):
        """

        inputs: a tensor of shape [batch_size, seq_length, dim], if use_2d is False, else [batch_size * seq_length, dim]
        cache: a list of dictionaries. key and values from previous calls.
        blocks: a list. if it is specified, the output will be a dictionary {layer_num: layer_output}
        return_cache: if it is true, it will returns cache for blocks
        use_2d: if it is True, then the operations will define base on 2D tensors. (for tpu performance)
        shape: if use_2d is True, then it is [batch_size, seq_length]

        """
        if blocks is None:
            max_block = self.blocks_num - 1
        elif len(blocks) == 0:
            max_block = self.blocks_num - 1
            blocks = None
        else:
            _blocks = []
            for i in blocks:
                if i >= 0:
                    k = i
                else:
                    k = self.blocks_num - i
                if k >= self.blocks_num or k < 0:
                    raise ValueError("output blocks should be in range [" + str(0) + ", " +
                                     str(self.blocks_num - 1) + "]")
                _blocks.append(k)
            _blocks = list(sorted(_blocks))
            blocks = _blocks
            max_block = blocks[-1]
        if blocks is not None:
            outputs = {}
        if return_cache:
            new_cache = []
        output = inputs
        for ids in range(max_block + 1):
            if cache is None:
                _cache = None
            else:
                _cache = cache[ids]
            output = self.blocks[ids](inputs=output,
                                      cache=_cache,
                                      dropout=dropout,
                                      attention_dropout=attention_dropout,
                                      return_cache=return_cache,
                                      use_2d=use_2d,
                                      shape=shape)
            if return_cache:
                output, _cache = output
                new_cache.append(_cache)
            if blocks is not None:
                if ids in blocks:
                    outputs[ids] = output
        if blocks is None:
            output = self.layer_norm(output)
            result = output
        else:
            result = outputs
        if return_cache:
            return result, new_cache
        else:
            return result

    def __call__(self, inputs, cache=None, dropout=None, attention_dropout=None,
                 return_cache=False, blocks=None, use_2d=False, shape=None):
        """

        inputs: a tensor of shape [batch_size, seq_length, dim], if use_2d is False, else [batch_size * seq_length, dim]
        cache: a list of dictionaries. key and values from previous calls.
        blocks: a list. if it is specified, the output will be a dictionary {layer_num: layer_output}
        return_cache: if it is true, it will returns cache for blocks
        use_2d: if it is True, then the operations will define base on 2D tensors. (for tpu performance)
        shape: if use_2d is True, then it is [batch_size, seq_length]

        """
        return super().__call__(
            inputs=inputs,
            cache=cache,
            dropout=dropout,
            attention_dropout=attention_dropout,
            return_cache=return_cache,
            blocks=blocks,
            use_2d=use_2d,
            shape=shape
        )


class Embedding(tf.keras.layers.Layer):

    def __init__(self, embedding_size, vocab_size, max_position_length,
                 trainable=True, name=None, initializer_range=0.02,
                 dtype=None):
        if dtype is None:
            dtype = tf.float32
        super().__init__(name=name, trainable=trainable, dtype=dtype)
        self.word_embedding = None
        self.position_embedding = None
        self.initializer_range = initializer_range
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size
        self.max_position_length = max_position_length

    def build(self, input_shape):
        self.word_embedding = self.add_weight(
            name="word_embedding",
            shape=(self.vocab_size, self.embedding_size),
            initializer=tf.random_normal_initializer(stddev=self.initializer_range),
        )
        self.position_embedding = self.add_weight(
            name="position_embedding",
            shape=(self.max_position_length, self.embedding_size),
            initializer=tf.random_normal_initializer(stddev=self.initializer_range),
        )

    def call(self, inputs, start=None):
        """

        inputs: integer tensor of [batch_size, seq_length]
        start: start of positional embedding

        """
        shape = get_tensor_shape(inputs)
        x = tf.gather(self.word_embedding, inputs)
        if start is None:
            start = 0
        end = start + shape[1]
        pe = self.position_embedding[start:end]
        x = x + pe
        return x

    def __call__(self, inputs, start=None):
        """

        if use_one_hot_keys is True, then inputs are one_hot tensors of shape [batch_size, seq_length, vocab_size],
        else it is an integer tensor of [batch_size, seq_length] of token ids.
        start: start of positional embedding

        """
        return super().__call__(inputs=inputs, start=start)


class GPT2(tf.keras.Model):

    def __init__(self, config, name=None, trainable=True, dtype=None):
        super().__init__(name=name)
        self.trainable = trainable
        self.embedding = Embedding(
            embedding_size=config['n_embd'],
            vocab_size=config['n_vocab'],
            max_position_length=config['n_ctx'],
            name="embedding",
            dtype=dtype
        )
        self.transformer = Transformer(config, name="transformer")

    def call(self, inputs, cache=None,
             dropout=None, attention_dropout=None,
             return_cache=False, return_logits=True, use_2d=False):
        """

        inputs: an integer tensor of shape [batch_size, seq_length] if not use_2d is False
                else a one_hot tensor of shape [batch_size, seq_length, vocab_size]
        cache: a list of dictionaries {"key": key, "value": value} of previous keys and values. it uses for generation
        use_one_hot_keys: if True it uses one hot tensors for embedding layer.
        return_cache: if True returns new keys and values alongside output. it uses for generation.
        return_logits: if True, return logits, else return last layer embedding.
        use_2d: for tpu performances: use 2D tensors for operations and return the output in 2D shape: [batch_size * seq_length, -1]

        """
        if cache is not None:
            _cache = cache[0]["key"]
            start = get_tensor_shape(_cache)[2]
        else:
            start = None
        x = self.embedding(inputs, start)
        if use_2d:
            shape = get_tensor_shape(x)
            x = tf.reshape(x, [shape[0] * shape[1], shape[2]])
            shape = shape[0:2]
        else:
            shape = None
        x = self.transformer(
            inputs=x,
            cache=cache,
            dropout=dropout,
            attention_dropout=attention_dropout,
            return_cache=return_cache,
            use_2d=use_2d,
            shape=shape
        )
        if return_cache:
            x, cache = x
        if return_logits:
            shape = get_tensor_shape(x)
            if not use_2d:
                x = tf.reshape(x, [shape[0] * shape[1], shape[2]])
            logits = tf.matmul(x, self.embedding.word_embedding, transpose_b=True)
            if not use_2d:
                logits = tf.reshape(logits, [shape[0], shape[1], self.embedding.vocab_size])
            result = logits
        else:
            result = x
        if return_cache:
            return result, cache
        else:
            return result

    def __call__(self, inputs, cache=None,
                 dropout=None, attention_dropout=None,
                 return_cache=False, return_logits=True,
                 use_2d=False):
        """

        inputs: an integer tensor of shape [batch_size, seq_length]
        cache: a list of dictionaries {"key": key, "value": value} of previous keys and values. it uses for generation
        use_one_hot_keys: if True it uses one hot tensors for embedding layer.
        return_cache: if True returns new keys and values alongside output. it uses for generation.
        return_logits: if True, return logits, else return last layer embedding.
        use_2d: for tpu performances: use 2D tensors for operations and return the output in 2D shape: [batch_size * seq_length, -1]

        """
        return super().__call__(
            inputs=inputs,
            cache=cache,
            dropout=dropout,
            attention_dropout=attention_dropout,
            return_cache=return_cache,
            return_logits=return_logits,
            use_2d=use_2d
        )



In [10]:
config124M = {'n_embd': 768, 'n_vocab': 50257, 'n_ctx': 1024, 'n_layer': 12, 'n_head': 12}
# config = {'n_embd': 3, 'n_vocab': 10, 'n_ctx': 5, 'n_layer': 12, 'n_head': 4}
gpt2 = GPT2(name="mygpt2", config=config124M)
x=tf.constant([[1]])
gpt2.compile()

In [11]:
gpt2(x)

<tf.Tensor: shape=(1, 1, 50257), dtype=float32, numpy=
array([[[ 0.56809425, -0.62641966,  1.2129886 , ...,  0.8097295 ,
          0.3339309 ,  0.10142062]]], dtype=float32)>

In [12]:
gpt2.summary(expand_nested=True)


Model: "mygpt2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       multiple                  39383808  
                                                                 
 transformer (Transformer)   multiple                  85056000  
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| block_0 (Block)            multiple                  7087872  |
|                                                               |
| block_1 (Block)            multiple                  7087872  |
|                                                               |
| block_2 (Block)            multiple                  7087872  |
|                                                               |
| block_3 (Block)            multiple                  7087872  |
|                                                               |
| block_4 (Block)            multiple                  70878

In [13]:
def print_layer_structure(layer_or_model, level=0):
    """Recursively prints the structure of a Keras layer or Model."""
    indent = "  " * level
    # Print the current layer's name and class
    print(f"{indent}- {layer_or_model.name} ({type(layer_or_model).__name__})")

    # The .layers property lists layers that are children of the current one
    if hasattr(layer_or_model, 'layers') and layer_or_model.layers:
        for inner_layer in layer_or_model.layers:
            # Recursively call the function for nested components
            print_layer_structure(inner_layer, level + 1)
            
# Use the same 'model' object from the previous example
print("\n--- Manual Recursive Traversal ---")
print_layer_structure(gpt2)


--- Manual Recursive Traversal ---
- mygpt2 (GPT2)
  - embedding (Embedding)
  - transformer (Transformer)
    - block_0 (Block)
    - block_1 (Block)
    - block_2 (Block)
    - block_3 (Block)
    - block_4 (Block)
    - block_5 (Block)
    - block_6 (Block)
    - block_7 (Block)
    - block_8 (Block)
    - block_9 (Block)
    - block_10 (Block)
    - block_11 (Block)
    - layer_norm (LayerNormalization)


In [14]:
checkpoint = tf.train.Checkpoint(gpt2)
model_dir = 'tf_ckpts'
save_path = checkpoint.save(model_dir + "/ckpt")

In [15]:
# list all the variables in the model
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
for name, v in tf.train.list_variables(tf_ckpt_path):
    # print(name)
    pass

In [16]:
# load weights into a "params" dict
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
    # Initialize parameters dictionary with empty blocks for each layer
    params = {"blocks": [{} for _ in range(settings["n_layer"])]}

    # Iterate over each variable in the checkpoint
    for name, _ in tf.train.list_variables(ckpt_path):
        # Load the variable and remove singleton dimensions
        variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))

        # Process the variable name to extract relevant parts
        variable_name_parts = name.split("/")[1:]  # Skip the 'model/' prefix

        # Identify the target dictionary for the variable
        target_dict = params
        if variable_name_parts[0].startswith("h"):
            layer_number = int(variable_name_parts[0][1:])
            target_dict = params["blocks"][layer_number]

        # Recursively access or create nested dictionaries
        for key in variable_name_parts[1:-1]:
            target_dict = target_dict.setdefault(key, {})

        # Assign the variable array to the last key
        last_key = variable_name_parts[-1]
        target_dict[last_key] = variable_array

    return params


In [17]:
settings = {"n_layer": 12}

model_dir="ch05/01_main-chapter-code/gpt2/124M"
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)

In [18]:
# Let's study this "params" thing..
params.keys() # ['blocks', 'b', 'g', 'wpe', 'wte']
len(params["blocks"]) # 12
params_block0 = params["blocks"][0]
params_block0.keys() # ['attn', 'ln_1', 'ln_2', 'mlp']
params_block0_attn = params_block0['attn']
params_block0_attn.keys() # ['c_attn', 'c_proj']
params_block0_attn_c_attn = params_block0_attn['c_attn']
params_block0_attn_c_attn.keys() # ['b', 'w']
x = (params["blocks"][0]["attn"]["c_attn"])["w"]
x.shape # (768, 2304)
q_w, k_w, v_w = np.split((params["blocks"][0]["attn"]["c_attn"])["w"], 3, axis=-1)
q_w.shape # (768, 768)

(768, 768)

In [19]:
# params['wpe'].shape # (1024, 768) position embedding
params['wte'].shape # (50257, 768) token embedding, out_head.weight
#len(params['blocks']) 12
#params['blocks'][0].keys() # dict_keys(['attn', 'ln_1', 'ln_2', 'mlp'])
# params['b'].shape # (768,) final_norm.shift (beta)
# params['g'].shape # (768,) final_norm.scale (gamma)

(50257, 768)

In [20]:
gpt2.get_config()

{'name': 'mygpt2',
 'trainable': True,
 'dtype': 'float32',
 'config': {'n_embd': 768,
  'n_vocab': 50257,
  'n_ctx': 1024,
  'n_layer': 12,
  'n_head': 12}}

## Here we load all the weights!!

In [21]:
# a GPT Model has an Embedding layer and a Transformer Model
embedding_layer   = gpt2.embedding

# The Embedding Layer has word_embedding, position_embedding, initializer_range, embedding_size, vocab_size, max_position_length
embedding_layer.word_embedding     = params['wte'] # word_embedding: (50257, 768) self.vocab_size, self.embedding_size
embedding_layer.position_embedding = params['wpe'] # position_embedding: (1024, 768) max_position_length, self.embedding_size


In [22]:
transformer_layer = gpt2.transformer
blocks = []
for b in range(gpt2.get_config()['config']['n_layer']): # = transformer_layer.blocks_num (12)
  # A transformer_layer has a list of blocks
  block = transformer_layer.blocks[b]
  blocks.append(block)

  # Each Block Layer has an AttentionLayer and a MultiLayerPerceptron Layer 
  attn = block.attention

  # Each AttentionLayer has a LayerNormalization layer and a SelfAttentionLayer and a Dense ("Projection") layer
  layer_norm = attn.layer_norm
  layer_norm.beta = params["blocks"][b]["ln_1"]["b"] # Not real sure about these..
  layer_norm.gamma = params["blocks"][b]["ln_1"]["g"]
  self_attention = attn.self_attention

  # Each SelfAttentionLayer has a query_layer, a key_layer and a value_layer
  query_layer = self_attention.query_layer
  key_layer = self_attention.key_layer
  value_layer = self_attention.value_layer
  q_w, k_w, v_w = np.split((params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
  q_b, k_b, v_b = np.split((params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
  query_layer.set_weights([q_w, q_b])
  key_layer.set_weights([k_w, k_b])
  value_layer.set_weights([v_w, v_b])

  layer_proj = attn.projection
  layer_proj.set_weights([params["blocks"][b]["attn"]["c_proj"]["w"], params["blocks"][b]["attn"]["c_proj"]["b"]])

  mlp_layer = block.mlp
  # A MultiLayerPerceptron layer has a LayerNormalization layer and 2 Dense Layers
  mlp_layer_norm = mlp_layer.layer_norm
  mlp_layer_norm.beta = params["blocks"][b]["ln_2"]["b"] # Not real sure about these..
  mlp_layer_norm.gamma = params["blocks"][b]["ln_2"]["g"]
  mlp_perceptron = mlp_layer.perceptron
  mlp_perceptron.set_weights([params["blocks"][b]["mlp"]["c_fc"]["w"], params["blocks"][b]["mlp"]["c_fc"]["b"]])
  mlp_projection = mlp_layer.projection
  mlp_projection.set_weights([params["blocks"][b]["mlp"]["c_proj"]["w"], params["blocks"][b]["mlp"]["c_proj"]["b"]])
final_norm_layer = transformer_layer.layer_norm
final_norm_layer.beta = params["b"]
final_norm_layer.gamma = params["g"]

## Test each layer, comparing to corresponding LLMFS Pytorch layer 

In [23]:
# test embedding_layer
x_embedded = embedding_layer(tf.constant([[1]]))
#x_embedded # [ 2.1520e-02, -2.4603e-01,  5.0275e-02

In [24]:
# test layer_norm
x = np.ones((1, 768) , dtype=np.float32)

layer_norm0 = blocks[0].attention.layer_norm
# layer_norm0(x) # [-3.6773e-03,  2.7197e-02, -6.4041e-02
# layer_norm11 = blocks[11].attention.layer_norm
# layer_norm11(x) # [ 5.0957e-02,  5.3063e-03,  7.1952e-02

In [25]:
# test query_layer, key_layer, value_layer, layer_proj
x = tf.constant(np.ones((1, 768) , dtype=np.float32))
# block 0
query_layer0 = blocks[0].attention.self_attention.query_layer
#query_layer0(x) # [-1.3708e+01,  1.3385e+01,  1.4323e+01
key_layer0 = blocks[0].attention.self_attention.key_layer
# key_layer0(x) [ 1.8049e-01, -1.4381e-01,  6.2964e-01
value_layer0 = blocks[0].attention.self_attention.value_layer               
# value_layer0(x) # [-6.1687e-02, -1.3786e-01, -3.0145e-01
layer_proj0 = blocks[0].attention.projection
# layer_proj0(x) # [-9.7561e+00, -1.7296e+01, -6.7800e-01
# layer_proj0(value_layer0(key_layer0(query_layer0(x)))) # [-2.3273e+01, -7.9272e+02,  5.6245e+02

# block 11
query_layer11 = blocks[11].attention.self_attention.query_layer
# query_layer11(x) # [-5.4209e+00,  4.6236e+00,  4.5401e+00
key_layer11 = blocks[11].attention.self_attention.key_layer
# key_layer11(x) # [ 5.8911e+00, -3.3184e-01,  6.3656e-01
value_layer11 = blocks[11].attention.self_attention.value_layer               
# value_layer11(x) # [-1.2480e+00, -3.0783e+00,  5.9679e+00
layer_proj11 = blocks[11].attention.projection
#layer_proj11(x) # [-4.1535e-01,  2.1763e+00,  4.7958e-01

# layer_proj11(value_layer11(key_layer11(query_layer11(x)))) # [ 3.4414e+02,  4.9568e+02,  3.8639e+02



In [26]:
# test mlp_layer_norm, mlp_perceptron, mlp_projection
x = np.ones((1, 768) , dtype=np.float32)
mlp_layer_norm0 = blocks[0].mlp.layer_norm
# mlp_layer_norm0(x) # [ 4.2478e-02,  3.2627e-02,  4.4881e-03

mlp_perceptron0 = blocks[0].mlp.perceptron
mlp_projection0 = blocks[0].mlp.projection
# mlp_projection0(mlp_perceptron0(x)) # [-1.6735e+01, -6.9883e+00,  4.1138e+00

mlp_layer_norm11 = blocks[11].mlp.layer_norm
# mlp_layer_norm11(x) # [-1.9770e-03,  2.0055e-02,  3.8334e-02

mlp_perceptron11 = blocks[11].mlp.perceptron
mlp_projection11 = blocks[11].mlp.projection
#mlp_projection11(mlp_perceptron11(x)) # [ 1.3675e+01,  2.2839e+01, -1.7306e+01


In [27]:
x_trivial = tf.constant([[1, 2, 3]])
gpt2(x_trivial)

<tf.Tensor: shape=(1, 3, 50257), dtype=float32, numpy=
array([[[-32.901043, -31.202375, -34.662212, ..., -39.486706,
         -39.873116, -32.238663],
        [-55.52078 , -53.42854 , -56.476715, ..., -68.1539  ,
         -66.77086 , -58.600624],
        [-61.796875, -60.53862 , -59.55034 , ..., -75.32062 ,
         -72.77314 , -65.57065 ]]], dtype=float32)>

In [28]:
def generate_text_simple(model, idx, max_new_tokens, context_size):
    idx = tf.cast(idx, dtype=tf.int64)
    for i in range(max_new_tokens):
        idx_cond = idx
        # print("i=", i, "idx_cond=", idx_cond)
        logits = model(idx)
        logits = logits[:, -1, :]
        probas = tf.nn.softmax(logits, -1)
        idx_next = tf.argmax(logits)
        idx_next = tf.argmax(probas, -1)        
        # print("idx_next:", idx_next)
        idx_next_expanded = tf.expand_dims(idx_next, axis=0)
        # print("idx_next_expanded:", idx_next_expanded)
        idx = tf.concat((idx, idx_next_expanded), axis=1)
        # print("idx_next_expanded after concat:", idx_next_expanded)
        # print("idx               after concat:", idx)
    return idx

In [30]:
idx = tf.constant([[6109 , 3626, 6100,  345]])
token_ids = generate_text_simple(gpt2, idx=idx, max_new_tokens=5, context_size=256)
print("token_ids:", token_ids) # [6109, 3626, 6100,  345, 2651,   13,  198,  198,  464]

token_ids: tf.Tensor([[6109 3626 6100  345 2651   13  198  198  464]], shape=(1, 9), dtype=int64)


In [31]:
idx = tf.constant([[6109 , 3626, 6100,  345]])
print(idx)
# print(idx[:, -1, :])
# generate(gpt2, idx=idx, max_new_tokens=3, context_size=256)
logits = gpt2(idx)
print("logits:\n", logits)
logits = logits[:, -1, :]
print("logits:\n", logits)
print("logits[-1]:\n", logits[-1])
probas = tf.nn.softmax(logits, -1)
print("probas:\n", probas)
idx_next = tf.argmax(probas, -1)
print("idx_next:\n", idx_next)

tf.Tensor([[6109 3626 6100  345]], shape=(1, 4), dtype=int32)
logits:
 tf.Tensor(
[[[ -35.582027  -34.980392  -38.452198 ...  -42.095932  -41.85325
    -35.596558]
  [ -76.96017   -76.697     -81.9309   ...  -88.79839   -86.763176
    -78.96271 ]
  [-125.34871  -126.2704   -135.09479  ... -132.31728  -135.2544
   -127.65115 ]
  [-136.60023  -137.38039  -146.55562  ... -148.29782  -147.21555
   -139.56773 ]]], shape=(1, 4, 50257), dtype=float32)
logits:
 tf.Tensor([[-136.60023 -137.38039 -146.55562 ... -148.29782 -147.21555 -139.56773]], shape=(1, 50257), dtype=float32)
logits[-1]:
 tf.Tensor([-136.60023 -137.38039 -146.55562 ... -148.29782 -147.21555 -139.56773], shape=(50257,), dtype=float32)
probas:
 tf.Tensor(
[[1.6012504e-03 7.3391170e-04 7.6013585e-08 ... 1.3312578e-08
  3.9290576e-08 8.2355182e-05]], shape=(1, 50257), dtype=float32)
idx_next:
 tf.Tensor([2651], shape=(1,), dtype=int64)


In [32]:

tokenizer = tiktoken.get_encoding("gpt2")
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    # encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    encoded_tensor = tf.constant(encoded) #.unsqueeze(0) # add batch dimension
    encoded_tensor = tf.expand_dims(encoded_tensor, axis=0)
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    # flat = token_ids.squeeze(0) # remove batch dimension
    # return tokenizer.decode(flat.tolist())
    return tokenizer.decode(token_ids[-1])

token_ids = text_to_token_ids("Every effort moves you", tokenizer)
print("token_ids:", token_ids)
print(token_ids_to_text(token_ids, tokenizer))

token_ids: tf.Tensor([[6109 3626 6100  345]], shape=(1, 4), dtype=int32)
Every effort moves you


In [34]:
start_context = "Every effort moves you"
token_ids = generate_text_simple(
    model=gpt2,
    idx=text_to_token_ids(start_context, tokenizer),
    max_new_tokens=10,
    context_size=256
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Every effort moves you forward.

The first step is to understand
