In [61]:
import jax
from jax import numpy as jnp, random
from typing import Any, Callable, Sequence
import flax
from flax import linen as nn
from functools import partial
import optax

In [31]:
### loading and processing dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

data = jnp.array(encode(text), dtype=jnp.int64)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


  data = jnp.array(encode(text), dtype=jnp.int64)


In [32]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]

In [36]:
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split, key):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = random.randint(key, (batch_size,), 0, len(data) - block_size)
    x = jnp.stack([data[i:i+block_size] for i in ix])
    y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

key = random.key(0)
xb, yb = get_batch('train', key)

In [37]:
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
(4, 8)
[[ 0 32 46 39 58  1 57 46]
 [57  6  1 40 63  1 63 53]
 [ 1 58 43 50 50  1 63 53]
 [ 0 37 53 59 56  1 51 53]]
targets:
(4, 8)
[[32 46 39 58  1 57 46 43]
 [ 6  1 40 63  1 63 53 59]
 [58 43 50 50  1 63 53 59]
 [37 53 59 56  1 51 53 58]]
----
when input is [0] the target: 32
when input is [0, 32] the target: 46
when input is [0, 32, 46] the target: 39
when input is [0, 32, 46, 39] the target: 58
when input is [0, 32, 46, 39, 58] the target: 1
when input is [0, 32, 46, 39, 58, 1] the target: 57
when input is [0, 32, 46, 39, 58, 1, 57] the target: 46
when input is [0, 32, 46, 39, 58, 1, 57, 46] the target: 43
when input is [57] the target: 6
when input is [57, 6] the target: 1
when input is [57, 6, 1] the target: 40
when input is [57, 6, 1, 40] the target: 63
when input is [57, 6, 1, 40, 63] the target: 1
when input is [57, 6, 1, 40, 63, 1] the target: 63
when input is [57, 6, 1, 40, 63, 1, 63] the target: 53
when input is [57, 6, 1, 40, 63, 1, 63, 53] the target: 59
when inpu

In [25]:
@partial(jax.vmap, in_axes=(0, 0, None))
def _multinomial(keys, probs, num_samples):
    return jax.random.choice(keys, jnp.arange(len(probs)), 
                             shape=(num_samples,), 
                             p=probs)

def multinomial(key, probs, num_samples):
    batch_size = probs.shape[0]
    keys = jax.random.split(key, batch_size)
    return _multinomial(keys, probs, num_samples)

In [26]:
key = random.key(0)
multinomial(key, jnp.array([[0,1,1],
                            [1,0,0],
                            [0,1,1]]), 5)

Array([[2, 2, 1, 1, 2],
       [0, 0, 0, 0, 0],
       [1, 1, 2, 2, 2]], dtype=int32)

In [121]:
# def cross_entropy(logits, targets, axis=-1):
#     ### logits: [B, N, C], targets: [B, N]
#     targets = jax.nn.one_hot(targets, logits.shape[-1])
#     return -(logits * targets).sum()

n_embd = 16
class BigramLanguageModel(nn.Module):
    vocab_size: int
    def setup(self):
        self.token_embedding_table = nn.Embed(num_embeddings=vocab_size, 
                                              features=vocab_size)
        
    def __call__(self, idx):
        logits = self.token_embedding_table(idx)
        return logits
    
    def generate(self, key, idx, max_new_tokens):
        keys = random.split(key, max_new_tokens)
        for i in range(max_new_tokens):
            logits = self(idx)[:, -1, :]
            probs = jax.nn.softmax(logits, axis=-1)
            idx_next = multinomial(keys[i], probs, num_samples=1)
            idx = jnp.concat((idx, idx_next), axis=1)
        return idx

def loss_fn(preds, targets):
    B, T, C = preds.shape
    logits = preds.reshape(B*T, C)
    targets = targets.reshape(B*T)
    targets = jax.nn.one_hot(targets, C)
    loss = optax.losses.softmax_cross_entropy(logits, targets).mean()
    return loss
    
m = BigramLanguageModel(vocab_size=vocab_size)
x, y = get_batch('train', key)
params = m.init(key, x)
logits = m.apply(params, x)
loss = loss_fn(logits, y)
print(logits.shape, loss.shape, loss)

(4, 8, 65) () 4.188377


In [86]:
# testing that cross entropy loss works correctly from optax
logits = jnp.array([[0,0,0,1]], dtype=jnp.float32)
labels = jnp.array([[0,0,0,1]], dtype=jnp.float32)
optax.softmax_cross_entropy(logits, labels)

Array([0.7436683], dtype=float32)

In [88]:
context = jnp.zeros((1, 1), dtype=jnp.int32)
autocomplete = m.apply(params, key=key, idx=context, max_new_tokens=100, method='generate')
print(decode(autocomplete.tolist()[0]))



!IkvgeEYJ EkhFBgkOqJXdbuEzswu!.b-z:AU3j:cxvmzaHGhTnK,aHeFfTWhdSvCo:;:ZrpFa-oJQwt:oC mT-si'xe;TV
eFsI


In [122]:
%%time
learning_rate = 1e-3

tx = optax.adamw(learning_rate=learning_rate)
opt_state = tx.init(params)

@jax.jit
def step(params, opt_state, xb, yb):
    def loss(params, xb, yb):
        preds = m.apply(params, xb)
        loss = loss_fn(preds, yb)
        return loss

    loss, grads = jax.value_and_grad(loss)(params, xb, yb)
    updates, opt_state = tx.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss
    
for steps in range(10000): # increase number of steps for good results...
    # make a new key for randomly getting a minibatch
    data_key, key = random.split(key, 2)
    
    # sample a batch of data
    xb, yb = get_batch('train', data_key)

    # evaluate the loss
    params, opt_state, loss = step(params, opt_state, xb, yb)
#     print(params)
    if steps % 1000 == 0:
        print('Loss step {}: '.format(steps), loss)



Loss step 0:  4.1692815
Loss step 1000:  3.7298422
Loss step 2000:  3.1380658
Loss step 3000:  3.3327022
Loss step 4000:  2.668402
Loss step 5000:  2.6300232
Loss step 6000:  3.137767
Loss step 7000:  2.418674
Loss step 8000:  2.8468523
Loss step 9000:  2.821628
CPU times: user 27.8 s, sys: 78.4 ms, total: 27.9 s
Wall time: 28 s


In [124]:
context = jnp.zeros((1, 1), dtype=jnp.int32)
autocomplete = m.apply(params, key=key, idx=context, max_new_tokens=500, method='generate')
print(decode(autocomplete.tolist()[0]))




Bllee.
Rof mes tyoorwedet, IUSoven ee me,

NRThauntion FI! sie. d therdandst isuprerend, Yierurrin furtinghis bis r ppe thigre d pa':
I:
'd apo tout m m I:
stir3had d owotlleashonctoofopllbemainotunthrssen:
SIA:
t grbalisee be thum he
Angrou, s,lyrs KELAn ad Conde.
-irr he tinced n:Fz;
Morompr y rf ntith, doantheavey HYD:
CHANLAxp'linene be, y
llirt, bell y ICALoullowond, l implf whin

WAMBer ton mu!GQShyssde; m ce Yo y,
UCovere me'Sa

'shaus, n,
Twnof dreere I dn s tan:
Qd.Enom hen wivithwP,




So far so good, and some notes on how jax+flax code is organized different from pytorch:
1. the `__call__` method doesn't compute loss, and loss computation is moved out into `step` function that can be jitted for the training loop.
2. some util functions are not available in jax, such as `multinomial` so one has to write custom versions of them.
3. managing and updating optimizer state is more complicated than in pytorch, which just amounts to `optimizer.step`.

The next step is to rewrite each of the classes in the pytorch implementation in jax.

```python
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * self.head_size**-0.5
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf')) 
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out
```

In [155]:
from dataclasses import dataclass

@dataclass
class Config:
    batch_size = 32
    block_size = 32 #k-gram
    max_iters = 5000
    eval_interval = 500
    learning_rate = 5e-3
    eval_iters = 200
    n_embd = 32
    n_head = 4
    n_layer = 4
    dropout = 0.2

class Head(nn.Module):
    
    config: Config
    head_size: int
        
    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False)
        self.query = nn.Dense(self.head_size, use_bias=False)
        self.value = nn.Dense(self.head_size, use_bias=False)
        self.dropout = nn.Dropout(self.config.dropout)
        
    def __call__(self, x, deterministic=True):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = jnp.matmul(q, k.swapaxes(-2, -1)) * self.head_size**-0.5
        wei = jnp.tril(wei)
        wei = jnp.where(wei, wei, float('-inf'))
        wei = jax.nn.softmax(wei, axis=-1)
        wei = self.dropout(wei, deterministic=deterministic)
        v = self.value(x)
        out = jnp.matmul(wei, v)
        return out

Notes on some differences between jax+flax and pytorch:
1. `flax.linen.Dense` doesn't need input shape,
2. `jnp.tril` directly returns masked array so no need to store mask as a frozen variable
3. Use `jnp.matmul` to explicitly to batched multiplication between matrices of shapes `(..., K, N)` and `(..., N, M)`.
4. For `Dropout` that has different behavior at trainining vs inference this is controlled by passing the `deterministic` boolean variable rather than using `model.train()` or `model.eval()` (stateful)

In [141]:
arr = random.randint(key, (5,5), 0, 10)
arr = jnp.tril(arr)
arr = jnp.where(arr, arr, float('-inf'))
arr

Array([[  5., -inf, -inf, -inf, -inf],
       [  9.,   8., -inf, -inf, -inf],
       [  6.,   3.,   3., -inf, -inf],
       [  4.,   3.,   3.,   6., -inf],
       [  2.,   7.,   2.,   4.,   5.]], dtype=float32, weak_type=True)

In [158]:
config = Config()
head = Head(config=config, head_size=8)
rand_x = random.normal(key, (config.batch_size, config.block_size, config.n_embd))
params = head.init(rngs=key,
                   x=rand_x)
head.apply(params, rand_x).shape

(32, 32, 8)