# Fine-Tune LLM with Data Cache

This notebook demonstrates how to fine-tune a BERT model with Kubeflow TrainJob using data streaming from a distributed cache.

We will use an iceberg table with [the Amazon reviews dataset](https://huggingface.co/datasets/fancyzhx/amazon_polarity) loaded from data cache cluster, which provides efficient data streaming for distributed training workloads.

## Install Kubeflow SDK

In [None]:
# Install the Kubeflow SDK
!pip install -U kubeflow

## List Available Training Runtimes

Get available Kubeflow Training Runtimes using the list_runtimes() API.

In [None]:
from kubeflow.trainer import TrainerClient

client = TrainerClient()
for runtime in client.list_runtimes():
    print(f'{runtime}\n')

## Define Training Function with DataCacheDataset

This training function uses DataCacheDataset to stream data from the cache cluster.

In [None]:
def train_func():
    import time
    import datetime
    from multiprocessing import current_process
    import torch.distributed as dist
    import pyarrow as pa
    import pyarrow.flight
    from pyarrow._flight import FlightClient
    import torch
    import os
    from torch.utils.data import IterableDataset
    from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, AdamW

    
    # Define DataCacheDataset to access data from cache
    class DataCacheDataset(IterableDataset):
        def __init__(self, seed=0, batch_size=128):
            train_job_name = os.getenv("TRAIN_JOB_NAME")
            self.endpoint = f'grpc://{train_job_name}-cache-service:50051'
            self.seed = seed
            self.batch_size = batch_size

        def set_epoch(self, epoch):
            self.seed = self.seed + epoch

        def from_arrow_rb_to_tensor(self, chunk):
            return

        # Extract RecordBatches and split them into chunks
        def stream_recordbatch_chunks(self, recordbatch_stream, chunk_size):
            leftover = None
        
            for rb in recordbatch_stream:
                batch = rb.data
                if leftover is not None:
                    arrays = [
                        pa.concat_arrays([leftover.column(i), batch.column(i)])
                        for i in range(len(batch.columns))
                    ]
                    batch = pa.RecordBatch.from_arrays(arrays, batch.schema.names)

                for i in range(0, batch.num_rows, chunk_size):
                    chunk = batch.slice(i, chunk_size)
        
                    if chunk.num_rows < chunk_size:
                        leftover = chunk
                        break
                    else:
                        yield from self.from_arrow_rb_to_tensor(chunk)
                else:
                    leftover = None
        
            if leftover is not None and leftover.num_rows > 0:
                yield from self.from_arrow_rb_to_tensor(leftover)

        def __iter__(self):
            worker_info = torch.utils.data.get_worker_info()
            
            if worker_info is None:
                worker_id = 0
                num_workers = 1
            else:
                worker_id = worker_info.id
                num_workers = worker_info.num_workers

            world_size = dist.get_world_size()
            rank = dist.get_rank()
                
            # Calculate shard index
            total_workers = num_workers * world_size
            global_worker_id = worker_id * world_size + rank
            index = (global_worker_id + self.seed) % total_workers

            descriptor = pa.flight.FlightDescriptor.for_path(*[s for s in [str(index), str(total_workers)]])
            client = pa.flight.connect(self.endpoint)
            flight_info = client.get_flight_info(descriptor)

            for endpoint in flight_info.endpoints:
                for location in endpoint.locations:
                    _client = FlightClient(location.uri)
                    ticket = endpoint.ticket
                    reader = _client.do_get(ticket)
                    
                    yield from self.stream_recordbatch_chunks(reader, self.batch_size)

    # Define dataset for Amazon reviews
    class AmazonReviewDataset(DataCacheDataset):
        def __init__(self, tokenizer, seed=0):
            super().__init__(seed)
            self.tokenizer = tokenizer
            self.row_count = 0

        def from_arrow_rb_to_tensor(self, chunk):
            self.row_count = self.row_count + chunk.num_rows
            texts = chunk.column("content").to_pylist()
            labels = chunk.column("label").to_pylist()

            encoding = self.tokenizer(
                texts,
                padding="max_length",
                truncation=True,
                max_length=128,
                return_tensors="pt"
            )
        
            yield {
                "input_ids": encoding["input_ids"].squeeze(0),
                "attention_mask": encoding["attention_mask"].squeeze(0),
                "labels": torch.tensor(labels)
            }

    device, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")
    print(f"Using Device: {device}, Backend: {backend}")

    local_rank = int(os.getenv("LOCAL_RANK", 0))
    dist.init_process_group(backend=backend)
    print(
        "Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}".format(
            dist.get_world_size(),
            dist.get_rank(),
            local_rank,
        )
    )
    device = torch.device(f"{device}:{local_rank}")
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=5)
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=2e-5)

    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

    start = time.time()
    dataset = AmazonReviewDataset(tokenizer)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)

    # Training loop
    for epoch in range(1):
        dataset.set_epoch(epoch)
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            if dist.get_rank() == 0:
                print(f"Epoch {epoch}, Loss: {loss.item()}")

    dist.barrier()
    if dist.get_rank() == 0:
        print("Training is finished")

    end = time.time()
    print(f"Time taken: {end - start:.6f} seconds")
    print(f"Total rows processed: {dataset.row_count}")

    dist.destroy_process_group()

## Create TrainJob with Data Cache

Use the train() API with DataCacheInitializer to create a TrainJob that streams data from cache.

In [None]:
from kubeflow.trainer import CustomTrainer, DataCacheInitializer, Initializer

# Get the cache runtime
torch_cache_runtime = None
for runtime in client.list_runtimes():
    if runtime.name == "torch-distributed-with-cache":
        torch_cache_runtime = runtime
        break

if torch_cache_runtime is None:
    raise ValueError("torch-distributed-with-cache runtime not found. Please ensure the runtime is installed.")

# Configure your data cache settings
# Replace these values with your actual metadata location and storage URI
METADATA_LOC = "Iceberg table metadata file path on s3"
STORAGE_URI = "cache://<schema_name>/<table_name>"
IAM_ROLE= "IAM role with access to table"

job_name = client.train(
    trainer=CustomTrainer(
        func=train_func,
        # Set how many PyTorch nodes you want to use for distributed training
        num_nodes=8,
        # Set the resources for each PyTorch node
        resources_per_node={
            "cpu": 3,
            "memory": "8Gi",
            # Uncomment to use GPUs
            "nvidia.com/gpu": 1,
        },
        packages_to_install=["pyarrow==19.0.0", "transformers==4.45.2"]
    ),
    # Dataset config with data cache initializer
    initializer=Initializer(
        dataset=DataCacheInitializer(
            num_data_nodes=4,
            metadata_loc=METADATA_LOC,
            storage_uri=STORAGE_URI,
            iam_role=IAM_ROLE
        )
    ),
    runtime=torch_cache_runtime
)

print(f"TrainJob created: {job_name}")

## Check the TrainJob Info

Use the list_jobs() and get_job() APIs to get information about the created TrainJob.

In [None]:
for job in client.list_jobs():
    print(f"TrainJob: {job.name}, Status: {job.status}, Created at: {job.creation_timestamp}")

In [None]:
# Get details about the TrainJob steps
for step in client.get_job(name=job_name).steps:
    print(f"Step: {step.name}, Status: {step.status}, Devices: {step.device} x {step.device_count}\n")

## Watch the TrainJob Logs

Use the get_job_logs() API to follow the TrainJob logs in real-time.

In [None]:
# Wait for the job to reach running state
client.wait_for_job_status(job_name, status={"Running"})
print("Job is in running state")

# Follow the logs
for line in client.get_job_logs(job_name, follow=True):
    print(line)

## Delete the TrainJob

When the TrainJob is finished, you can delete the resource.

In [None]:
# Uncomment to delete the job
# client.delete_job(job_name)