# Implement GPT-2 Using keras.ops and Convert PyTorch Weights to Keras Format

Import Libraries

In [304]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras import backend as K
import keras.ops as ops
import torch
from transformers import GPT2LMHeadModel, GPT2Model, GPT2Tokenizer
from numpy import dot
from numpy.linalg import norm

## Implement GPT2 using Keras.ops

**Linear Layer**

In [305]:
class CustomDense(keras.layers.Layer):
  def __init__(self, units, activation=None, name=None, **kwargs):
    super().__init__(name=name, **kwargs)
    self.units = units
    self.activation = keras.activations.get(activation)

  def build(self, input_shape):
    input_dim = input_shape[-1]

    self.kernel = self.add_weight(
        shape=(input_dim, self.units),
        initializer="glorot_uniform",
        trainable=True,
        name="kernel"
    )

    self.bias = self.add_weight(
        shape=(self.units,),
        initializer="zeros",
        trainable=True,
        name="bias"
    )

    self.built = True

  def call(self, inputs):
    outputs = ops.matmul(inputs, self.kernel)
    outputs = ops.add(outputs, self.bias)

    if self.activation is not None:
      outputs = self.activation(outputs)

    return outputs

**Embedding Layer**

In [306]:
class CustomEmbedding(keras.layers.Layer):
  def __init__(self, input_dim, output_dim, embeddings_initializer="uniform", mask_zero=False, name=None, **kwargs):
    super().__init__(name=name, **kwargs)
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.embeddings_initializer = keras.initializers.get(embeddings_initializer)
    self.mask_zero = mask_zero
    self.supports_masking = self.mask_zero

  def build(self, input_shape):
    # here input_dim is vocab size and output dim is embedding dim
    self.embeddings = self.add_weight(
        shape=(self.input_dim, self.output_dim),
        initializer=self.embeddings_initializer,
        name="embeddings"
    )
    self.built=True

  def call(self, inputs):
    # embeddings has shape of (vocab_szie, embedding_dim).
    # u can think of each row represent a certain word.
    outputs = keras.ops.take(self.embeddings, inputs, axis=0)
    return outputs

  def get_config(self):
    config = super().get_config()
    config.update({
        "input_dim": self.input_dim,
        "output_dim": self.output_dim,
        "embeddings_initializer": keras.initializers.serialize(self.embeddings_initializer),
        "mask_zero": self.mask_zero,
    })
    return config

**MHA**

In [307]:
class CustomMHA(keras.layers.Layer):
  def __init__(self, d_model, num_heads, weight_initializer="glorot_uniform", causal=False, name=None, **kwargs):
    super().__init__(name=name, **kwargs)
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads
    self.causal = causal
    self.weight_initializer = keras.initializers.get(weight_initializer)



  def build(self, input_shape):
    self.W_qkv = self.add_weight(shape=(self.d_model, 3*self.d_model), initializer=self.weight_initializer, name="W_qkv")
    self.b_qkv = self.add_weight(shape=(3*self.d_model,), initializer="zeros", name="b_qkv")
    self.W_o = self.add_weight(shape=(self.d_model, self.d_model), initializer=self.weight_initializer, name="W_o")
    self.b_o = self.add_weight(shape=(self.d_model,), initializer="zeros", name="b_o")


  def call(self, x):
    qkv = ops.matmul(x, self.W_qkv) # (N, T, D) @ (D, 3D) => (N, T, 3D)
    qkv = ops.add(qkv, self.b_qkv) # gptw had bais term (N, T, 3D)
    q, k, v = ops.split(qkv, 3, axis=-1)
    # q: (N, T, D)
    # v: (N, T, D)
    # v: (N, T, D)

    new_shape = (ops.shape(q)[0], ops.shape(q)[1], self.num_heads, self.d_k)
    q = ops.reshape(q, new_shape) # (N, T, num_heads, d_k)
    k = ops.reshape(k, new_shape) # (N, T, num_heads, d_k)
    v = ops.reshape(v, new_shape) # (N, T, num_heads, d_k)

    q = ops.transpose(q, (0, 2, 1, 3)) # (N, num_heads, T, d_k)
    k = ops.transpose(k, (0, 2, 1, 3)) # (N, num_heads, T, d_k)
    v = ops.transpose(v, (0, 2, 1, 3)) # (N, num_heads, T, d_k)


    dk_float = ops.cast(self.d_k, "float32")
    scale = ops.sqrt(dk_float)


    # (N, num_heads, T, d_k) @ (N, num_heads, d_k, T) => (N, num_heads, T, T)
    logits = ops.matmul(q, ops.transpose(k, (0, 1, 3, 2)))
    logits = logits / scale # (N, num_heads, T, T)


    if self.causal:
      seq_len = ops.shape(logits)[-1]
      causal_mask = ops.tril(ops.ones((seq_len, seq_len)))
      logits = logits + (1.0 - causal_mask) * -1e9
      '''
      the mask we're adding is:
      [
        0 -e9 -e9
        0 0 -e9
        0 0 0
      ]
      '''

    # (N, num_heads, T, T)
    weights = ops.softmax(logits, axis=-1)
    attn_out = ops.matmul(weights, v)
    # (N, num_heads, T, d_v)

    attn_out = ops.transpose(attn_out, (0, 2, 1, 3))
    out_shape = (ops.shape(attn_out)[0], ops.shape(attn_out)[1], self.num_heads * self.d_k)
    attn_out = ops.reshape(attn_out, out_shape)
    # (N, T, d_model)

    attn_out = ops.matmul(attn_out, self.W_o)
    # (N, T, d_model)
    attn_out = ops.add(attn_out, self.b_o)
    # (N, T, d_model)

    return attn_out

**FeedForwardNetwork**

In [308]:
class FeedFoward(keras.layers.Layer):
  def __init__(self, d_model, **kwargs):
    super().__init__(**kwargs)
    self.d_model = d_model


  def build(self, input_shape):
    # (N, T, d_model
    self.ff1 = CustomDense(units=self.d_model * 4, activation=keras.activations.gelu)
    # (N, T, 4*d_model)
    self.ff2 = CustomDense(units=self.d_model, activation=None)
    # (N, T, d_model)

  def call(self, inputs):
    inputs = self.ff1(inputs)
    inputs = self.ff2(inputs)
    return inputs

**Decoder Layer**

In [321]:
class TransformerDecoderLayer(keras.layers.Layer):
  def __init__(self, d_model, num_heads, weight_initializer="uniform", causal=True, name=None, **kwargs):
    super().__init__(name=name, **kwargs)
    self.d_model = d_model
    self.num_heads = num_heads
    self.causal = causal

    self.self_attention = CustomMHA(d_model, num_heads, weight_initializer=weight_initializer, causal=causal)
    self.feed_forward = FeedFoward(d_model)

    self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)


  def call(self, x):
    x = self.layernorm1(x)
    attn = self.self_attention(x)
    x = ops.add(x, attn)

    x = self.layernorm2(x)
    ff_output = self.feed_forward(x)
    x = ops.add(x, ff_output)



    return x

**GPT2**

In [322]:
class TransformerDecoder(keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, vocab_size, max_pos_encoding=1024, name=None, **kwargs):
    super().__init__(name=name, **kwargs)
    self.d_model = d_model
    self.num_layers = num_layers
    self.final_ln = keras.layers.LayerNormalization(epsilon=1e-6, name="ln_f")
    self.embedding = CustomEmbedding(input_dim=vocab_size, output_dim=d_model)
    self.pos_embedding = self.add_weight(shape=(max_pos_encoding, self.d_model), initializer="uniform", trainable=True, name="pos_embedding")
    self.dec_layers = [TransformerDecoderLayer(d_model, num_heads, causal=True, name=f"decoder_layer_{i}") for i in range(num_layers)]
    self.dropout = keras.layers.Dropout(0.1)


  def call(self, x):
    seq_len = x.shape[1]
    x = self.embedding(x)
    pos_embeds = self.pos_embedding[:seq_len, :]
    pos_embeds = ops.broadcast_to(pos_embeds, ops.shape(x))
    x = ops.add(x, pos_embeds)
    x = self.dropout(x)

    for i in range(self.num_layers):
      x = self.dec_layers[i](x)

    x = self.final_ln(x)
    final_logits = ops.matmul(x, ops.transpose(self.embedding.embeddings, (1, 0)))

    return final_logits

In [323]:
vocab_size = 50257
d_model = 768
num_heads = 12
num_layers = 12
seq_len = 512
batch_size = 1

# check whether the output of keras-gpt2 has correct shape.
gpt2 = TransformerDecoder(num_layers, d_model, num_heads, vocab_size)
sample_input = np.zeros((batch_size, seq_len), dtype='int32')
sample_output = gpt2(sample_input)
print(sample_output.shape) # (B, T, D)

(1, 512, 50257)


## Check Weight Keys

In [324]:
count = 0
for weight in gpt2.weights:
  print(weight.name)
  count += 1

print(f"{count}")

pos_embedding
gamma
beta
embeddings
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
W_qkv
b_qkv
W_o
b_o
kernel
bias
kernel
bias
gamma
beta
gamma
beta
148


In [325]:
model_name = "gpt2"
hf_model = GPT2Model.from_pretrained(model_name)
hf_model.train() # this is to ensure dropout is also enabled for hf gpt2
hf_state_dict = hf_model.state_dict()
numpy_state_dict = {k: v.cpu().numpy() for k, v in hf_state_dict.items()}

In [326]:
count = 0
for key in numpy_state_dict.keys():
  count += 1
  print(key)

print(f"total num of keys: {count}")

wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias
h.1.ln_1.weight
h.1.ln_1.bias
h.1.attn.c_attn.weight
h.1.attn.c_attn.bias
h.1.attn.c_proj.weight
h.1.attn.c_proj.bias
h.1.ln_2.weight
h.1.ln_2.bias
h.1.mlp.c_fc.weight
h.1.mlp.c_fc.bias
h.1.mlp.c_proj.weight
h.1.mlp.c_proj.bias
h.2.ln_1.weight
h.2.ln_1.bias
h.2.attn.c_attn.weight
h.2.attn.c_attn.bias
h.2.attn.c_proj.weight
h.2.attn.c_proj.bias
h.2.ln_2.weight
h.2.ln_2.bias
h.2.mlp.c_fc.weight
h.2.mlp.c_fc.bias
h.2.mlp.c_proj.weight
h.2.mlp.c_proj.bias
h.3.ln_1.weight
h.3.ln_1.bias
h.3.attn.c_attn.weight
h.3.attn.c_attn.bias
h.3.attn.c_proj.weight
h.3.attn.c_proj.bias
h.3.ln_2.weight
h.3.ln_2.bias
h.3.mlp.c_fc.weight
h.3.mlp.c_fc.bias
h.3.mlp.c_proj.weight
h.3.mlp.c_proj.bias
h.4.ln_1.weight
h.4.ln_1.bias
h.4.attn.c_attn.weight
h.4.attn.c_at

## Load HuggingFace Weights into Keras GPT-2 Model

In [327]:
def load_hf_weights_into_keras_model(gpt2_keras, hf_weights):
  # embeddings
  gpt2_keras.embedding.embeddings.assign(hf_weights["wte.weight"])
  gpt2_keras.pos_embedding.assign(hf_weights["wpe.weight"])

  # decoder layers
  for i, decoder_layer in enumerate(gpt2_keras.dec_layers):
    # LN
    decoder_layer.layernorm1.gamma.assign(hf_weights[f"h.{i}.ln_1.weight"])
    decoder_layer.layernorm1.beta.assign(hf_weights[f"h.{i}.ln_1.bias"])

    decoder_layer.layernorm2.gamma.assign(hf_weights[f"h.{i}.ln_2.weight"])
    decoder_layer.layernorm2.beta.assign(hf_weights[f"h.{i}.ln_2.bias"])

    # MHA
    decoder_layer.self_attention.W_qkv.assign(hf_weights[f"h.{i}.attn.c_attn.weight"])
    decoder_layer.self_attention.b_qkv.assign(hf_weights[f"h.{i}.attn.c_attn.bias"])
    decoder_layer.self_attention.W_o.assign(hf_weights[f"h.{i}.attn.c_proj.weight"])
    decoder_layer.self_attention.b_o.assign(hf_weights[f"h.{i}.attn.c_proj.bias"])

    # FFN
    decoder_layer.feed_forward.ff1.kernel.assign(hf_weights[f"h.{i}.mlp.c_fc.weight"])
    decoder_layer.feed_forward.ff1.bias.assign(hf_weights[f"h.{i}.mlp.c_fc.bias"])
    decoder_layer.feed_forward.ff2.kernel.assign(hf_weights[f"h.{i}.mlp.c_proj.weight"])
    decoder_layer.feed_forward.ff2.bias.assign(hf_weights[f"h.{i}.mlp.c_proj.bias"])

  # final layer norm layer
  gpt2_keras.final_ln.gamma.assign(hf_weights["ln_f.weight"])
  gpt2_keras.final_ln.beta.assign(hf_weights["ln_f.bias"])

In [328]:
load_hf_weights_into_keras_model(gpt2, numpy_state_dict)

## Check Cosine Similarity between HF GPT2 and Keras GPT2 outputs

In [329]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
random_input_ids = np.random.randint(0, tokenizer.vocab_size, (1, 10))
torch_input = torch.tensor(random_input_ids)
print("Random input tokens (IDs):", random_input_ids)

Random input tokens (IDs): [[ 6374  1678 33827 16198  9914 27890 22299 43585 43689 42557]]


In [330]:
hf_model = GPT2LMHeadModel.from_pretrained(model_name)
hf_output = hf_model(torch_input).logits
print(hf_output.shape)

torch.Size([1, 10, 50257])


In [331]:
tf_input = tf.convert_to_tensor(random_input_ids)
keras_output = gpt2(tf_input).numpy()
print(keras_output.shape)

(1, 10, 50257)


In [334]:
hf_output_np = hf_output.detach().cpu().numpy()

cos_sim = dot(hf_output_np.flatten(), keras_output.flatten()) / (norm(hf_output_np.flatten()) * norm(keras_output.flatten()))
print(f"Cosine Similarity between HF GPT2 and Keras GPT2 outputs: {cos_sim}")


Cosine Similarity between HF GPT2 and Keras GPT2 outputs: 0.980139970779419
