<a href="https://colab.research.google.com/github/pranukrish/CMPE297-SpecialTopics/blob/main/Assignment3/NanoGPT_Jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q jax jaxlib flax optax tensorflow tensorflow-datasets

In [None]:
import jax
import flax.linen as nn
import optax
import tensorflow_datasets as tfds
import numpy as np
from flax.training import train_state
from jax import random

In [None]:
# Define the model
class NanoGPT(nn.Module):
    vocab_size: int
    d_model: int = 128

    def setup(self):
        self.embedding = nn.Embed(dimension=self.d_model)
        self.transformer = nn.SelfAttention(
            kernel_size=3,
            features=self.d_model,
            use_bias=True,
            deterministic=False,
            name=None,
            dtype=jnp.float32
        )
        self.fc = nn.Dense(self.vocab_size)

    def __call__(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        x = self.fc(x)
        return x

# Tokenization and preprocessing
tokenizer = tfds.deprecated.text.Tokenizer()

# Build vocabulary
vocabulary = set()
for text, _ in tfds.load('imdb_reviews', split='train', as_supervised=True):
    vocabulary.update(tokenizer.tokenize(text.numpy().lower()))

# Encoder
encoder = tfds.deprecated.text.TokenTextEncoder(vocabulary)

# Encode data
def encode(text_tensor, _):
    return encoder.encode(text_tensor.numpy())

def encode_map_fn(text, label):
    return tf.py_function(encode, inp=[text, label], Tout=(tf.int64))

train_data = tfds.load('imdb_reviews', split='train', as_supervised=True).map(encode_map_fn)
test_data = tfds.load('imdb_reviews', split='test', as_supervised=True).map(encode_map_fn)

# Convert to JAX arrays and pad sequences
MAX_LENGTH = 1000
train_data = jax.tree_map(lambda x: np.pad(x, (0, MAX_LENGTH - len(x)), 'constant'), np.array(list(train_data.as_numpy_iterator())))
test_data = jax.tree_map(lambda x: np.pad(x, (0, MAX_LENGTH - len(x)), 'constant'), np.array(list(test_data.as_numpy_iterator())))

# Hyperparameters
VOCAB_SIZE = len(vocabulary) + 1
D_MODEL = 128
LR = 0.001
EPOCHS = 5
BATCH_SIZE = 64

# Training setup
rng = random.PRNGKey(0)
model = NanoGPT(vocab_size=VOCAB_SIZE)
params = model.init(rng, jnp.ones((BATCH_SIZE, MAX_LENGTH), jnp.int32))
tx = optax.adam(LR)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Loss function
def compute_loss(params, batch):
    logits = model.apply(params, batch)
    return -jnp.mean(logits * jnp.log(batch))

# Training step
def train_step(state, batch):
    def loss_fn(params):
        return compute_loss(params, batch)

    grad = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grad)

# Training loop
for epoch in range(EPOCHS):
    for i in range(0, len(train_data) - BATCH_SIZE + 1, BATCH_SIZE):
        batch = train_data[i:i+BATCH_SIZE]
        state = train_step(state, batch)
    # (Optional) Evaluate on test data and print metrics
