In [None]:
%pip install --upgrade jax flax optax orbax-checkpoint transformers --quiet

In [None]:
# preferred
!pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


In [None]:
!nvidia-smi
!nvcc --version || true

Fri Oct  3 11:06:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   53C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import jax.numpy as jnp
def tokenize(word_list, max_len):
    word_set = set()
    for word in word_list:
        word_set.update(word.split(' '))
    word_set.add('<EOS>')
    word_set = sorted(word_set)
    word_to_idx = {word: index for index, word in enumerate(word_set)}
    idx_to_word = {index: word for index, word in enumerate(word_set)}

    dataset = []
    for word in word_list:
        tokens = [word_to_idx[word_single] for word_single in word.split(' ')]
        tokens += [word_to_idx['<EOS>']] * (max_len - len(tokens))
        dataset.append(tokens)
    return jnp.array(dataset)

In [None]:
sentences = [
    "the cat sat on the mat",
    "the dog ate my homework",
    "jax is a high performance numerical computing library",
    "transformers are powerful models for sequence processing",
    "a decoder only model predicts the next token in a sequence",
    "flax is a neural network library for jax",
    "we will train this model from scratch",
    "attention is all you need"
]

In [None]:
g = 0
for i in sentences:
    g = max(g, len(i.split(' ')))

In [None]:
from flax import nnx

class Transformer(nnx.Module):
    def __init__(self, max_lim, num_heads, in_features, attention_dropout, ffn_dropout:float, rngs: nnx.Rngs):
        super().__init__()
        self.layer_norm_0 = nnx.LayerNorm(in_features, rngs=rngs, param_dtype=jnp.float32) # Changed to float32
        # self.q_w = nnx.Linear(in_features, in_features, rngs=rngs
        # self.k_w = nnx.Linear(in_features, in_features, rngs=rngs)
        # self.v_w = nnx.Linear(in_features, in_features, rngs=rngs)
        # nnx.split_rngs(rngs, 6)
        self.attention = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features = in_features,
            dropout_rate=attention_dropout,
            rngs = rngs,
            param_dtype=jnp.float32 # Changed to float32

        )
        self.layer_norm = nnx.LayerNorm(num_features=in_features, rngs = rngs, param_dtype=jnp.float32) # Changed to float32
        self.linear_1 = nnx.Linear(
            in_features=in_features,
            out_features=in_features * 2,
            rngs = rngs,
            param_dtype=jnp.float32 # Changed to float32
        )
        self.dropout_1 = nnx.Dropout(rate=ffn_dropout, rngs=rngs)
        self.linear_2 = nnx.Linear(
            in_features=in_features * 2,
            out_features=in_features * 4,
            rngs = rngs,
            param_dtype=jnp.float32 # Changed to float32
        )
        self.dropout_2 = nnx.Dropout(rate=ffn_dropout, rngs=rngs)
        self.linear_3 = nnx.Linear(
            in_features=in_features * 4,
            out_features=in_features,
            rngs = rngs,
            param_dtype=jnp.float32 # Changed to float32
        )
        self.layer_norm_2 = nnx.LayerNorm(num_features=in_features, rngs = rngs, param_dtype=jnp.float32) # Changed to float32

    def __call__(self, x, train: bool = True, decode: bool = False):
        q_value = self.layer_norm_0(x)
        # k_value = self.k_w(x)
        # v_value = self.v_w(x)

        output_attention = x + self.attention(q_value, deterministic= not train, decode = decode)

        linear_out = nnx.gelu(self.dropout_1(self.linear_1(self.layer_norm(output_attention)), deterministic = not train))
        linear_out = nnx.gelu(self.dropout_2(self.linear_2(linear_out), deterministic = not train))
        linear_out = nnx.gelu(self.linear_3(linear_out))

        return output_attention + self.layer_norm_2(linear_out)

In [None]:
class Decoder(nnx.Module):
  def __init__(self, vocab_size: int, max_len: int, in_features: int, attention_dropout: float, ffn_dropout: float, num_transformers: int, num_head: int, rngs: nnx.Rngs):
    super().__init__()
    self.token_embed = nnx.Embed(vocab_size, features=in_features, rngs = rngs, param_dtype=jnp.float32) # Changed to float32
    self.pos_embedding = nnx.Embed(max_len, features=in_features, rngs = rngs, param_dtype=jnp.float32) # Changed to float32
    self.transformer = nnx.Sequential(*[Transformer(max_len, num_head, in_features, attention_dropout, ffn_dropout, rngs) for _ in range(num_transformers)])
    self.linear = nnx.Linear(in_features, vocab_size, rngs = rngs, param_dtype=jnp.float32) # Changed to float32

  def __call__(self, x, train: bool = True, decode: bool = False):
    seq_len = x.shape[1]
    positions = jnp.arange(0, seq_len)
    embedded_x = self.token_embed(x) + self.pos_embedding(positions)
    transformer_output = self.transformer(embedded_x, train=train, decode=decode)
    logits = self.linear(transformer_output)

    return logits

In [None]:
import jax
import jax.numpy as jnp
from flax import nnx
from typing import Optional, Dict, Any, Tuple

class MoE(nnx.Module):
    """
    Simple MoE layer:
      - num_experts experts (each a small MLP)
      - top_k routing (defaults to top-1)
      - returns (moe_out, aux_loss) where aux_loss is a scalar load-balancing term
    """
    def __init__(
        self,
        num_experts: int,
        in_features: int,
        expert_hidden: Optional[int] = None,
        top_k: int = 1,
        dropout_rate: float = 0.0,
        rngs: nnx.Rngs = None,
    ):
        super().__init__()
        expert_hidden = expert_hidden or (in_features * 4)
        self.num_experts = num_experts
        self.in_features = in_features
        self.expert_hidden = expert_hidden
        self.top_k = top_k

        # gating network: project token -> num_experts logits
        self.gate = nnx.Linear(in_features=in_features, out_features=num_experts, rngs=rngs, param_dtype=jnp.float32)

        # create experts: each expert is a 2-layer MLP (in -> hidden -> in)
        # We use a list of nnx.Sequential to keep each expert isolated.
        self.experts = nnx.List([
            nnx.Sequential(
                nnx.Linear(in_features=in_features, out_features=expert_hidden, rngs=rngs, param_dtype=jnp.float32),
                nnx.gelu,
                nnx.Dropout(rate=dropout_rate, rngs=rngs),
                nnx.Linear(in_features=expert_hidden, out_features=in_features, rngs=rngs, param_dtype=jnp.float32),
                nnx.Dropout(rate=dropout_rate, rngs=rngs),
            )
            for _ in range(num_experts)
        ])

    def __call__(self, x: jnp.ndarray, train: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        x: [batch, seq_len, in_features]
        returns:
          - out: [batch, seq_len, in_features]
          - aux_loss: scalar (load balancing)
        """
        batch, seq_len, dim = x.shape
        assert dim == self.in_features, f"MoE expected in_features={self.in_features}, got {dim}"

        # Gate logits per token -> shape [batch, seq_len, num_experts]
        gate_logits = self.gate(x)  # no activation
        gates = jax.nn.softmax(gate_logits, axis=-1)  # [B, T, E]

        # If top_k > 1 we zero out small probs; here implement top-k masking
        if self.top_k > 1 and self.top_k < self.num_experts:
            # compute k-th largest threshold per token
            # gates_sorted shape [B, T, E] sorted ascending -> take kth from end
            kth_vals = jnp.sort(gates, axis=-1)[..., -self.top_k][..., None]  # [B, T, 1]
            mask = (gates >= kth_vals).astype(gates.dtype)  # keep top-k
            # renormalize gated probabilities to sum to 1 over the selected experts
            gates = gates * mask
            denom = jnp.sum(gates, axis=-1, keepdims=True) + 1e-9
            gates = gates / denom

        # Prepare experts application:
        # We'll compute expert outputs for all experts and then do weighted sum by gates.
        # expert_outs: shape [E, B, T, D]
        def apply_expert(expert_module, x_in):
            # Each expert_module expects inputs [B, T, D] -> returns same shape
            return expert_module(x_in)

        # Vectorize: iterate experts and stack results
        expert_outs = jnp.stack([apply_expert(e, x) for e in self.experts], axis=0)  # [E, B, T, D]

        # gates: [B, T, E] -> move E axis in front for broadcasting: [E, B, T, 1]
        gates_for_mult = jnp.transpose(gates, axes=(2, 0, 1))[..., None]
        weighted = expert_outs * gates_for_mult  # [E, B, T, D]
        moe_out = jnp.sum(weighted, axis=0)  # [B, T, D]

        # Aux loss: encourage load balancing across experts.
        # A simple proxy: mean gate usage per expert across batch+time, then compute variance penalizer.
        mean_gate = jnp.mean(gates, axis=(0, 1))  # shape [E,]
        aux_loss = jnp.var(mean_gate)  # smaller variance -> more balanced
        # Scale auxiliary loss by number of experts so magnitude roughly consistent
        aux_loss = aux_loss * float(self.num_experts)

        return moe_out, aux_loss

class Transformer(nnx.Module):
    def __init__(
        self,
        max_lim,
        num_heads,
        in_features,
        attention_dropout,
        ffn_dropout: float,
        rngs: nnx.Rngs,
        use_moe: bool = True,
        moe_config: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()
        self.layer_norm_0 = nnx.LayerNorm(in_features, rngs=rngs, param_dtype=jnp.float32)
        self.attention = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=in_features,
            dropout_rate=attention_dropout,
            rngs=rngs,
            param_dtype=jnp.float32
        )
        self.layer_norm = nnx.LayerNorm(num_features=in_features, rngs=rngs, param_dtype=jnp.float32)

        # Option: normal FFN or MoE-based FFN
        self.use_moe = use_moe
        if self.use_moe:
            # default MoE params if not provided
            moe_config = moe_config or {}
            num_experts = moe_config.get("num_experts", 4)
            top_k = moe_config.get("top_k", 1)
            expert_hidden = moe_config.get("expert_hidden", in_features * 4)
            moe_dropout = moe_config.get("dropout_rate", ffn_dropout)
            self.moe = MoE(
                num_experts=num_experts,
                in_features=in_features,
                expert_hidden=expert_hidden,
                top_k=top_k,
                dropout_rate=moe_dropout,
                rngs=rngs
            )
        else:
            # Plain FFN like your original (kept similar shape progression)
            self.linear_1 = nnx.Linear(
                in_features=in_features,
                out_features=in_features * 2,
                rngs=rngs,
                param_dtype=jnp.float32
            )
            self.dropout_1 = nnx.Dropout(rate=ffn_dropout, rngs=rngs)
            self.linear_2 = nnx.Linear(
                in_features=in_features * 2,
                out_features=in_features * 4,
                rngs=rngs,
                param_dtype=jnp.float32
            )
            self.dropout_2 = nnx.Dropout(rate=ffn_dropout, rngs=rngs)
            self.linear_3 = nnx.Linear(
                in_features=in_features * 4,
                out_features=in_features,
                rngs=rngs,
                param_dtype=jnp.float32
            )

        self.layer_norm_2 = nnx.LayerNorm(num_features=in_features, rngs=rngs, param_dtype=jnp.float32)

    def __call__(self, x, train: bool = True, decode: bool = False, return_aux: bool = False):
        """
        If return_aux==True and MoE is used, returns (output, aux_loss_sum)
        Otherwise returns output only (compatibility with existing code).
        """
        aux_loss_sum = 0.0
        q_value = self.layer_norm_0(x)
        output_attention = x + self.attention(q_value, deterministic=not train, decode=decode)

        if self.use_moe:
            # MoE returns (moe_out, aux_loss)
            moe_out, aux_loss = self.moe(self.layer_norm(output_attention), train=train)
            linear_out = nnx.gelu(moe_out)
            aux_loss_sum += aux_loss
        else:
            linear_out = nnx.gelu(self.dropout_1(self.linear_1(self.layer_norm(output_attention)), deterministic=not train))
            linear_out = nnx.gelu(self.dropout_2(self.linear_2(linear_out), deterministic=not train))
            linear_out = nnx.gelu(self.linear_3(linear_out))

        out = output_attention + self.layer_norm_2(linear_out)

        if return_aux and self.use_moe:
            return out, aux_loss_sum
        return out

class Decoder(nnx.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        in_features: int,
        attention_dropout: float,
        ffn_dropout: float,
        num_transformers: int,
        num_head: int,
        rngs: nnx.Rngs,
        transformer_ctor_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()
        self.token_embed = nnx.Embed(vocab_size, features=in_features, rngs=rngs, param_dtype=jnp.float32)
        self.pos_embedding = nnx.Embed(max_len, features=in_features, rngs=rngs, param_dtype=jnp.float32)

        # Build a list of transformers; accept constructor kwargs to set use_moe per layer.
        transformer_ctor_kwargs = transformer_ctor_kwargs or {}
        transformers = [
            Transformer(
                max_len,
                num_head,
                in_features,
                attention_dropout,
                ffn_dropout,
                rngs,
                **transformer_ctor_kwargs
            ) for _ in range(num_transformers)
        ]
        self.transformer = nnx.Sequential(*transformers)
        self.linear = nnx.Linear(in_features, vocab_size, rngs=rngs, param_dtype=jnp.float32)

    def __call__(self, x, train: bool = True, decode: bool = False, return_aux: bool = False):
        seq_len = x.shape[1]
        positions = jnp.arange(0, seq_len)
        embedded_x = self.token_embed(x) + self.pos_embedding(positions)

        # If transformers may return aux losses, collect them
        aux_total = 0.0
        cur = embedded_x

        # nnx.Sequential will call each transformer; but if each transformer returns (out, aux),
        # we must handle both cases. We'll manually call each Module to be safe.
        # for layer_module in self.transformer:  # rely on nnx.Sequential internals
        #     # try calling with return_aux True to capture any MoE aux loss
        #     maybe_out = layer_module(cur, train=train, decode=decode, return_aux=True)
        #     if isinstance(maybe_out, tuple) and len(maybe_out) == 2:
        #         cur, aux_loss = maybe_out
        #         aux_total += aux_loss
        #     else:
        #         cur = maybe_out
        cur = self.transformer(cur, train=train, decode=decode)
        logits = self.linear(cur)

        if return_aux:
            return logits, aux_total
        return logits


In [None]:
import requests
import os
def download_shakespeare():
    url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    filename = 'tinyshakespeare.txt'
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        response = requests.get(url)
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(response.text)
        print("Download complete.")

    with open(filename, 'r', encoding='utf-8') as f:
        return f.read()

In [None]:
def process_character_data(text: str):
    """
    Creates a character vocabulary and encodes a block of text into integer IDs.

    Args:
      text: A single string containing the entire text corpus.

    Returns:
      A tuple containing:
        - data: A JAX array of the encoded text.
        - vocab_size: The number of unique characters.
        - char_to_idx: A dictionary mapping characters to their integer IDs.
        - idx_to_char: A dictionary mapping integer IDs back to characters.
    """
    chars = sorted(list(set(text)))
    vocab_size = len(chars)

    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}

    data = jnp.array([char_to_idx[c] for c in text])

    return data, vocab_size, char_to_idx, idx_to_char
# k, v, m, a = process_character_data(download_shakespeare())

In [None]:
from datasets import load_dataset

wikitext_dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")

full_text = "\n".join(wikitext_dataset['text'])

In [None]:
from collections import Counter
import re
def process_word_data(full_text: str, min_freq: int = 5):
    """Pre-processes, tokenizes, and encodes a text block into a 1D array of word IDs."""
    print("1. Pre-processing text to space out punctuation...")
    processed_text = re.sub(r'([=\n!$&,.:;?])', r' \1 ', full_text)
    processed_text = re.sub(r'\s+', ' ', processed_text).strip()

    print("2. Tokenizing the entire corpus...")
    all_tokens = processed_text.split(' ')

    print(f"3. Building vocabulary from {len(all_tokens):,} tokens...")
    word_counts = Counter(all_tokens)
    vocab = [word for word, count in word_counts.items() if count >= min_freq]
    vocab.append('<UNK>')
    vocab = sorted(vocab)
    vocab_size = len(vocab)

    word_to_idx = {word: i for i, word in enumerate(vocab)}
    idx_to_word = {i: word for word, i in word_to_idx.items()}
    unk_token_id = word_to_idx['<UNK>']

    print(f"   -> Vocabulary size reduced to {vocab_size} (min frequency: {min_freq})")

    print("4. Encoding the data...")
    data = jnp.array([word_to_idx.get(token, unk_token_id) for token in all_tokens])

    return data, vocab_size, word_to_idx, idx_to_word,

# Load WikiText and process it
print("Loading WikiText-103 dataset...")

Loading WikiText-103 dataset...


In [None]:
import os
import re
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from tqdm import tqdm
from dataclasses import dataclass
from collections import Counter
from datasets import load_dataset

In [None]:
data, vocab_size, word_to_idx, idx_to_word = process_character_data(download_shakespeare())

Downloading tinyshakespeare.txt...
Download complete.


In [None]:
data, vocab_size, word_to_idx, idx_to_word = process_word_data(full_text[:5_000_000])

In [None]:
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from tqdm import tqdm

BATCH_SIZE = 512
BLOCK_SIZE = 64
LEARNING_RATE = 1e-3
TRAINING_STEPS = 10000
NUM_TRANSFORMERS = 4
NUM_HEADS = 8
EMBED_DIM = 256
DECAY_STEPS = 200
DECAY_RATE=0.7
def get_batch(data_array, key):
    # sample BATCH_SIZE random starting indices
    ix = jax.random.randint(key, (BATCH_SIZE,), 0, len(data_array) - BLOCK_SIZE - 1)
    x = jnp.stack([data_array[i : i + BLOCK_SIZE] for i in ix])
    y = jnp.stack([data_array[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
    return x, y
key = jax.random.PRNGKey(0)
key, model_key = jax.random.split(key)
rngs = nnx.Rngs(model_key)

model = Decoder(
    vocab_size=vocab_size,
    max_len=BLOCK_SIZE,
    in_features=EMBED_DIM,
    attention_dropout=0.1,
    ffn_dropout=0.1,
    num_transformers=NUM_TRANSFORMERS,
    num_head=NUM_HEADS,
    rngs=rngs,
)
schedule = optax.exponential_decay(
    init_value=LEARNING_RATE,
    transition_steps=DECAY_STEPS,
    decay_rate=DECAY_RATE,
    staircase=True  # makes it step-wise decay
)

# Use the schedule in your optimizer
optimizer = nnx.Optimizer(model, optax.adamw(schedule), wrt=nnx.Param)
def main_train(model, optimizer, data, vocab_size, word_to_idx, idx_to_word):
    # data, vocab_size, word_to_idx, idx_to_word = process_word_data(full_text)


    key = jax.random.PRNGKey(0)
    key, model_key = jax.random.split(key)
    rngs = nnx.Rngs(model_key)
    @nnx.jit
    def train_step(model, optimizer, x, y):
        def loss_fn(m):
            logits = m(x, train=True)
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
            return loss

        loss, grads = nnx.value_and_grad(loss_fn)(model)
        optimizer.update(model,grads)
        return loss

    print("Starting training...")
    pbar = tqdm(range(TRAINING_STEPS))
    for step in pbar:
        key, batch_key = jax.random.split(key)
        x, y = get_batch(data, batch_key)
        loss = train_step(model, optimizer, x, y)
        wandb.log(
            {
                "loss": round(loss, 5),
                "step": step

            }
        )
        if step % 10 == 0:
            pbar.set_postfix(loss=f"{loss:.5f}")

    print("Training complete.")



In [None]:
nnx.display(model)

In [None]:
import jax
jax.devices()

[CudaDevice(id=0)]

In [None]:
%pip install wandb

In [None]:
import wandb

wandb.init(project="JAX-LLM", config={
    "lr": 1e-4,
    "Epochs": 10000,
})

In [None]:
main_train(model, optimizer, data, vocab_size, word_to_idx, idx_to_word)

Starting training...


  3%|▎         | 259/10000 [02:35<1:37:45,  1.66it/s, loss=0.25708]


KeyboardInterrupt: 

In [None]:
import cloudpickle
with open("model.pickle", "wb") as file:
  cloudpickle.dump(model, file)

In [None]:
import cloudpickle
with open("state.pickle", "wb") as file:
  cloudpickle.dump(state, file)

In [None]:
import jax
import orbax.checkpoint as orbax
state = nnx.state(model)


path = "/content/attention_model_0.59"
checkpointer = orbax.PyTreeCheckpointer()
checkpointer.save(f'{path}/model_state', state)
checkpointer.save(f'{path}/optimizer', optimizer)



In [None]:
# Replace 'folder_to_zip' with the path to the folder you want to zip
# Replace 'output_archive_name.zip' with the desired name for the zip file
!zip -r '/content/attention_model_0.59.zip' '/content/attention_model_0.59'

  adding: content/attention_model_0.59/ (stored 0%)
  adding: content/attention_model_0.59/optimizer/ (stored 0%)
  adding: content/attention_model_0.59/optimizer/_sharding (deflated 97%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/ (stored 0%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/d/ (stored 0%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/d/ef2dce5d915cbd42bbc49a8973f9bb37 (deflated 0%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/d/169ce1abd2ad03e63c9e13fba0e46ade (deflated 0%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/d/11002da2a38d27ee24940e0c26d139c6 (stored 0%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/d/5e77e8be26da1ca50194bd0e48192a45 (stored 0%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/d/e4761a16da0e7d8bd68914ab4a5b69fc (deflated 0%)
  adding: content/attention_model_0.59/optimizer/ocdbt.process_0/d/902b20a4b4ce590d2b449

In [None]:
from google.colab import files
files.download('/content/attention_model_0.59.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>