In [None]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class FeedForward(nn.Module):
    n_emb: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(4 * self.n_emb)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_emb)(x) # FOR RESIDUAL PATHWAY
        return x

class AttentionHead(nn.Module):
    head_size: int

    def setup(self):
        self.causal_mask = jnp.tril(jnp.ones((block_size, block_size), dtype=bool))

        self.to_q = nn.Dense(self.head_size, use_bias=False) # (C, hs)
        self.to_k = nn.Dense(self.head_size, use_bias=False) # (C, hs)
        self.to_v = nn.Dense(self.head_size, use_bias=False) # (C, hs)

    def __call__(self, x):
        B, T, C = x.shape
        q = self.to_q(x) # (B, T, C) @ (C, hs) -> (B, T, hs)
        k = self.to_k(x) # (B, T, C) @ (C, hs) -> (B, T, hs)
        v = self.to_v(x) # (B, T, C) @ (C, hs) -> (B, T, hs)

        attn_scores = (q @ k.swapaxes(-2, -1)) * (self.head_size ** -0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T)

        mask = self.causal_mask[:T, :T]
        attn_scores = jnp.where(~mask[None, :, :], -jnp.inf, attn_scores)
        attn_weights = nn.softmax(attn_scores, axis=-1)
        return attn_weights @ v # (B, T, T) @ (B, T, hs) ---> (B, T, hs)

class MultiHeadAttention(nn.Module):
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x):
        out = jnp.concatenate([AttentionHead(self.head_size)(x) for _ in range(self.num_heads)], axis=-1) # (B, T, hs*num_heads)
        out = nn.Dense(self.num_heads*self.head_size)(out) # FOR RESIDUAL PATHWAY
        return out

class Block(nn.Module):
    n_emb: int
    num_heads: int

    @nn.compact
    def __call__(self, x):
        # NOTE the x + (stuff) is a RESIDUAL CONNECTION
        x = x + MultiHeadAttention(self.num_heads, self.n_emb//self.num_heads)(nn.LayerNorm()(x)) # (B, T, C) num_heads of (n_emb//num_heads)-dimensional self-attention
        x = x + FeedForward(self.n_emb)(nn.LayerNorm()(x)) # (B, T, C)
        return x

class GPT(nn.Module):
    n_emb: int
    vocab_size: int
    block_size: int
    num_heads: int
    num_blocks: int


    @nn.compact
    def __call__(self, idx):
        B, T = idx.shape
        x = nn.Embed(self.vocab_size, self.n_emb)(idx) # (B, T, C)
        # positional embedding layer
        x += nn.Embed(self.block_size, self.n_emb)(jnp.arange(T)) # (B, T, C) + (T, C) (broadcast)
        blocks = nn.Sequential([Block(self.n_emb, self.num_heads) for _ in range(self.num_blocks)])
        x = blocks(x)
        logits = nn.Dense(self.vocab_size)(nn.LayerNorm()(x)) # (B, T, vocab_size)
        return logits

    def generate(self, key, params, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:] # crop idx so that it's never bigger than block_size (B, T)
            logits = self.apply(params, idx_cond)
            logits = logits[:, -1, :]
            key, subkey = jax.random.split(key)
            idx_next = jax.random.categorical(subkey, logits, shape=(logits.shape[0],1))
            idx = jnp.concatenate([idx, idx_next], axis=1).astype(int)
        return idx


In [None]:
# we always start with a dataset to train on. let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-04-25 07:10:17--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-04-25 07:10:18 (18.5 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [None]:
with open('/content/drive/MyDrive/smiles.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for i,c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

data = jnp.array(encode(text), dtype=jnp.int32)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [None]:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 128 # what is the maximum context length for predictions?
n_emb = 384
num_heads = 6
num_blocks = 8
learning_rate = 1e-3
num_epochs = 1

In [None]:
def get_batch(key, data):
    # how many valid start positions we have
    max_idx = data.shape[0] - (block_size + 1)
    # draw batch_size random starts in [0, max_idx)
    starts = jax.random.randint(key, (batch_size,), 0, max_idx)

    def slice_pair(i):
        # pull [i : i+block_size]  and [i+1 : i+1+block_size]
        x = jax.lax.dynamic_slice(data, (i,), (block_size,))
        y = jax.lax.dynamic_slice(data, (i+1,), (block_size,))
        return x, y

    # vmap over our vector of starts
    xs, ys = jax.vmap(slice_pair)(starts)
    return xs, ys

In [None]:
model = GPT(n_emb=n_emb, vocab_size=vocab_size, block_size=block_size, num_heads=num_heads, num_blocks=num_blocks)

def loss_fn(params, x, y):
    logits = model.apply(params, x)
    return optax.losses.softmax_cross_entropy_with_integer_labels(logits, y).mean()

In [None]:
# Initialise Model ===============================
key, init_key = jax.random.split(jax.random.PRNGKey(420696969))
xb, yb = get_batch(key, train_data)
params = model.init(init_key, jnp.ones_like(xb))

In [None]:
@jax.jit
def estimate_loss(key, params, eval_iters: int = 200):
    # split off two keys: one for train, one for val
    key, train_key, val_key = jax.random.split(key, 3)

    # now generate all the subkeys for each split in one go
    train_keys = jax.random.split(train_key, eval_iters)
    val_keys   = jax.random.split(val_key,   eval_iters)

    # vmapped get_batch: maps over the first axis of keys
    vget_batch = jax.vmap(get_batch, in_axes=(0, None))

    Xs_train, Ys_train = vget_batch(train_keys, train_data)
    Xs_val,   Ys_val   = vget_batch(val_keys,   val_data)

    # vmapped loss_fn: maps over axis 0 of (Xs, Ys)
    vloss = jax.vmap(lambda x, y: loss_fn(params, x, y))

    train_losses = vloss(Xs_train, Ys_train)
    val_losses   = vloss(Xs_val,   Ys_val)

    return {
        'train': jnp.mean(train_losses),
        'val':   jnp.mean(val_losses),
    }

In [None]:
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)

@jax.jit
def train_step(opt_state, params, X, Y):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, Y)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state


In [None]:
for i in tqdm(range(5000)):
    key, train_key = jax.random.split(key)
    xb, yb = get_batch(train_key, train_data)
    loss, params, opt_state = train_step(opt_state, params, xb, yb)
    if i % 1000 == 0:
        key, est_key = jax.random.split(key)
        losses = estimate_loss(est_key, params)
        print(f"Training loss: {losses['train']}, Validation loss: {losses['val']}")

  0%|          | 4/5000 [00:07<1:58:00,  1.42s/it] 

Training loss: 0.45827242732048035, Validation loss: 0.4805230498313904


 11%|█▏        | 564/5000 [00:48<05:25, 13.61it/s]Exception ignored in: <function _xla_gc_callback at 0x799dcd5e4400>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 
 12%|█▏        | 609/5000 [00:52<06:16, 11.67it/s]


KeyboardInterrupt: 

In [None]:
from functools import partial

def _gen_step(carry, rng):
    window, = carry  # window: (B, block_size)
    logits = model.apply(params, window)
    next_token = jax.random.categorical(rng, logits[:, -1, :])  # (B,)
    new_window = jnp.concatenate([window[:, 1:], next_token[:,None]], axis=1) #add next_token[:,None] here instead
    return (new_window,), next_token #return next_token without adding extra dimension


@partial(jax.jit, static_argnums=(3,))
def generate_batch(params, init_idx, key, max_new_tokens: int):
    """
    params:        your model params pytree
    init_idx:      int32 array of shape (B, T0) with T0 <= block_size
    key:           a PRNGKey
    max_new_tokens: number of new tokens to sample
    returns: full_seq (B, T0+max_new_tokens), all_new_tokens (max_new_tokens, B)
    """
    B, T0 = init_idx.shape
    assert T0 <= block_size, f"Context length must be ≤ block_size ({block_size}), got {T0}"

    # left-pad init_idx up to block_size so our carry window is fixed-size
    pad_len     = block_size - T0
    init_window = jnp.pad(init_idx, ((0,0), (pad_len,0)), constant_values=0)  # (B, block_size)

    # split RNG into one key per token
    keys = jax.random.split(key, max_new_tokens)

    # run the scan
    (final_window,), new_tokens = jax.lax.scan(
        _gen_step,
        (init_window,),  # initial carry
        keys             # scan over these RNGs
    )
    # new_tokens: (max_new_tokens, B)

    # rebuild the full generated sequence
    #   - take the tail of the init_window to recover the original context
    #   - concatenate with the newly sampled tokens
    context = init_window[:, pad_len:]               # (B, T0)
    #gen_seq = jnp.transpose(new_tokens, (1,0))       # (B, max_new_tokens)
    gen_seq = new_tokens.reshape(new_tokens.shape[1], new_tokens.shape[0]) # Reshape new_tokens to (B, max_new_tokens)
    #gen_seq = new_tokens.squeeze()
    full_seq = jnp.concatenate([context, gen_seq], axis=1)  # (B, T0 + max_new_tokens)

    return full_seq, new_tokens

In [None]:
batch_size = 1
init       = jnp.zeros((batch_size, 1), dtype=jnp.int32)  # start with BOS token
key, subkey = jax.random.split(key)

full_ids, tokens = generate_batch(params, init, subkey, 2000)
print(decode(full_ids[0].tolist()))  # your generated text


Cc1c-2c(cn1c3ccc(cc3)C2=O)F
CCNOC(=O)N1CCCN(CC1)c2ccccc2F
Cc1ccccc1Oc2ccc(cc2[N+](=O)[O-])C=O
CN1C(=O)/C(=C/C(=O)NC1=O)N2CCCC2)/N(CCOc3ccccc3)/NS(=O)(=O)C
Cc1cccc(c1)c2nc3ccc(cc3c(=O)o2)OCC(=O)N4CCCCC4
CCc1ccccc1N2CC[NH+](CC2)Cc3ccc4c(c3)OCO4
C[NH+]1CCN(CC1)c2cc(ccc2OCCO2
c1ccc2c(c1)cccc2NC(=O)NCCC(=O)[O-]
CCC(=O)Nc1c(c2c(s1)CCC2)C(=O)Nc3cc(ccc3OC)Cl
C[C@H]1CCCCN([C@@H]1CC)c2cc(ccc2OCC(=O)Nc3ccc(cc3)Cl)C
C[C@@H]1CCCCN([C@@H]1CC)c2cc(ccc2OCC(=O)Nc3ccc(cc3)F)Cl
C[C@H]1CCCCN([C@@H]1CCCCNC(=O)c2cccs2)C
C[C@@H]1CCCCN([C@H]1CCCCNC(=O)c2cccs2)C
C[C@H]1CCCCN([C@@H]1CCCCNC(=O)c2cccs2)C
CCc1ccccc1NC(=O)c2cccs2
C[C@@H]1CCCCN1C(=O)c2cccs2
Cc1ccc(o1)/C=C/2\C(=NN(C2=O)c3ccccc3)C
c1cc(ccc1N2[C@H]4C=CC[C@@H](O2)[C@@H]5[C@H]3C(=O)N(C5=O)c6ccco6)Cl
CC12CC[C@@]3([C@@H]1[C@@H](C2)C3(C(=O)OC)OCc4ccc(cc4)F)C
COc1ccc(c(c1)[N-]S(=O)(=O)c2ccc3c(c2)C(=O)N(C3=O)CCc4ccc(cc4)O)OC
CCOC(=O)c1ccc(cc1)NC/2=C(\C(=C(C(=O)OC)OCC)/S2)[O-]
CCOC(=O)[C@H]1[C@H](CC1=O)c2ccccc2[N+](=O)[O-]
CCOc1ccc(cc1OC)/C=C/2\C(=O)OC(=N2)c3