Use latest version environment and TPUs ( Not GPU )

In [None]:
!pip install -q -U keras-nlp tensorflow-text
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
!pip install -q -U tensorflow-cpu

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.16.0 requires tensorflow<2.17,>=2.16, but you have tensorflow 2.17.0 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import jax

jax.devices()

In [None]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate all TPU memory to minimize memory fragmentation and allocation overhead.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [None]:
import keras
import keras_nlp

In [None]:
device_mesh = keras.distribution.DeviceMesh(
    (3, 6),
    ["batch", "model"],
    devices=keras.distribution.list_devices(),
)

In [None]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

In [None]:
def remove_surrogates(text):
    return ''.join(char for char in text if not (0xD800 <= ord(char) <= 0xDFFF))


In [None]:
import pandas as pd

# Define the prompt template and EOS token
alpaca_prompt = """Below is a conversation, identify the reason for the conversation that the client had to call

# Conversation:
{}

# Output:
{}"""

EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN

# Define the formatting function
def formatting_prompts_func(df):
    Text = df["call_transcript"]
    outputs = df["primary_call_reason"]
    texts = []
    for instruction, output in zip(Text, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, output) + EOS_TOKEN
        texts.append(text)
    return texts

df['formatted_text'] = formatting_prompts_func(df)

# View the result
#df[['text', 'formatted_text']]

In [None]:
from pandas import read_csv, DataFrame


raw_train_dataset = read_csv('') #filepath


train_dataset = DataFrame({
    'text' : raw_train_dataset[input_columns].apply(lambda x: remove_surrogates(x)),
    'label' : raw_train_dataset[label_columns].apply(lambda x: x.values.tolist(), axis=1)
})

In [None]:
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

keras.distribution.set_distribution(model_parallel)


In [None]:
#keras.config.set_floatx("bfloat16")

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("/kaggle/input/gemma2/keras/gemma2_instruct_9b_en/1")
gemma_lm.summary()

In [None]:
gemma_lm.backbone.enable_lora(rank=16)

In [None]:
for layer in gemma_lm._backbone.layers[:16]:
    layer.trainable = False

In [None]:
gemma_lm.summary()

In [None]:
def preprocess_fn(text, label=None):
    preprocessed = gemma_lm._preprocessor(text, sequence_length=1024)[0]
    # Ensure the preprocess function returns only the necessary inputs
    return {'token_ids' : preprocessed['token_ids'], 'padding_mask' : preprocessed['padding_mask']}, label if label is not None else text

In [None]:
import tensorflow as tf
from keras.layers import Input, Dense, Flatten, GlobalAveragePooling1D
from keras import Model

inputs = {
    "token_ids": keras.Input(shape=(1024,), dtype=tf.int32, name="token_ids"),
    "padding_mask": keras.Input(shape=(1024,), dtype=tf.int32, name="padding_mask"),
}
x = gemma_lm.backbone(inputs)
print(x.shape)
x = GlobalAveragePooling1D()(x)
print(x.shape)

outputs = Dense(54, 'softmax')(x)
model = Model(inputs, outputs)

In [None]:
optimizer = keras.optimizers.AdamW(
                learning_rate=5e-5,
                weight_decay=0.01,)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])


In [None]:
model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(),)

In [None]:
import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices((train_dataset.text.values, raw_train_dataset[label_columns].values)).batch(4).map(preprocess_fn)
ds = ds.shuffle(ds.cardinality())


In [None]:
train_split = ds.take(int(len(ds)*0.9))
val_split = ds.skip(int(len(ds)*0.9)).take(int(len(ds)*0.1))
histories = model.fit(train_split, validation_data=[val_split], epochs=1, batch_size=4)

In [None]:
import numpy as np
layer = model.get_layer(name='dense')
weights = layer.get_weights()
kernel, bias = weights

# Save the kernel and bias separately
np.save('dense_1_kernel.npy', kernel)
np.save('dense_1_bias.npy', bias)
model.layers[2].save_lora_weights("model.lora.h5")