<a href="https://colab.research.google.com/github/ryanzhao29/Jax/blob/main/gpt_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Building a GPT

Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT.

In [None]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-06-18 15:47:03--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.6’


2024-06-18 15:47:04 (11.9 MB/s) - ‘input.txt.6’ saved [1115394/1115394]



In [None]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [None]:
# let's look at the first 1000 characters
# print(text[:1000])

In [None]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [None]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import flax
import optax
import copy

In [None]:
# Let's now split up the data into train and validation sets
 # first 90% will be train, rest val

data = jnp.array(encode(text), dtype=jnp.int32)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [None]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int32)

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [None]:
seed = 1
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?

def get_batch(split, key):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = jax.random.randint(key, (batch_size,), 0, len(data) - block_size)
    x = jnp.stack([data[i:i+block_size] for i in ix])
    y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)
xb, yb = get_batch('train', subkey)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

# for b in range(batch_size): # batch dimension
#     for t in range(block_size): # time dimension
#         context = xb[b, :t+1]
#         target = yb[b,t]
#         print(f"when input is context the target: {context.tolist(), target.tolist()}")


inputs:
(64, 32)
[[46 63  1 ... 52 43  8]
 [57 46 39 ... 43  1 43]
 [56 52  5 ... 57 43 52]
 ...
 [ 1 51 63 ... 52 42  1]
 [60 43  1 ...  1 41 46]
 [53  1 58 ...  1 39 52]]
targets:
(64, 32)
[[63  1 57 ... 43  8  0]
 [46 39 50 ...  1 43 39]
 [52  5 42 ... 43 52 42]
 ...
 [51 63  1 ... 42  1 42]
 [43  1 58 ... 41 46 53]
 [ 1 58 46 ... 39 52 42]]
----


In [None]:
print(xb) # our input to the transformer

[[46 63  1 ... 52 43  8]
 [57 46 39 ... 43  1 43]
 [56 52  5 ... 57 43 52]
 ...
 [ 1 51 63 ... 52 42  1]
 [60 43  1 ...  1 41 46]
 [53  1 58 ...  1 39 52]]


In [None]:
class BigramLanguageModel(nn.Module):
  embedding_size: int
  @nn.compact
  def __call__(self, x):
    logits = nn.Embed(self.embedding_size,self.embedding_size)(x)
    return logits

  def loss(self, params, x, y):
    logits = self.apply(params, x)
    print(y.shape)
    print(logits.shape)
    y = y.reshape(-1)
    B, T, C = logits.shape
    logits_reshaped = logits.reshape(B*T, C)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits_reshaped, y)
    loss = jnp.mean(loss)
    return loss

  def generate(self, params, idx, max_new_tokens, key):
    # idx is b x t.
    for _ in range(max_new_tokens):
      key, subkey = jax.random.split(key)
      logits = self.apply(params, idx)
      logits = logits[:,-1,:]
      idx_next = jax.random.categorical(subkey, logits, axis=1)
      idx_next = jnp.expand_dims(idx_next, axis=0)
      idx = jnp.concatenate((idx, idx_next), axis=1)
    return idx

model = BigramLanguageModel(65)

print(model)
key, subkey, init_key = jax.random.split(key, 3)
params = model.init(init_key,xb)
loss = model.loss(params, xb, yb)
print("loss is")
print(loss)
idx = jnp.ones((1, 1), dtype=jnp.int32)
print(decode(model.generate(params, idx, max_new_tokens=100, key=subkey)[0].tolist()))

BigramLanguageModel(
    # attributes
    embedding_size = 65
)
(64, 32)
(64, 32, 65)
loss is
4.191567
 Y
rvDf$MHYq-TqlsDb$mKEmMiyr:jiXHaYYgmkvWGyRFHal hPizzTMfQIubc&DdplW3ZbVHlswolBaiQxXQZAowgnWeMHK;jbCH


In [None]:
opt = optax.adam(learning_rate=1e-3)
opt_state = opt.init(params)
seed = 0
key = jax.random.PRNGKey(seed)

# for epoch in range(5000):
#   key, key1 = jax.random.split(key)
#   xb, yb = get_batch('train', key1)
#   loss, grads = jax.value_and_grad(model.loss)(params, xb, yb)
#   updates, opt_state = opt.update(grads, opt_state)
#   # print(opt_state)
#   params = optax.apply_updates(params, updates)

#   # param_diff = jax.tree.map(lambda p,q: p-q, params, params1)
#   # print(param_diff)
#   # print(decode(model.generate(params, idx, max_new_tokens=10)[0].tolist()))
#   if epoch % 10 == 0:
#     print(f"epoch {epoch}, loss {loss}")

In [None]:
print(decode(model.generate(params, idx, max_new_tokens=500, key=key)[0].tolist()))

 br'qDW!cVenexdcDOJhoE.fqwwC'E.zBDoXvJoCFcTO;R!D  HOZFP
VD.sZoP..tm!YGFkybeGt3Jn'u?-$x  SsQSrkHqXL,GyKq-pvCg:GWs:B$GmO$grOUFnBLaJHBobZBHq
v-c
iVQvgxD- :d:UCQuDrEgFRN?xFZTC
RnfoCn;3fyu,biAn'W,r O3.iy?HKkl
kkDFzZN'H;yySPN-zd:oWU'DBqu;es.ob-h!RfZQ:P,IA-TcokJajFQEh;OHvvngJ3uDUd
3! Pp3iyDkTIYc;z UPk
M$bQF!JWtoN
qoWfGR!R;.voXhdMnrL'IXlD3KjpF&v!tD,g?ex?NQBND
fgGMc,YJRNp3wkqWLCz
XTjivBFE:Vfz!AHHEFrwLFb$HOWXx'ZKUmXhL.gF:BUPuFzp;sJtvnXspBKNiJL rStSjCYzjpk; o?:mVwzXYytAF-t
cKhGUJdhki;-LGwvB,ChIHKN!Oz,gDwbj$


## The mathematical trick in self-attention

In [None]:
key, subkey = jax.split(key)
B, T, C = 4, 8, 2
x = jax.random.normal(subkey, (B, T, C))
print(x.shape)

AttributeError: module 'jax' has no attribute 'split'