# LMSYS Keras Gemma 2B

In [None]:
!pip install -q -U keras-nlp tensorflow-text
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
!pip install -q -U tensorflow-cpu

In [None]:
import jax

jax.devices()

In [None]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate all TPU memory to minimize memory fragmentation and allocation overhead.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [None]:
import keras
import keras_nlp

In [None]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (8, 1),
    ["batch", "model"],
    devices=keras.distribution.list_devices(),
)

In [None]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

In [None]:
def remove_surrogates(text):
    return ''.join(char for char in text if not (0xD800 <= ord(char) <= 0xDFFF))


In [None]:
from pandas import read_csv, DataFrame

input_columns = ['prompt', 'response_a', 'response_b']
label_columns = ['winner_model_a', 'winner_model_b', 'winner_tie']

raw_train_dataset = read_csv('/kaggle/input/lmsys-chatbot-arena/train.csv')
raw_train_dataset[input_columns] = raw_train_dataset[input_columns].map(lambda x: eval(x)[0] if 'null' not in x else None)

raw_train_dataset = raw_train_dataset.dropna().drop(['model_a', 'model_b'], axis=1).reset_index(drop=True)


train_dataset = DataFrame({
    'text' : raw_train_dataset[input_columns].apply(lambda x: '<start_of_turn>user\nFind which one is the best answer for the question:\n'+x['prompt']+'\n\nA:\n'+x['response_a']+'\n\nB\n:'+x['response_b']+'\n\nC:\n both right (or) both wrong<end_of_turn>\n<start_of_turn>model\n', axis=1).apply(lambda x: remove_surrogates(x)),
    'label' : raw_train_dataset[label_columns].apply(lambda x: x.values.tolist(), axis=1)
#         'label' : raw_train_dataset[label_columns].apply(lambda x: 'A' if x.values.tolist()[0] == 1 else 'B' if x.values.tolist()[1] == 1 else 'C', axis=1)
})

train_dataset = train_dataset[:4000]
raw_train_dataset = raw_train_dataset[:4000]

In [None]:
len(train_dataset)

In [None]:
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)

In [None]:
keras.config.set_floatx("float16")

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("/kaggle/input/gemma/keras/gemma_instruct_2b_en/2")
gemma_lm.summary()

In [None]:
gemma_lm.backbone.enable_lora(rank=8)


In [None]:
# for layer in gemma_lm.backbone.layers[:16]:
#     layer.trainable = False

In [None]:
gemma_lm.summary()

In [None]:
def preprocess_fn(text, label=None):
    preprocessed = gemma_lm._preprocessor(text, sequence_length=3072)[0]
    # Ensure the preprocess function returns only the necessary inputs
    return {'token_ids' : preprocessed['token_ids'], 'padding_mask' : preprocessed['padding_mask']}, label if label is not None else {'token_ids' : preprocessed['token_ids'], 'padding_mask' : preprocessed['padding_mask']}

In [None]:
gemma_lm.layers[-1]

In [None]:
import gc
del gemma_lm.layers[-1]

gc.collect()

In [None]:
import tensorflow as tf
from keras.layers import Input, Dense, Flatten, GlobalAveragePooling1D
from keras import Model

inputs = {
    "token_ids": keras.Input(shape=(3072,), dtype=tf.int32, name="token_ids"),
    "padding_mask": keras.Input(shape=(3072,), dtype=tf.int32, name="padding_mask"),
}
x = gemma_lm.backbone(inputs)
print(x.shape)
x = GlobalAveragePooling1D()(x)
print(x.shape)

outputs = Dense(3, 'softmax')(x)
model = Model(inputs, outputs)

In [None]:
optimizer = keras.optimizers.AdamW(
                learning_rate=5e-5,
                weight_decay=0.01,)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])


In [None]:
model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(),)

In [None]:
import tensorflow as tf


ds = tf.data.Dataset.from_tensor_slices((train_dataset.text.values, raw_train_dataset[label_columns].values)).batch(8).map(preprocess_fn)
ds = ds.shuffle(ds.cardinality())


In [None]:
train_split = ds.take(int(len(ds)*0.9))
val_split = ds.skip(int(len(ds)*0.9)).take(int(len(ds)*0.1))
histories = model.fit(train_split, validation_data=[val_split], epochs=1, batch_size=8)

In [None]:
model.get_layer("gemma_backbone").save_lora_weights('/kaggle/working/lora19.lora.h5')