In [8]:
! pip install flax tensorflow tensorflow_datasets

  pid, fd = os.forkpty()


Defaulting to user installation because normal site-packages is not writeable
[0m

In [9]:
! pip install "jax[tpu]"

Defaulting to user installation because normal site-packages is not writeable
[0m

In [10]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## With Positional Embeddings

In [11]:
import tensorflow as tf
import tensorflow_datasets as tfds

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state

import functools

import flax.linen.attention as attention

import numpy as np

import optax

import time



In [12]:
BATCH_IN_SEQUENCES = 384
SEQUENCE_LENGTH = 128

VOCAB_DIM = 256
EMBED_DIM = 512
FF_DIM = 2048

NUM_HEADS = 4
HEAD_DIM = 128

LAYERS = 2

HEAD_DEPTH = 128
NUM_HEADS = 4

LEARNING_RATE = 1e-3

FSDP = 4 # 8 OR 4
TENSOR = 1 # 1 OR 2


def attention_ourselves(_Q, _K, _V):
    _weights_unnormalized = jax.numpy.einsum("BSHD,BTHD->BHST", _Q, _K)
    _weights_unnormalized_to_zero_out = jax.numpy.triu( jax.numpy.ones((SEQUENCE_LENGTH,SEQUENCE_LENGTH), jax.numpy.bfloat16), 1)
    _weights = jax.nn.softmax(_weights_unnormalized - 1e6 * _weights_unnormalized_to_zero_out)  ### Creating something of size (B,HEADS, SEQUENCE, SEQUENCE)
    #print(f"{_weights.size=}")
    output = jax.numpy.einsum("BHST,BTHD->BSHD", _weights, _V)

    return output

class OurModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    '''
        x is [BATCH, SEQUENCE]
    '''
    embedding = self.param(
        'embedding',
        nn.with_partitioning(nn.initializers.normal(1), ("tp", "fsdp")),
        (VOCAB_DIM, EMBED_DIM),
        jnp.float32,
    )
    x = embedding[x] ##OUTPUT should be [BATCH, SEQUENCE, EMBED]

    positional_embedding = self.param(
        'positional_embedding',
        nn.with_partitioning(nn.initializers.normal(1), (None, None, "fsdp")),
        (1, SEQUENCE_LENGTH, EMBED_DIM),
        jnp.float32,
    )

    x += positional_embedding


    for i in range(LAYERS):
      feedforward = self.param(
          'feedforward_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, FF_DIM),
          jnp.float32,
      )
      x = x @ feedforward
      x = jax.nn.relu(x)
      embed = self.param(
          'embed_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('tp', 'fsdp')),
          (FF_DIM, EMBED_DIM),
          jnp.float32,
      )
      x = x @ embed
      x = jax.nn.relu(x)

      q_proj = self.param(
          'qproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, NUM_HEADS, HEAD_DIM),
          jnp.float32,
      )
      q = jnp.einsum("BSE,EHD->BSHD",x, q_proj )

      k_proj = self.param(
          'kproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, NUM_HEADS, HEAD_DIM),
          jnp.float32,
      )
      k = jnp.einsum("BSE,EHD->BSHD",x, k_proj )

      v_proj = self.param(
          'vproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, NUM_HEADS, HEAD_DIM),
          jnp.float32,
      )
      v = jnp.einsum("BSE,EHD->BSHD",x, v_proj )

      o = attention_ourselves(q,k,v)

      o_proj = self.param(
          'oproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (NUM_HEADS, HEAD_DIM, EMBED_DIM),
          jnp.float32,
      )
      x = jnp.einsum("BSHD,HDE->BSE",o, o_proj )

    return x @ embedding.T

def convert_to_ascii(string_array, max_length):
  result = np.zeros((len(string_array), max_length), dtype=np.uint8)
  for i, string in enumerate(string_array):
    for j, char in enumerate(string):
      if j >= SEQUENCE_LENGTH:
         break
      result[i, j] = char
  return result

def input_to_output(np_array):
   zero_array = np.zeros( (BATCH_IN_SEQUENCES,SEQUENCE_LENGTH), dtype = jnp.uint8)
   zero_array[:, 1:SEQUENCE_LENGTH] = np_array[:, 0:SEQUENCE_LENGTH-1]
   return zero_array

def calculate_loss(params, model, inputs, outputs):
   proposed_outputs = model.apply(params, inputs)
   one_hot = jax.nn.one_hot(outputs, VOCAB_DIM)
   loss = optax.softmax_cross_entropy(proposed_outputs, one_hot)
   return jnp.mean(loss)


def step(state, model, inputs, outputs):
   loss, grad = jax.value_and_grad(calculate_loss)(state.params, model, inputs, outputs)
   state = state.apply_gradients(grads = grad)
   return loss, state


mesh = jax.sharding.Mesh(np.reshape(  jax.devices(), (FSDP,TENSOR)), ["fsdp", "tp"])

ds = tfds.load('lm1b', split='train', shuffle_files=False)
ds = ds.batch(BATCH_IN_SEQUENCES)

rngkey = jax.random.key(0)
model = OurModel()

shaped_init = jax.eval_shape( functools.partial(model.init, rngkey), jax.ShapeDtypeStruct((BATCH_IN_SEQUENCES, SEQUENCE_LENGTH), dtype = jnp.uint8))
state_sharding = nn.get_sharding(shaped_init, mesh)
_params = jax.jit(model.init, out_shardings = state_sharding)(rngkey, jax.ShapeDtypeStruct((BATCH_IN_SEQUENCES, SEQUENCE_LENGTH), dtype = jnp.uint8))

tx = optax.adam(learning_rate = LEARNING_RATE)
state = train_state.TrainState.create(
    apply_fn = model.apply,
    params = _params,
    tx = tx
)

iter = 0
static_step = jax.jit(step, static_argnums=1)

last_step_time = time.time()
stepnum = 0

for example in ds:
    outputs = convert_to_ascii(example['text'].numpy(), SEQUENCE_LENGTH)
    inputs = input_to_output(outputs)

    loss, state = static_step(state, model, inputs, outputs)
    #loss, state = jax.jit(step, static_argnums=1)(state, model, inputs, outputs)
    #loss, state = jax.jit(lambda x,y,z,a : step(x,y,z,a), static_argnums=1)(state, model, inputs, outputs)

    stepnum += 1

    if stepnum % 10 == 0:
      new_time = time.time()
      time_elapsed_seconds = (new_time-last_step_time)
      last_step_time = new_time
      print(f"{iter} -> {loss} {time_elapsed_seconds}")

    if stepnum == 1000:
       break

    iter += 1


9 -> 4.610074996948242 15.437712669372559
19 -> 3.3865103721618652 0.5307536125183105
29 -> 3.1280136108398438 0.5627996921539307
39 -> 2.9500465393066406 0.5636236667633057
49 -> 2.6155314445495605 0.5631210803985596
59 -> 2.5496368408203125 0.565514087677002
69 -> 2.5889205932617188 0.5621740818023682
79 -> 2.585242748260498 0.5648152828216553
89 -> 2.5686445236206055 0.5646486282348633
99 -> 2.53019380569458 0.5624356269836426
109 -> 2.8131237030029297 0.5647952556610107
119 -> 2.562844753265381 0.5635154247283936
129 -> 2.545961856842041 0.563058614730835
139 -> 2.5722155570983887 0.5657546520233154
149 -> 2.4941372871398926 0.5623538494110107
159 -> 2.4459774494171143 0.5646297931671143
169 -> 2.435457229614258 0.5642399787902832
179 -> 2.4510440826416016 0.5650138854980469
189 -> 2.4776196479797363 0.564751148223877
199 -> 2.3942651748657227 0.5634090900421143
209 -> 2.4091174602508545 0.5627098083496094
219 -> 2.4573934078216553 0.5660290718078613
229 -> 2.446984052658081 0.5634

In [None]:
import tensorflow as tf

def predict(input_str, model, params):
    input_tf = tf.constant([input_str], shape=(1,), dtype=tf.string)
    input_ascii = convert_to_ascii(input_tf.numpy(), SEQUENCE_LENGTH)
    logits = model.apply(params, input_ascii)
    return logits

# for example in ds:
#     print(example["text"][1:2])
#     inputs = convert_to_ascii(example['text'][1:2].numpy(), SEQUENCE_LENGTH)
#     break
logits_1 = predict("test a", model, state.params)
logits_2 = predict("test a", model, state.params)
logits_3 = predict("a test", model, state.params)
print(logits_1.shape)
print("logits 1 and 2 are the same:", jnp.array_equal(logits_1, logits_2))
print("logits 1 and 3 are NOT the same:", not jnp.array_equal(logits_1, logits_3))

(1, 128, 256)
logits 1 and 2 are the same: True
logits 1 and 3 are NOT the same: True


## Remove position embedding and rerun check

In [16]:
BATCH_IN_SEQUENCES = 384
SEQUENCE_LENGTH = 128

VOCAB_DIM = 256
EMBED_DIM = 512
FF_DIM = 2048

NUM_HEADS = 4
HEAD_DIM = 128

LAYERS = 2

HEAD_DEPTH = 128

LEARNING_RATE = 1e-3

FSDP = 4
TENSOR = 1


def attention_ourselves(_Q, _K, _V):
    _weights_unnormalized = jax.numpy.einsum("BSHD,BTHD->BHST", _Q, _K)
    _weights_unnormalized_to_zero_out = jax.numpy.triu( jax.numpy.ones((SEQUENCE_LENGTH,SEQUENCE_LENGTH), jax.numpy.bfloat16), 1)
    _weights = jax.nn.softmax(_weights_unnormalized - 1e6 * _weights_unnormalized_to_zero_out)  ### Creating something of size (B,HEADS, SEQUENCE, SEQUENCE)
    #print(f"{_weights.size=}")
    output = jax.numpy.einsum("BHST,BTHD->BSHD", _weights, _V)

    return output

class OurModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    '''
        x is [BATCH, SEQUENCE]
    '''
    embedding = self.param(
        'embedding',
        nn.with_partitioning(nn.initializers.normal(1), ("tp", "fsdp")),
        (VOCAB_DIM, EMBED_DIM),
        jnp.float32,
    )
    x = embedding[x] ##OUTPUT should be [BATCH, SEQUENCE, EMBED]

    # positional_embedding = self.param(
    #     'positional_embedding',
    #     nn.with_partitioning(nn.initializers.normal(1), (None, None, "fsdp")),
    #     (1, SEQUENCE_LENGTH, EMBED_DIM),
    #     jnp.float32,
    # )

    # x += positional_embedding


    for i in range(LAYERS):
      feedforward = self.param(
          'feedforward_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, FF_DIM),
          jnp.float32,
      )
      x = x @ feedforward
      x = jax.nn.relu(x)
      embed = self.param(
          'embed_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('tp', 'fsdp')),
          (FF_DIM, EMBED_DIM),
          jnp.float32,
      )
      x = x @ embed
      x = jax.nn.relu(x)

      q_proj = self.param(
          'qproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, NUM_HEADS, HEAD_DIM),
          jnp.float32,
      )
      q = jnp.einsum("BSE,EHD->BSHD",x, q_proj )

      k_proj = self.param(
          'kproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, NUM_HEADS, HEAD_DIM),
          jnp.float32,
      )
      k = jnp.einsum("BSE,EHD->BSHD",x, k_proj )

      v_proj = self.param(
          'vproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (EMBED_DIM, NUM_HEADS, HEAD_DIM),
          jnp.float32,
      )
      v = jnp.einsum("BSE,EHD->BSHD",x, v_proj )

      o = attention_ourselves(q,k,v)

      o_proj = self.param(
          'oproj_' + str(i),
          nn.with_partitioning(nn.initializers.lecun_normal(), ('fsdp', 'tp')),
          (NUM_HEADS, HEAD_DIM, EMBED_DIM),
          jnp.float32,
      )
      x = jnp.einsum("BSHD,HDE->BSE",o, o_proj )

    return x @ embedding.T # missing the softmax with probabilities for each token?

def convert_to_ascii(string_array, max_length):
  result = np.zeros((len(string_array), max_length), dtype=np.uint8)
  for i, string in enumerate(string_array):
    for j, char in enumerate(string):
      if j >= SEQUENCE_LENGTH:
         break
      result[i, j] = char
  return result

def input_to_output(np_array):
   zero_array = np.zeros( (BATCH_IN_SEQUENCES,SEQUENCE_LENGTH), dtype = jnp.uint8)
   zero_array[:, 1:SEQUENCE_LENGTH] = np_array[:, 0:SEQUENCE_LENGTH-1]
   return zero_array

def calculate_loss(params, model, inputs, outputs):
   proposed_outputs = model.apply(params, inputs)
   one_hot = jax.nn.one_hot(outputs, VOCAB_DIM)
   loss = optax.softmax_cross_entropy(proposed_outputs, one_hot)
   return jnp.mean(loss)


def step(state, model, inputs, outputs):
   loss, grad = jax.value_and_grad(calculate_loss)(state.params, model, inputs, outputs)
   state = state.apply_gradients(grads = grad)
   return loss, state


mesh = jax.sharding.Mesh(np.reshape(  jax.devices(), (FSDP,TENSOR)), ["fsdp", "tp"])

ds = tfds.load('lm1b', split='train', shuffle_files=False)
ds = ds.batch(BATCH_IN_SEQUENCES)

rngkey = jax.random.key(0)
model = OurModel()

shaped_init = jax.eval_shape( functools.partial(model.init, rngkey), jax.ShapeDtypeStruct((BATCH_IN_SEQUENCES, SEQUENCE_LENGTH), dtype = jnp.uint8))
state_sharding = nn.get_sharding(shaped_init, mesh)
_params = jax.jit(model.init, out_shardings = state_sharding)(rngkey, jax.ShapeDtypeStruct((BATCH_IN_SEQUENCES, SEQUENCE_LENGTH), dtype = jnp.uint8))

tx = optax.adam(learning_rate = LEARNING_RATE)
state = train_state.TrainState.create(
    apply_fn = model.apply,
    params = _params,
    tx = tx
)

iter = 0
static_step = jax.jit(step, static_argnums=1)

last_step_time = time.time()
stepnum = 0

for example in ds:
    outputs = convert_to_ascii(example['text'].numpy(), SEQUENCE_LENGTH)
    inputs = input_to_output(outputs)

    loss, state = static_step(state, model, inputs, outputs)
    # assert model_without_pos_embeddings("Sam and Erin") == model_without_pos_embeddings("Erin and Sam")
    # assert model_with_pos_embeddings("Sam and Erin") != model_with_pos_embeddings("Erin and Sam")
    #loss, state = jax.jit(step, static_argnums=1)(state, model, inputs, outputs)
    #loss, state = jax.jit(lambda x,y,z,a : step(x,y,z,a), static_argnums=1)(state, model, inputs, outputs)

    stepnum += 1

    if stepnum % 10 == 0:
      new_time = time.time()
      time_elapsed_seconds = (new_time-last_step_time)
      last_step_time = new_time
      print(f"{iter} -> {loss} {time_elapsed_seconds}")


    iter += 1
    if stepnum == 1000:
      break


9 -> 3.7439961433410645 14.990861177444458
19 -> 3.3444061279296875 0.524946928024292
29 -> 2.9133424758911133 0.5640227794647217
39 -> 2.644909381866455 0.5637977123260498
49 -> 2.5924177169799805 0.5632658004760742
59 -> 2.5294244289398193 0.5666470527648926
69 -> 2.569258451461792 0.5604920387268066
79 -> 2.562154531478882 0.5642075538635254
89 -> 2.514817237854004 0.5647234916687012
99 -> 2.442938804626465 0.5620982646942139
109 -> 2.5612828731536865 0.5638775825500488
119 -> 2.4192612171173096 0.5632226467132568
129 -> 2.3162553310394287 0.5659706592559814
139 -> 2.3402962684631348 0.563025951385498
149 -> 2.225848436355591 0.5639653205871582
159 -> 2.185987949371338 0.5623829364776611
169 -> 2.1876444816589355 0.5629620552062988
179 -> 2.204556465148926 0.5645332336425781
189 -> 2.2342259883880615 0.5638141632080078
199 -> 2.1401021480560303 0.5638136863708496
209 -> 2.1552047729492188 0.5641112327575684
219 -> 2.2246217727661133 0.5627484321594238
229 -> 2.161456823348999 0.5646

In [19]:
import tensorflow as tf

def predict(input_str, model, params):
    input_tf = tf.constant([input_str], shape=(1,), dtype=tf.string)
    input_ascii = convert_to_ascii(input_tf.numpy(), SEQUENCE_LENGTH)
    logits = model.apply(params, input_ascii)
    return logits

# for example in ds:
#     print(example["text"][1:2])
#     inputs = convert_to_ascii(example['text'][1:2].numpy(), SEQUENCE_LENGTH)
#     break
logits_1 = predict("test a", model, state.params)
logits_2 = predict("test a", model, state.params)
logits_3 = predict("a test", model, state.params)
print(logits_1.shape)
print("logits 1 and 2 are the same:", jnp.array_equal(logits_1, logits_2))
print("logits 1 and 3 are the same:", jnp.array_equal(logits_1, logits_3))

(1, 128, 256)
logits 1 and 2 are the same: True
logits 1 and 3 are the same: False


In [None]:
# Near the end of your script, replace the definition of input_1

def pad(raw_input_1):
    padding_length = SEQUENCE_LENGTH - raw_input_1.shape[1]
    if padding_length < 0:
        input_1 = raw_input_1[:, :SEQUENCE_LENGTH]
    elif padding_length > 0:
        input_1 = jnp.pad(raw_input_1, ((0, 0), (0, padding_length)), mode='constant', constant_values=0)
    else:
        input_1 = raw_input_1
    return input_1


raw_input_1 = jnp.array([[10, 5]], dtype=jnp.int32)
raw_input_2 = jnp.array([[5, 10]], dtype=jnp.int32)
input_1 = pad(raw_input_1)
input_2 = pad(raw_input_2)

def predict_tensor(ascii_tensor, model, params):
    logits = model.apply(params, ascii_tensor)
    return logits

logits_1 = predict_tensor(input_1, model, state.params) # Use the padded input
logits_2 = predict_tensor(input_2, model, state.params) # Use the padded input
logits_3 = predict_tensor(input_2, model, state.params) # Use the padded input
print("Logits shape:", logits_1.shape)
print("logits 1 and 2 are the same:", jnp.array_equal(logits_1, logits_2))
print("logits 2 and 3 are the same:", jnp.array_equal(logits_2, logits_3))
print(f"{logits_1[0][0][0]=}")
print(f"{logits_2[0][0][0]=}")

Logits shape: (1, 128, 256)
logits 1 and 2 are the same: False
logits 2 and 3 are the same: True
logits_1[0][0][0]=Array(8.40613, dtype=float32)
logits_2[0][0][0]=Array(1.3883572, dtype=float32)
