# *Kithara* - Finetune LLMs on TPU and GPU

## Overview

1. Introduction to Kithara
2. Kithara Demo

## Introduction to Kithara

Kithara will be an accelerator-agnostic, lightweight library offering tools and recipes for tuning popular open source LLMs on TPUs and GPUs. 

go/kithara-dd

go/kithara-slides  

go/kithara-design-review-recording 

## Kithara Demo 

The goal of this demo is to show the following key features of Kithara. 

1. Native integration with HuggingFace: Load and save models in HuggingFace format

2. LoRA support

3. Ease of scaling single host workload to multihost

4. GPU/TPU Fungibility - same code runs on both GPU and TPUs

5. Extensive dataset format support

6. Smart defaults for model optimizations and parallelism

7. Support for tuning MaxText models 


### This demo is currently running on V4-8 (single-host)

In [3]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
devices = keras.distribution.list_devices()
print(f"Available devices: {devices}")

import jax 
jax.config.update("jax_compilation_cache_dir", "/dev/shm/temp/xla_cache")

Available devices: ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3']


#### Imports

In [1]:
import ray
import keras_tuner as kithara

  from .autonotebook import tqdm as notebook_tqdm
2024-12-04 09:16:09,979	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-12-04 09:16:11,186	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


### Load model from HuggingFace Hub, Enable LoRA

In [2]:
model_handle = "hf://google/gemma-2-2b"

In [None]:
model = kithara.KerasHubModel(
    model_handle=model_handle,
    precision="mixed_bfloat16",
    lora_rank=4,
    # Predefined Sharding Strategy 
    sharding_strategy=kithara.ShardingStrategy(
        parallelism="fsdp", model="gemma"
    ),
    # Flash Attention is activated by default
    use_flash_attention= True
)

### Load Dataset

Kithara will support an extensive list of datasets and dataset formats, including HuggingFace, CSV, JSON, JSONL, and more. 

*Features:* 
- Streaming Dataset
- Multihost distributed dataloading
- Integration with Cloud: GCS, Azure, AWS


#### Data source: HuggingFace

In [11]:
"""Load the C4 dataset from HuggingFace. Load in streaming mode. """
from datasets import load_dataset
train_ds = load_dataset("allenai/c4", "en", split="train", streaming=True)
test_ds = load_dataset("allenai/c4", "en", split="validation", streaming=True)
train_ds, test_ds = ray.data.from_huggingface(train_ds), ray.data.from_huggingface(test_ds)

#### Data source: CSV

In [None]:
"""Load CSV dataset from Cloud"""
ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
train_ds, test_ds = ds.train_test_split(test_size=50)

#### Data source: TFRecords

In [None]:
"""Load TFRecords from Cloud"""
ds = ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords")
train_ds, test_ds = ds.train_test_split(test_size=50)

#### Data source: Python Dict

In [None]:
"""Create a toy dataset using Python Dictionary for demo."""

"""We want to teach the model to answer a question with a specific response."""
dataset_items = [
    {
    "prompt": "What is your name?",
    "answer": "My name is Kithara"
    } 
]* 1000

ds = ray.data.from_items(dataset_items)
train_ds, test_ds = ds.train_test_split(test_size=50)

#### Other supported data formats include JSON, Parquet, BigQuery, MongoDB, Spark, TFDS, Torch.data and more

 ----------

### Preprocess Data 

In [8]:
# Creates preprocessor
preprocessor = kithara.preprocessor.SFTPreprocessor(
    tokenizer_handle=model_handle, seq_len=2048
)

# Create data loaders
train_dataloader = kithara.Dataloader(
    train_ds,
    per_device_batch_size=1,
)
eval_dataloader = kithara.Dataloader(
    test_ds,
    per_device_batch_size=1,
)

### Run SFT 

In [None]:
rm -rf /tmp/demo

In [None]:
# Initialize SFT trainer
trainer = kithara.SFTTrainer(
    model=model,
    preprocessor=preprocessor,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    steps=150,
    eval_steps_interval=50,
    log_steps_interval=1,
    tensorboard_dir="/tmp/demo",
    optimizer=keras.optimizers.AdamW(learning_rate=5e-5, weight_decay=0.01),
)

# Start training
trainer.train()

# Test after tuning
pred = trainer.generate("What is your name?")
print("Tuned model generates:", pred)

In [None]:
# Launch tensorboard
tensorboard --logdir=/tmp/demo

#### Save model in HuggingFace format

Note currently this feature is not fully designed out yet, we rely on CLI for model conversion for now. 

In [10]:
# Step 1: Saves the model weights
trainer.model.save_weights("/dev/shm/temp/tuned_gemma2_2b.weights.h5")

In [None]:
# Step 2: Converts the model weights to HF format
python keras-hub/tools/gemma/export_gemma_to_hf.py \
  --weights_file /dev/shm/temp/tuned_gemma2_2b.weights.h5 \
  --gemma_version 2 \
  --size 2b \
  --output_dir /dev/shm/temp/tuned_model_in_hf_format

#### Load model in HF, or vLLM

In [30]:
# Loading checkpoint into a HuggingFace model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("/dev/shm/temp/tuned_model_in_hf_format")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# Load checkpoint to a vLLM server 

from vllm import LLM

llm = LLM(model="/dev/shm/temp/tuned_model_in_hf_format")  # Load checkpoint tuned with Kithara
output = llm.generate("What is your name?")

#### Tune a MaxText model

MaxText model implementations offer the best in class performance on TPUs. Kithara support tuning models available in MaxText. 

In [None]:
model = kithara.MaxTextModel(
    model_name="gemma2-9b",
    seq_len=4096,
    per_device_batch_size=1,
)

### Multihost Example

We have seen in the previous section how to tune a Gemma2-2b model with LoRA on a v4-8 (singlehost) machine. In this section, we show how we can scale up and tune a Gemma2-9b model with LoRA on a v4-32 (multihost) machine. 

Kithara offers an orchestration layer via Ray, which works with resources from GCE, GKE, XPK, QRs. 

It's worth noting that the core Kithara library can work with any orchestrator options. We offer the Ray orchestration abstraction layer for users who are not familiar with multihost development. 

#### Step 1: Launch a Ray Cluster

A Ray Cluster is a group of machines, including CPUs, TPUs, and GPUs, that has a host machine and worker machines. The host machine is responsible for scheduling jobs onto worker machines. 

In [None]:
# Launch a TPU cluster with the provided YAML file with CLI. 
ray up -y orchestration/multihost/ray/TPU/cluster.yaml

# Or, Launch a GPU cluster 
ray up -y orchestration/multihost/ray/TPU/cluster.yaml

# You can also launch a cluster with both TPUs and GPUs

#### Step 2: Launch the Ray dashboard

You should see a link for opening up the Ray Dashboard on your localhost. 

In [None]:
ray dashboard orchestration/multihost/ray/TPU/cluster.yaml

#### Submit a multihost job on the TPU Ray cluster. 


In [None]:
python orchestration/multihost/ray/submit_ray_job.py "python examples/multihost/ray/TPU/hf_sft_gemma_example_via_ray.py" --hf-token your_token

#### Submit a multihost job on the GPU Ray Cluster. 

In [None]:
python orchestration/multihost/ray/submit_ray_job.py "python examples/multihost/ray/GPU/hf_gemma_example_via_ray.py" --hf-token your_token