## Importing Libraries and connecting Google Drive

In [None]:
!pip install jax jaxlib flax optax



In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap, pmap
import flax
import flax.linen as nn
import optax
import numpy as np

In [None]:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

In [None]:
from google.colab import drive

In [None]:
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [None]:
rng_key = jax.random.PRNGKey(0)



In [None]:
# Download a file based on its file ID (the long string in the shareable link of a file in Google Drive)
file_id = '1uZiUwBQGpCcr2G9s1mXdwvNYnttTdOIU'
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('theSecretBook.txt')

In [None]:
input_file_path = '/content/theSecretBook.txt'

In [None]:
# hyperparameters
batch_size = 16
block_size = 32
max_iters = 2500
eval_interval = 200
learning_rate = 1e-3
eval_iters = 100
n_embd = 64
n_head = 4
n_layer = 8
# ------------

Dataset Loading

In [None]:
class Dataset:
  def __init__(self):
    self.vocab_size = 0
    self.train_data = jnp.array([], dtype=jnp.int32)
    self.val_data = jnp.array([], dtype=jnp.int32)

  def read_dataset(self):
    with open(input_file_path, 'r', encoding='utf-8') as f:
        self.data = f.read()

  def prepare_dataset(self):
    self.read_dataset()

    chars = sorted(list(set(self.data)))
    self.vocab_size = len(chars)
    char_to_int = {ch: i for i, ch in enumerate(chars)}
    int_to_char = {i: ch for i, ch in enumerate(chars)}
    self.encode = lambda s: [char_to_int[c] for c in s]
    self.decode = lambda l: ''.join([int_to_char[i] for i in l])

  def data_split(self):
    self.prepare_dataset()

    data_tensor = jnp.array(self.encode(self.data), dtype=jnp.int32)
    n = int(0.8 * len(data_tensor))
    self.train_data = data_tensor[:n]
    self.val_data = data_tensor[n:]

  def get_batch(self, split):
    self.data_split()

    data = self.train_data if split == 'train' else self.val_data
    ix = random.randint(rng_key, (batch_size,), 0, len(data) - block_size)
    x = jnp.stack([data[i:i+block_size] for i in ix])
    y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

## Loss Function

In [None]:
class Loss:
  def estimate_loss(self):
    out = {}
    for split in ['train', 'val']:
        losses = []
        for k in range(eval_iters):
            X, Y = dataObj.get_batch(split)
            logits, loss = model(X, Y)
            losses.append(loss)
        out[split] = jnp.mean(jnp.array(losses))
    return out

lossObj = Loss()

In [None]:
class Head(nn.Module):
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False, kernel_init=nn.initializers.xavier_uniform())
        self.query = nn.Dense(self.head_size, use_bias=False, kernel_init=nn.initializers.xavier_uniform())
        self.value = nn.Dense(self.head_size, use_bias=False, kernel_init=nn.initializers.xavier_uniform())
        self.tril = jnp.tril(jnp.ones((block_size, block_size)))

    def __call__(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        w = jnp.matmul(q, jnp.transpose(k, (0, 2, 1))) * C ** -0.5
        w = jnp.where(self.tril[:T, :T] == 0, float('-inf'), w)
        w = nn.softmax(w, axis=-1)

        v = self.value(x)
        out = jnp.matmul(w, v)

        return out

In [None]:
class MultiHeadAttention(nn.Module):
    n_head: int
    head_size: int

    def setup(self):
        self.heads = [Head(self.head_size) for _ in range(self.n_head)]
        self.proj = nn.Dense(n_embd)

    def __call__(self, x):
        out = jnp.concatenate([h(x) for h in self.heads], axis=-1)
        return self.proj(out)

In [None]:
class FeedForward(nn.Module):
  def setup(self):
    self.net = nn.Sequential([
        nn.Dense(4 * n_embd),
        nn.relu,
        nn.Dense(n_embd)
    ])

  def __call__(self, x):
    return self.net(x)

In [None]:
class TransformerBlock(nn.Module):
    def setup(self):
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head=n_head, head_size=head_size)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(epsilon=1e-6)
        self.ln2 = nn.LayerNorm(epsilon=1e-6)

    def __call__(self, x):
        x = x + self.sa(x)
        x = x + self.ffwd(x)
        return x

In [None]:
class TokenEmbedding(nn.Module):
    @nn.compact
    def __call__(self, idx):
        return nn.Dense(n_embd, use_bias=False)(jax.nn.one_hot(idx, dataObj.vocab_size))

class PositionEmbedding(nn.Module):
    def setup(self):
        self.T = None  # Initialize T as None

    def set_T(self, T):
        self.T = T  # Set T dynamically

    @nn.compact
    def __call__(self, idx):
        assert self.T is not None, "T must be set using set_T() before calling PositionEmbedding"
        return nn.Dense(n_embd, use_bias=False)(jax.nn.one_hot(jnp.arange(self.T), block_size))

class NanoGPT(nn.Module):
    def setup(self):
        self.blocks = [TransformerBlock() for _ in range(n_layer)]
        self.ln_f = nn.LayerNorm(epsilon=1e-6)
        self.lm_head = nn.Dense(dataObj.vocab_size, kernel_init=nn.initializers.xavier_uniform())

        # Initialize submodules within setup
        self.token_embedding = TokenEmbedding()
        self.position_embedding = PositionEmbedding()

    def __call__(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding(idx)
        self.position_embedding.set_T(T)  # Set T dynamically
        pos_emb = self.position_embedding(jnp.arange(T))

        x = tok_emb + pos_emb
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = jnp.mean(jax.nn.softmax_cross_entropy(logits, jax.nn.one_hot(targets, dataObj.vocab_size)))

        return logits, loss

    def generate(self, idx, max_new_tokens, key):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = nn.softmax(logits)
            idx_next = random.categorical(key, logits)
            idx = jnp.concatenate([idx, idx_next[:, None]], axis=1)
        return idx

In [None]:
def generateNext():
  key = random.PRNGKey(0)
  context = jnp.zeros((1, 1), dtype=jnp.int32)
  generated_seq = model.generate(context, max_new_tokens=2000, key=key)
  print(dataObj.decode(generated_seq[0].tolist()))


In [None]:
if __name__ == '__main__':
  dataObj = Dataset()
  dataObj.read_dataset()
  dataObj.prepare_dataset()

  model = NanoGPT()
  params = model.init(rng_key, jnp.zeros((1, block_size), dtype=jnp.int32))

  optimizer = optax.adam(learning_rate=learning_rate)
  state = optimizer.init(params)

  @jit
  def update(params, xb, yb, state):
      logits, loss = model.apply({'params': params}, xb, yb)
      grads = jax.grad(loss)(params)
      updates, new_state = optimizer.update(grads, state)
      new_params = optax.apply_updates(params, updates)
      return new_params, loss, new_state

  for iter in range(max_iters):
      if iter % eval_interval == 0 or iter == max_iters - 1:
          losses = lossObj.estimate_loss()
          print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

      xb, yb = dataObj.get_batch('train')
      params, loss, state = update(params, xb, yb, state)

In [None]:
generateNext()