<h1>Multilayer Perceptron Language Model using Flax and Jax</h1>

In [1]:
import pandas as pd
from google.colab import drive
drive.mount('/content/drive')
file = "/content/sample_data/names.txt"
words = open(file, 'r').read().splitlines()

Mounted at /content/drive


In [201]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [226]:
import jax.numpy as jnp
import flax.linen as nn
from jax import random as jrandom
from jax import value_and_grad, jit
import optax
import jax
from jax import lax

This code builds a dataset for a language modeling task where given a sequence of characters, the next character needs to be predicted

In [258]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  for w in words:

    #print(w)
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      # print(''.join(itos[i] for i in context), '--->', itos[ix])
      context = context[1:] + [ix] # crop and append

  X = jnp.array(X)
  Y = jnp.array(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])

(182512, 3) (182512,)


In [12]:
Xtr.shape, Ytr.shape # dataset

((182424, 3), (182424,))

In the Flax implementation of the MLP, the weight matrices W1 and W2 are replaced with nn.Dense(128) and nn.Dense(self.out_dims = 10), respectively, while the biases b1 and b2 are automatically included as part of the Dense layers.

This is because the Dense layer in Flax combines both the weights and biases in a single parameter tuple. Specifically, each Dense layer has a kernel parameter representing the weight matrix, and a bias parameter representing the bias vector. These parameters are automatically initialized by Flax when the model is instantiated, and are updated during training as part of the model's parameter tree.

In [204]:
class MLP(nn.Module):                    # create a Flax Module dataclass
  out_dims: int
  vocab_size: int

  @nn.compact
  def __call__(self, x):
    x = nn.Embed(self.vocab_size, 128)(x)
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(128)(x)                 # create inline Flax Module submodules
    x = nn.relu(x)
    x = nn.Dense(self.out_dims)(x)       # shape inference
    return x

model = MLP(out_dims=10, vocab_size=len(itos))                 # instantiate the MLP model
params = model.init(jrandom.PRNGKey(0), Xtr)['params'] # initialize the weights

In [205]:
def loss_fn(params):
    # Convert ground truth labels to one-hot encoding
    Ytr_onehot = jnp.eye(N=10)[Ytr]
    # forward pass
    logits = model.apply({'params': params}, Xtr) #forward pass
    return jnp.mean(jnp.square(logits - Ytr_onehot))


This code block optimize MLP's parameters using the Optax library in JAX. The first step is to create an optimizer using optax.sgd(learning_rate=0.1) which specifies the learning rate of the optimizer. The initial optimizer state is initialized with opt.init(params).

In [207]:
#Optimizing with Optax
opt = optax.sgd(learning_rate=0.1)
opt_state = opt.init(params)
loss_grad_fn = value_and_grad(loss_fn)

Then, a loop is run for a fixed number of iterations (in this case, 100) and in each iteration, the loss and gradients of the MLP parameters are computed using the value_and_grad function with loss_fn as an argument. The optimizer is then updated with these gradients using opt.update(grads, opt_state) and the updated parameters are obtained with optax.apply_updates(params, updates).

In [208]:
for i in range(100):
  loss_val, grads = loss_grad_fn(params)
  updates, opt_state = opt.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.09833379
Loss step 10:  0.08320424
Loss step 20:  0.07557934
Loss step 30:  0.07171275
Loss step 40:  0.0697162
Loss step 50:  0.068644024
Loss step 60:  0.06802529
Loss step 70:  0.06763111
Loss step 80:  0.06735133
Loss step 90:  0.06713202


In the for loop, the emb variable contains the output of the MLP model. The logits variable is calculated from the softmax of emb[0]. The ix variable is then randomly sampled from the categorical distribution using jrandom.categorical() with the generated logits as input. The context variable is updated by removing the first element and appending the newly generated index, ix.

Finally, the generated name is printed by concatenating the corresponding characters in itos.

In [257]:
rng_key = jrandom.PRNGKey(2147483647+10)
context = jnp.zeros((1, block_size), dtype=jnp.int32)
for _ in range(20):
    out = []
    while True:
        # emb = model(context)
        emb = model.apply({'params': params}, context)
        logits = jax.nn.softmax(emb[0])
        # Sample from the categorical distribution
        key, rng_key = jrandom.split(rng_key)
        ix = jrandom.categorical(rng_key, logits).item()
        context = jnp.hstack([context[:, 1:], jnp.reshape(ix, (1, 1))])
        out.append(ix)
        # print(context)
        if ix == 0:
            break

    print(''.join([itos[i] for i in out]))

eadeggg.
gd.
cighgieaccbdabedd.
iadedceb.
idbaghggeebgb.
dbc.
ichibhfch.
chbahaiieadbgb.
adffibhbeafigabhchcahfabhgahdhchdbbhdbbhdf.
.
aacacfghd.
ecdcificabf.
a.
gc.
b.
dbddhdbhcddahhaihdhaieifciiaigaigde.
ci.
haaiahihbdcaafbfcfggdghebaiedgd.
iaffhh.
iidaabigbed.
