In [147]:
import jax.numpy as jnp
import jax.random
import jax.lax

In [215]:
def rmsnorm(x, weight):
  ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5)
  return weight * x * ss

def softmax(x):
  max_val = jnp.max(x)
  x = jnp.exp(x - max_val)
  return x / sum(x)

def sigmoid(x):
  return 1 / (1 + jnp.exp(-x))

def silu(x):
  return x * sigmoid(x)


# Token is token value
asserts = False
def forward(token, config, weights, key_cache, value_cache):
  pos = key_cache.shape[1]
  assert pos == key_cache.shape[1]
  assert pos == value_cache.shape[1]
    
  n_layers = config['n_layers']
  seq_len = config['seq_len']
  n_heads = config['n_heads']
  vocab_size = config['vocab_size']

  # Total number of parameters of the recurrent state
  dim = config['dim']

  n_kv_heads = config['n_kv_heads']

  # number of hidden dimensions?
  hidden_dim = config['hidden_dim']


  # Number of parameters per head
  head_size = dim // n_heads

  # Number of heads per kv
  kv_mul = n_heads // n_kv_heads

  # Number of parameters in a kv
  kv_dim = dim // n_heads * n_kv_heads


  wo = weights['wo']
  if asserts: assert wo.shape == (n_layers, dim, dim)
  rms_ffn_weight = weights['rms_ffn_weight']
  if asserts: assert rms_ffn_weight.shape == (n_layers, dim)
  w1 = weights['w1']
  if asserts: assert w1.shape == (n_layers, hidden_dim, dim)
  w3 = weights['w3']
  if asserts: assert w3.shape == (n_layers, hidden_dim, dim)
  w2 = weights['w2']
  if asserts: assert w2.shape == (n_layers, dim, hidden_dim)

  rms_att_weight = weights['rms_att_weight']
  if asserts: assert rms_att_weight.shape == (n_layers,dim)

  rms_final_weight = weights['rms_final_weight']
  if asserts: assert rms_final_weight.shape == (dim,)
  wcls = weights['wcls']
  if asserts: assert wcls.shape == (vocab_size, dim)

  token_embedding_table = weights['token_embedding_table']
  if asserts: assert token_embedding_table.shape == (vocab_size, dim)

  x = token_embedding_table[token, :]
  if asserts: assert x.shape == (dim, )

  wq = weights['wq']
  if asserts: assert wq.shape == (n_layers, dim, dim)

  wk = weights['wk']
  if asserts: assert wk.shape == (n_layers, kv_dim, dim)

  wv = weights['wv']
  if asserts: assert wv.shape == (n_layers, kv_dim, dim)

  toconv = []
       
  for i in range(0, dim, 2):
    freq = 1 / jnp.power(10000, (i % head_size) / head_size)
    val = pos * freq
    fcr = jnp.cos(val)
    fci = jnp.sin(val)

    rotM = jnp.array([[fcr, -fci],
                      [fci, fcr]])
    toconv.append(rotM)
  toconv2 = toconv[:kv_dim//2] + [jnp.eye(2)] * (dim//2 - kv_dim//2)
    
  toconv = jnp.array(toconv)
  toconv2 = jnp.array(toconv2)

  keys2 = []
  values2 = []
  for l in range(n_layers):
    xb = rmsnorm(x, rms_att_weight[l, :])
    if asserts: assert xb.shape == (dim, )

    q = wq[l, :, :] @ xb
    if asserts: assert q.shape == (dim, )

    k = wk[l, :, :] @ xb
    if asserts: assert q.shape == (kv_dim, )

    v = wv[l, :, :] @ xb
    if asserts: assert q.shape == (kv_dim, )
      
    # TODO inspect properly
    q2 = []
    k2 = []

    q_tmp = jnp.reshape(q, (dim // 2, 2))
    k_tmp = jnp.reshape(k, (dim // 2, 2))

    # dim == head_size * n_heads

    # Batched gemv
    k = jnp.reshape(jnp.einsum('ijk,ik -> ij', toconv2, k_tmp), (dim,))
    q = jnp.reshape(jnp.einsum('ijk,ik -> ij', toconv, q_tmp), (dim,))

    key_cache_l = key_cache[l, :, :]
    key_cache_l = jnp.append(key_cache_l, jnp.reshape(k, (1, dim)), axis=0)
    value_cache_l = value_cache[l, :, :]
    value_cache_l = jnp.append(value_cache_l, jnp.reshape(v, (1, dim)), axis=0)
    keys2.append(key_cache_l)
    values2.append(value_cache_l)
      
    xbs2 = []
    for h in range(n_heads):

      q2 = q[head_size*h:head_size*(h+1)]
      if asserts: assert q2.shape == (head_size,)

      # For kv_mul consecutive heads, they share the same kv cache
      # reshape key_cache last dim from (kv_dim,) to (kv_mul, head_size)
      # generalized einsum reducing the last dim, the rest are batch
      att = []

      key_index = h // kv_mul
        
      att = jnp.einsum('ij,j->i', key_cache_l[:, key_index * head_size : (key_index+1) * head_size], q2)

      att = att / jnp.sqrt(head_size)

      att = softmax(att)
        
      x_tmp = jnp.einsum('ij,i->j', value_cache_l[:, key_index * head_size : (key_index+1) * head_size], att)

      xbs2.append(x_tmp)

    # Todo right concat
    xb = jnp.concatenate(xbs2, axis=None)

    xb2 = wo[l, :, :] @ xb
    if asserts: assert xb2.shape == (dim, )

    x += xb2

    # Rmsnorm and feedforward swiglu

    xb = rmsnorm(x, rms_ffn_weight[l, :])

    # Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
    # first calculate self.w1(x) and self.w3(x)


    hb = w1[l, :, :] @ xb
    hb2 = w3[l, :, :] @ xb

    hb = silu(hb)

    hb = hb * hb2


    xb = w2[l, :, :] @ hb

    x += xb


  x = rmsnorm(x, rms_final_weight)
  logits = wcls @ x

  if asserts: assert logits.shape == (vocab_size,)

  for k in keys2:
    assert k.shape == (pos+1, kv_dim)
  key_cache = jnp.array(keys2)
  assert key_cache.shape == (n_layers, pos+1, kv_dim)
  value_cache = jnp.array(values2)
  assert value_cache.shape == (n_layers, pos+1, kv_dim)
  return logits, key_cache, value_cache


def str_lookup(string, vocab):
    # Find the first perfect match for string in vocab, return its index or -1 if not found
    try:
        index = vocab.index(string)
        return index
    except ValueError as err:
        return -1

def bpe_encode(tokenizer, text):
    vocab = tokenizer['vocab']
    vocab_scores = tokenizer['vocab_scores']
    tokens = []

    # First encode every individual character in the input text
    for pos, char in enumerate(text):
        string = char
        id = str_lookup(string, vocab)
        if id == -1:
            print(f"not a good prompt at pos {pos}")
            sys.exit(1)
        tokens.append(id)

    # Merge the best consecutive pair each iteration, according to the scores in vocab_scores
    while True:
        best_score = -1e10
        best_id = -1
        best_idx = -1

        for i in range(len(tokens) - 1):
            # Check if we can merge the pair (tokens[i], tokens[i+1])
            # string = vocab[tokens[i]].rstrip(b'\x00') + vocab[tokens[i + 1]].rstrip(b'\x00')
            string = vocab[tokens[i]] + vocab[tokens[i + 1]]
            id = str_lookup(string, vocab)
            if id != -1 and vocab_scores[id] > best_score:
                # This merge pair exists in vocab! Record its score and position
                best_score = vocab_scores[id]
                best_id = id
                best_idx = i

        if best_idx == -1:
            break  # We couldn't find any more pairs to merge, so we're done

        # Merge the consecutive pair (best_idx, best_idx+1) into new token best_id
        tokens[best_idx] = best_id
        # Delete token at position best_idx+1, shift the entire sequence back 1
        tokens = tokens[0:best_idx + 1] + tokens[best_idx + 2:]

    return tokens
    
def sample(key, logits, temperature, topp):
  if temperature == 0:
    return jnp.argmax(logits)
  else:
    logits = logits / temperature
    logits = softmax(logits)
    if topp <= 0 or topp >= 1:
      return jax.random.categorical(key, logits)
    else:
      raise NotImplementedError("not implemented topp")
      # return sample_topp(logits)

def run(text, tokenizer, steps, config, weights):
  inputs = bpe_encode(tokenizer, text)
  key = 0
  temperature = 0
  topp = 0

  n_layers = config['n_layers']
  seq_len = config['seq_len']
  n_heads = config['n_heads']
  dim = config['dim']
  n_kv_heads = config['n_kv_heads']
  kv_dim = dim // n_heads * n_kv_heads

  key_cache = jnp.zeros((n_layers, 0,kv_dim))
  value_cache = jnp.zeros((n_layers, 0,kv_dim))

  token = 1
  for pos in range(steps):
    logits, key_cache, value_cache = forward(token, config, weights, key_cache, value_cache)
    if pos < len(inputs):
      next = inputs[pos]
    else:
      next = sample(key, logits, temperature, topp)
    res = tokenizer['vocab'][next] if pos != -1 else "<-1>"
    print("token # ", pos, " ", res, " state " if (pos < len(inputs)) else "sampled", logits)
    token = next

def loss(token, pos, config, weights, key_cache, value_cache, temperature):
  logits = forward(token, config, weights, key_cache, value_cache)
  logits = logits / temperature
  logits = softmax(logits)
  return logits

In [3]:
with open('/Users/wmoses/work/llama2.c/stories15M.bin', 'rb') as f:
  data = f.read()

In [59]:
import numpy as np

INT_SIZE = 4

def load_weights(filename):
  config = {}
  with open(filename, 'rb') as f:
    read_int = lambda: int.from_bytes(f.read(INT_SIZE), byteorder='little')
    config['dim'] = read_int()
    config['hidden_dim'] = read_int()
    config['n_layers'] = read_int()
    config['n_heads'] = read_int()
    config['n_kv_heads'] = read_int()
    config['vocab_size'] = read_int()
    config['seq_len'] = read_int()

  # Weird encoding with negative size indicating actual weights present.
  shared_weights = config['vocab_size'] > 0
  config['vocab_size'] = abs(config['vocab_size'])

  # Map variables.
  dim = config['dim']
  hidden_dim = config['hidden_dim']
  n_layers = config['n_layers']
  n_heads = config['n_heads']
  n_kv_heads = config['n_kv_heads']
  vocab_size = config['vocab_size']
  seq_len = config['seq_len']
  print(f"Loading data with config: {config}")

  # Mmap all data.
  config_byte_size = len(config) * INT_SIZE
  all_weights = np.memmap(filename, dtype='float32', mode='r', offset=config_byte_size)
  offset = 0
  weights = {}
  def slice_data(name, shape):
    nonlocal offset
    nonlocal weights
    total = np.prod(np.array(shape))
    weights[name] = all_weights[offset:offset+total].reshape(shape)
    offset += total

  assert dim % n_heads == 0
  head_size = dim // n_heads

  # Take slices of mmaped data in the right order.
  slice_data("token_embedding_table",   (vocab_size, dim))
  slice_data("rms_att_weight",(n_layers, dim))
  slice_data("wq",(n_layers, dim, n_heads * head_size))
  slice_data("wk",(n_layers, dim, n_kv_heads * head_size))
  slice_data("wv",(n_layers, dim, n_kv_heads * head_size))
  slice_data("wo",(n_layers, n_heads * head_size, dim))
  slice_data("rms_ffn_weight",(n_layers, dim))
  slice_data("w1",(n_layers, hidden_dim, dim))
  slice_data("w2",(n_layers, dim, hidden_dim))
  slice_data("w3",(n_layers, hidden_dim, dim))
  slice_data("rms_final_weight",(dim,))
  offset += seq_len * head_size
  if shared_weights:
    weights["wcls"] = weights["token_embedding_table"]
  else:
    slice_data("wcls", (vocab_size, dim))
  assert offset == len(all_weights), "haven't read all data"
  return config, weights

import struct


def read_int_from_file(f):
  return int.from_bytes(f.read(INT_SIZE), byteorder='little')

def load_tokenizer(filename: str, vocab_size: int):
  vocab_scores = []
  vocab = []
  with open(filename, 'rb') as f:
    max_token_length = read_int_from_file(f)
    for i in range(vocab_size):
      score = struct.unpack('f', f.read(4))[0]
      vocab_scores.append(score)
      token_length = read_int_from_file(f)
      token = f.read(token_length)
      if type(token) is not str:
        token = token.decode('utf8')
      vocab.append(token)
  return {'vocab_scores':vocab_scores, 'vocab':vocab, 'max_token_length':max_token_length}
    

In [60]:
config, weights = load_weights('/Users/wmoses/work/llama2.c/stories15M.bin')

tokenizer = load_tokenizer('/Users/wmoses/work/llama2.c/tokenizer.bin', config['vocab_size'])

Loading data with config: {'dim': 288, 'hidden_dim': 768, 'n_layers': 6, 'n_heads': 6, 'n_kv_heads': 6, 'vocab_size': 32000, 'seq_len': 256}


In [222]:
run("Dream comes true this day", tokenizer, 30, config, weights)

token #  0   D  state  [-6.7907796  0.8281164 -6.790422  ... -6.7907    -6.7906713 -6.790539 ]
token #  1   ream  state  [ 0.20367432 -1.5172932   0.20359135 ...  0.20386982  0.20382214
  0.2036686 ]
token #  2    comes  state  [-6.1911945  0.994138  -6.1906934 ... -6.191182  -6.191168  -6.1909094]
token #  3    true  state  [-7.076321   1.8050578 -7.0759583 ... -7.076381  -7.0763884 -7.076169 ]
token #  4    this  state  [-10.818716    2.4258273 -10.818541  ... -10.818831  -10.818807
 -10.818767 ]
token #  5    day  state  [-7.903948   -0.22370028 -7.90354    ... -7.904048   -7.9038086
 -7.903598  ]
token #  6   . sampled [-9.333269  2.192757 -9.332975 ... -9.33329  -9.333334 -9.33309 ]
token #  7    A sampled [-3.4395585  5.33119   -3.4393303 ... -3.4395003 -3.4394808 -3.439447 ]
token #  8    little sampled [-9.060901  -2.7534804 -9.06044   ... -9.060832  -9.060694  -9.060497 ]
token #  9    girl sampled [-14.028574   -3.9483452 -14.02855   ... -14.028534  -14.028523
 -14.028587 ]
t

In [217]:
n_layers = config['n_layers']
seq_len = config['seq_len']
n_heads = config['n_heads']
dim = config['dim']
n_kv_heads = config['n_kv_heads']
kv_dim = dim // n_heads * n_kv_heads

def partial(func, config):
    def sfn(token, weights, key_cache, value_cache):
        return func(token, config, weights, key_cache, value_cache)
    return sfn

pos = 0
key_cache = jnp.zeros((n_layers, pos,kv_dim))
value_cache = jnp.zeros((n_layers, pos,kv_dim))

jfunc = jax.jit(partial(forward, config))
mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")

In [218]:
print(str(mlir))

module @jit_sfn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<6x288xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<6x288xf32> {mhlo.sharding = "{replicated}"}, %arg3: tensor<288xf32> {mhlo.sharding = "{replicated}"}, %arg4: tensor<32000x288xf32> {mhlo.sharding = "{replicated}"}, %arg5: tensor<6x768x288xf32> {mhlo.sharding = "{replicated}"}, %arg6: tensor<6x288x768xf32> {mhlo.sharding = "{replicated}"}, %arg7: tensor<6x768x288xf32> {mhlo.sharding = "{replicated}"}, %arg8: tensor<32000x288xf32> {mhlo.sharding = "{replicated}"}, %arg9: tensor<6x288x288xf32> {mhlo.sharding = "{replicated}"}, %arg10: tensor<6x288x288xf32> {mhlo.sharding = "{replicated}"}, %arg11: tensor<6x288x288xf32> {mhlo.sharding = "{replicated}"}, %arg12: tensor<6x288x288xf32> {mhlo.sharding = "{replicated}"}, %arg13: tensor<6x0x288xf32> {mhlo.sharding = "{replicated}"}, %arg14: tensor<6x0x288xf3

In [220]:
%timeit jfunc(1, weights, key_cache, value_cache)

2.99 ms ± 161 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
