In [None]:
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

In [None]:
import os

os.environ['KERAS_BACKEND'] = 'jax'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [None]:
import keras
import keras_nlp

In [None]:
from datasets import load_dataset

ds = load_dataset("abhinand/tamil-alpaca-orca")

In [None]:
from datasets import DatasetDict

# Initialize an empty list to store the formatted examples.
data = []

# Access the 'train' split of your dataset.
train_data = ds["train"]

# Iterate over each example in the dataset.
for example in train_data:
    # Filter out examples where 'input' is not empty (assuming you want simple cases).
    if example["input"]:
        continue
    
    # Create a template with instruction and output (similar to the previous format).
    template = "Instruction:\n{instruction}\n\nResponse:\n{output}"
    
    # Format the example and add it to the data list.
    data.append(template.format(**example))

# Limit to the first 1000 examples to keep it manageable.
data = data[:1000]

# Display the first few formatted examples (optional).
for i, example in enumerate(data[:5]):
    print(f"Example {i + 1}:\n{example}\n")


In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [None]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

In [None]:
prompt = template.format(
    instruction="ஒரு கதை எழுது",
    output="ஒரு கதை எழுது",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))