Goal
-Train a small transformer-based language model (MiniGPT) using JAX + Flax and optimize its performance.

1️⃣ Install Dependencies

In [1]:
!pip install -q jax flax optax datasets transformers


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.2/69.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m274.9/274.9 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

2️⃣ Import Required Libraries

In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from transformers import AutoTokenizer
from datasets import load_dataset


3️⃣ Define a Simple Transformer Model (MiniGPT)

In [3]:
class MiniGPT(nn.Module):
    vocab_size: int
    embed_dim: int
    num_heads: int
    num_layers: int
    hidden_dim: int

    def setup(self):
        self.token_embedding = nn.Embed(self.vocab_size, self.embed_dim)
        self.transformer_layers = [
            nn.SelfAttention(num_heads=self.num_heads)
            for _ in range(self.num_layers)
        ]
        self.dense = nn.Dense(self.vocab_size)

    def __call__(self, x):
        x = self.token_embedding(x)
        for layer in self.transformer_layers:
            x = layer(x)
        return self.dense(x)


4️⃣ Load and Tokenize a Text Dataset

In [6]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Step 1: Load the dataset (Example: Using 'wikitext' dataset)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Step 2: Load the tokenizer (Example: GPT-2)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Step 3: Set a padding token if the tokenizer doesn't have one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Use EOS token as padding

# Step 4: Define a function to tokenize the text
def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",   # Pad to max length
        truncation=True,        # Truncate longer sequences
        max_length=128,         # Set maximum token length
        return_tensors="np"     # Convert to NumPy format (for JAX)
    )

# Step 5: Tokenize the dataset
tokenized_dataset = dataset.map(tokenize_fn, batched=True)

# Step 6: Print an example
print(tokenized_dataset["train"][0])  # Check the first tokenized example



tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

{'text': '', 'input_ids': [50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256], 'attention_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

Full training model

In [10]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from datasets import load_dataset
from transformers import AutoTokenizer
from flax.training import train_state

# ✅ Step 1: Load Dataset
dataset = load_dataset("imdb")  # Example dataset (IMDb reviews)

# ✅ Step 2: Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# ✅ Step 3: Define Tokenization Function
def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="np"  # Ensure NumPy format for JAX
    )

# ✅ Step 4: Apply Tokenization to Dataset
tokenized_dataset = dataset.map(tokenize_fn, batched=True)

# ✅ Step 5: Define a Simple Model in Flax
class MiniGPT(nn.Module):
    vocab_size: int

    @nn.compact
    def __call__(self, input_ids):
        embed = nn.Embed(self.vocab_size, 128)(input_ids)  # Embedding layer
        x = nn.Dense(256)(embed)  # Dense layer
        x = nn.relu(x)
        x = nn.Dense(self.vocab_size)(x)  # Output layer
        return x

# ✅ Step 6: Initialize Model
rng = jax.random.PRNGKey(0)
model = MiniGPT(vocab_size=len(tokenizer))

# ✅ Step 7: Create Model Parameters
params = model.init(rng, jnp.ones((1, 128), dtype=jnp.int32))  # Dummy input

# ✅ Step 8: Define Loss Function
def loss_fn(params, batch):
    logits = model.apply(params, batch)
    labels = batch  # For simplicity, using input as labels
    loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, len(tokenizer))))
    return loss

# ✅ Step 9: Optimizer & Training State
tx = optax.adam(learning_rate=1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# ✅ Step 10: Training Step
@jax.jit
def train_step(state, batch):
    def loss_fn_wrapper(params):
        return loss_fn(params, batch)
    grad_fn = jax.value_and_grad(loss_fn_wrapper)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ✅ Step 11: Train the Model
for epoch in range(3):
    for batch in tokenized_dataset["train"].shuffle(seed=42).select(range(100)):  # Small batch for testing
        input_ids = jnp.array(batch["input_ids"], dtype=jnp.int32)  # Ensure integer input
        state, loss = train_step(state, input_ids)
    print(f"Epoch {epoch+1} - Loss: {loss:.4f}")

print("Training Completed Successfully ✅🚀")


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Epoch 1 - Loss: 4.8063
Epoch 2 - Loss: 3.3170
Epoch 3 - Loss: 2.4390
Training Completed Successfully ✅🚀


📌 Optimized Code for Training on GPU/TPU

In [11]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from datasets import load_dataset
from transformers import AutoTokenizer
from flax.training import train_state

# ✅ Step 1: Load Dataset
dataset = load_dataset("imdb")  # Example: IMDb reviews dataset

# ✅ Step 2: Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# ✅ Step 3: Tokenization Function (Efficient with Padding)
def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="np"
    )

# ✅ Step 4: Apply Tokenization to Dataset
tokenized_dataset = dataset.map(tokenize_fn, batched=True)

# ✅ Step 5: Convert to JAX Tensors for Faster Processing
def format_data(batch):
    return {
        "input_ids": jnp.array(batch["input_ids"], dtype=jnp.int32),
        "attention_mask": jnp.array(batch["attention_mask"], dtype=jnp.float32)
    }

tokenized_dataset.set_format("numpy")
train_data = tokenized_dataset["train"].map(format_data)
train_data = list(train_data)

# ✅ Step 6: Define a Transformer Model in Flax
class MiniGPT(nn.Module):
    vocab_size: int

    @nn.compact
    def __call__(self, input_ids, attention_mask):
        embed = nn.Embed(self.vocab_size, 128)(input_ids)
        x = nn.Dense(256)(embed)
        x = nn.relu(x)
        x = nn.Dense(self.vocab_size)(x)
        return x

# ✅ Step 7: Initialize Model
rng = jax.random.PRNGKey(0)
model = MiniGPT(vocab_size=len(tokenizer))

# ✅ Step 8: Create Model Parameters
params = model.init(rng, jnp.ones((1, 128), dtype=jnp.int32), jnp.ones((1, 128)))

# ✅ Step 9: Define Loss Function
def loss_fn(params, batch):
    logits = model.apply(params, batch["input_ids"], batch["attention_mask"])
    labels = batch["input_ids"]
    loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, len(tokenizer))))
    return loss

# ✅ Step 10: Optimizer & Training State
tx = optax.adamw(learning_rate=3e-4)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# ✅ Step 11: Training Step (JIT-Compiled for GPU/TPU)
@jax.jit
def train_step(state, batch):
    def loss_fn_wrapper(params):
        return loss_fn(params, batch)
    grad_fn = jax.value_and_grad(loss_fn_wrapper)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ✅ Step 12: Train the Model (Efficient Batching)
batch_size = 32
epochs = 3

for epoch in range(epochs):
    total_loss = 0.0
    num_batches = len(train_data) // batch_size

    for i in range(num_batches):
        batch = {
            "input_ids": jnp.stack([train_data[j]["input_ids"] for j in range(i * batch_size, (i + 1) * batch_size)]),
            "attention_mask": jnp.stack([train_data[j]["attention_mask"] for j in range(i * batch_size, (i + 1) * batch_size)])
        }

        state, loss = train_step(state, batch)
        total_loss += loss

        if (i + 1) % 10 == 0:  # Print every 10 batches
            print(f"Epoch {epoch+1}, Batch {i+1}/{num_batches} - Loss: {loss:.4f}")

    print(f"Epoch {epoch+1} - Avg Loss: {total_loss / num_batches:.4f}")

print("🚀 Training Completed Successfully!")


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Epoch 1, Batch 10/781 - Loss: 10.2465
Epoch 1, Batch 20/781 - Loss: 10.1595
Epoch 1, Batch 30/781 - Loss: 10.0543
Epoch 1, Batch 40/781 - Loss: 9.9503
Epoch 1, Batch 50/781 - Loss: 9.8389
Epoch 1, Batch 60/781 - Loss: 9.6983
Epoch 1, Batch 70/781 - Loss: 9.4948
Epoch 1, Batch 80/781 - Loss: 9.2692
Epoch 1, Batch 90/781 - Loss: 9.0719
Epoch 1, Batch 100/781 - Loss: 8.7522
Epoch 1, Batch 110/781 - Loss: 8.4790
Epoch 1, Batch 120/781 - Loss: 8.0079
Epoch 1, Batch 130/781 - Loss: 7.5516
Epoch 1, Batch 140/781 - Loss: 7.1254
Epoch 1, Batch 150/781 - Loss: 6.3764
Epoch 1, Batch 160/781 - Loss: 5.4394
Epoch 1, Batch 170/781 - Loss: 4.9734
Epoch 1, Batch 180/781 - Loss: 4.5680
Epoch 1, Batch 190/781 - Loss: 4.0754
Epoch 1, Batch 200/781 - Loss: 3.7293
Epoch 1, Batch 210/781 - Loss: 3.5789
Epoch 1, Batch 220/781 - Loss: 3.3354
Epoch 1, Batch 230/781 - Loss: 3.2881
Epoch 1, Batch 240/781 - Loss: 3.1055
Epoch 1, Batch 250/781 - Loss: 2.7363
Epoch 1, Batch 260/781 - Loss: 2.7166
Epoch 1, Batch 270

🚀 Performance Optimizations
🔹 JAX Array Preloading: Converts dataset to JAX arrays before training.
🔹 Mini-batching (Batch Size = 32): Efficiently processes multiple samples at once.
🔹 AdamW Optimizer (3e-4 LR): Faster and stable convergence.
🔹 JIT-Compilation (@jax.jit): Significantly speeds up training on TPU/GPU.
🔹 Softmax Cross-Entropy Loss: More stable training for text data.