# Example Notebook: SFT with LORA on Gemma2-2b

## **Before You Begin: Connect to a TPU runtime. The free v2-8 TPU runtime will work for this example.**

# Install Kithara

In [None]:
# Install Kithara
!pip install kithara[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://download.pytorch.org/whl/cpu

# Colab specific set up
!pip uninstall torchvision -y && pip install torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
!pip install flask==2.1.3

# Login to HuggingFace with your access token


In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import jax
from typing import Tuple
import ray
from ray.data import Dataset
from typing import Union, Optional, List
from kithara import KerasHubModel, MaxTextModel, Dataloader, Trainer, PredefinedShardingStrategy, SFTDataset,PredefinedShardingStrategy
from transformers import AutoTokenizer

Installing MaxText... This should only happen once when Kithara is first initiated.
MaxText installed successfully
JAX compilation cached at /root/.keras/jax_cache
       '==='
        |||
     '- ||| -'
    /  |||||  \   Kithara. Platform: Linux. JAX: 0.5.2
   |   (|||)   |  Hardware: TPU v2. Device count: 8.
   |   |◕‿◕|   |  HBM Per Device: 7.48 GB. Total HBM Memory: 59.86 GB
    \  |||||  /   Free Apache license: http://github.com/ai-hypercomputer/kithara
     --|===|--


In [None]:
from huggingface_hub import login
import os
hf_token = ""
login(token=hf_token, add_to_git_credential=False)

# Create Model

In [None]:
model = KerasHubModel.from_preset(
    "hf://google/gemma-2-2b",
    precision="mixed_bfloat16",
    lora_rank=4, # Specify LoRA Rank here
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

# Create Data


In [None]:
# Create Toy Data
dataset_items = [
    {
        "prompt": "What is your name?",
        "answer": "My name is Kithara",
    }
    for _ in range(1000)
]
dataset = ray.data.from_items(dataset_items)
train_ds, eval_ds = dataset.train_test_split(test_size=500)

# Create Datasets
train_dataset = SFTDataset(
    train_ds,
    tokenizer_handle="hf://google/gemma-2-2b",
    max_seq_len=1024,
)

eval_dataset = SFTDataset(
    eval_ds,
    tokenizer_handle="hf://google/gemma-2-2b",
    max_seq_len=1024,
)

# Create Dataloders
train_dataloader = Dataloader(
    train_dataset,
    per_device_batch_size=1
)

eval_dataloader = Dataloader(
    eval_dataset,
    per_device_batch_size=1
)


##Initialize trainer and start training

In [None]:
optimizer = keras.optimizers.AdamW(
    learning_rate=2e-4,
    weight_decay=0.01
)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    steps=200, # You can also use epochs instead of steps
    eval_steps_interval=50,
    max_eval_samples=50,
    log_steps_interval=10,
)

trainer.train()

       '==='
        |||
     '- ||| -'
    /  |||||  \   Kithara | Device Count = 8
   |   (|||)   |  Steps = 200 | Batch size per device = 1
   |   |◕‿◕|   |  Total batch size = 8 | Total parameters = 9.750(GB)
    \  |||||  /   Trainable parameters = 0.011(GB) (0.11%) | Non-trainable = 9.739(GB)
     --|===|--   
model <kithara.model.kerashub.keras_hub_model.KerasHubModel object at 0x7f8a4c02d150>
optimizer <keras.src.optimizers.adamw.AdamW object at 0x7f92901e89d0>
train_dataloader <kithara.dataset.dataloader.Dataloader object at 0x7f88d4180a50>
eval_dataloader <kithara.dataset.dataloader.Dataloader object at 0x7f929029af10>
steps 200
epochs None
tensorboard_dir None
step_count 0
epoch_count 0
eval_steps_interval 10
eval_epochs_interval None
max_eval_samples 50
log_steps_interval 10
global_batch_size 8
profiler None
checkpointer None
callbacks <keras.src.callbacks.callback_list.CallbackList object at 0x7f88d41a3310>
train_step <PjitFunction of <bound method Trainer._train_step of <

## Prompt the model

In [None]:
pred = model.generate(
    "What is your name?",
    max_length=30,
    tokenizer_handle="hf://google/gemma-2-2b",
    return_decoded=True
)
print("Tuned model generates:", pred)

Tuned model generates: ['What is your name?My name is Kithara']
