Inference on float16 should take ???hrs


training nb:

https://www.kaggle.com/code/pranshubahadur/tf-gemma-2-9b-lmsys-training-tpu

In [None]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate all TPU memory to minimize memory fragmentation and allocation overhead.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
import keras
import keras_nlp
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 2),
    ["batch", "model"],
    devices=['gpu:0', 'gpu:1'],
)
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["position_embedding/embeddings"] = (model_dim, None)

# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

layout_map["decoder_block.*layer_norm/scale"] = (model_dim,)
layout_map["decoder_block.*layer_norm/bias"] = (model_dim,)
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)


In [None]:
import jax
jax.default_device = jax.devices('cpu')[0]
jax.devices()

In [None]:
keras.config.set_dtype_policy("float16")


In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("/kaggle/input/gemma2/keras/gemma2_instruct_9b_en/1", trainable=False, dtype='int8')

gemma_lm.summary()

In [None]:
def remove_surrogates(text):
    return ''.join(char for char in text if not (0xD800 <= ord(char) <= 0xDFFF))


In [None]:
from pandas import read_csv, DataFrame

input_columns = ['prompt', 'response_a', 'response_b']
label_columns = ['winner_model_a', 'winner_model_b', 'winner_tie']

raw_test_dataset = read_csv('/kaggle/input/lmsys-chatbot-arena/test.csv')
#raw_test_dataset[input_columns] = raw_test_dataset[input_columns].map(lambda x: eval(x)[0])
#raw_test_dataset =raw_test_dataset.dropna().reset_index(drop=True)



train_dataset = DataFrame({
    'text' : raw_test_dataset[input_columns].agg('\n\nRESPONSE:\n\n'.join, axis=1).apply(lambda x: '\n\nPROMPT\n\n' + x).apply(lambda x: remove_surrogates(x)),
})

In [None]:
tokenizer = gemma_lm._preprocessor
backbone = gemma_lm.backbone

In [None]:
def preprocess_fn(text, label=None):
    preprocessed = tokenizer(text, sequence_length=512)[0]
    print(preprocessed)
    # Ensure the preprocess function returns only the necessary inputs
    return {'token_ids' : preprocessed['token_ids'], 'padding_mask' : preprocessed['padding_mask']}

In [None]:
import tensorflow as tf
from keras.layers import Input, Dense, Flatten, GlobalAveragePooling1D
from keras import Model


inputs = {
        "token_ids": keras.Input(shape=(512,), dtype=tf.int32, name="token_ids"),
        "padding_mask": keras.Input(shape=(512,), dtype=tf.int32, name="padding_mask"),
    }
x = backbone(inputs)
print(x.shape)
x = GlobalAveragePooling1D()(x)
print(x.shape)

outputs = Dense(3, 'softmax')(x)
model = Model(inputs, outputs)

In [None]:

optimizer = keras.optimizers.AdamW(
                    learning_rate=5e-5,
                    weight_decay=0.01,)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])


In [None]:
model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(),)

In [None]:
model.layers[2].load_lora_weights("/kaggle/input/tf-gemma-2-9b-lmsys-training-tpu/model.lora.h5")

In [None]:
import numpy as np
dense_1_weights = np.load('/kaggle/input/tf-gemma-2-9b-lmsys-training-tpu/dense_1_kernel.npy')
dense_1_biases = np.load('/kaggle/input/tf-gemma-2-9b-lmsys-training-tpu/dense_1_bias.npy')
dense_1_combined = [dense_1_weights, dense_1_biases]
model.layers[-1].set_weights(dense_1_combined)


In [None]:
for layer in model.layers:
    layer.trainable = False

In [None]:
model.summary()

In [None]:
ds = tf.data.Dataset.from_tensor_slices((train_dataset.text.values)).map(preprocess_fn).batch(16)


In [None]:
from tqdm import tqdm

preds = []

for inputs in tqdm(ds):
    keras.backend.clear_session(free_memory=True)
    preds.append(model(inputs))
    keras.backend.clear_session()

    



In [None]:
import numpy as np
results = np.concatenate(preds)

In [None]:
import pandas
submission = pandas.DataFrame(data=results, index=raw_test_dataset.id, columns=label_columns)

In [None]:
submission.to_csv('submission.csv', index=False)

In [None]:
submission.head()