This is a demo of Weights and Biases' support for automatic loggign of system metrics for TPUs.

This combines the work of Felafax.ai with Weights and Biases instrumentation for logging.

Please see the original colab and backend from Felfax [here](https://github.com/felafax/felfax), and checkout their website at [Felefax.ai](felefax.ai).

# Setup

In [None]:
!pip install --upgrade git+https://github.com/felafax/felafax -q
!pip uninstall -y tensorflow
!pip install tensorflow-cpu -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.8/52.8 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.6/177.6 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m34.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""
import sys
from typing import Any, Dict, List

# Add the parent directory of the current working directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from llama3_jax.trainer_engine import setup

setup.setup_environment(base_dir="/")

from llama3_jax.trainer_engine import automodel_lib, jax_utils, trainer_lib

setup.reload_modules("llama3_jax")

import jax
import jax.numpy as jnp
import chex
import optax

import torch
from datasets import load_dataset


Reloaded all felafax modules.


# Step 0: Configure LoRA params and precision for training (jnp.bfloat16 or jnp.float32)

In [None]:
MODEL_NAME = "colab-llama-3.1-8B-Instruct-JAX"

In [None]:
model_path, model, model_configurator, tokenizer = (
    automodel_lib.AutoJAXModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=jnp.bfloat16,
        param_dtype=jnp.bfloat16,
        lora_rank=8,
        lora_alpha=16,
    )
)


Downloading model colab-llama-3.1-8B-Instruct-JAX...


Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.59k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

llama-3.1-8B-Instruct-JAX.flax:   0%|          | 0.00/16.1G [00:00<?, ?B/s]

colab-llama-3.1-8B-Instruct-JAX was downloaded to /hf/models--felafax--colab-llama-3.1-8B-Instruct-JAX/snapshots/7598ab3cbfab748cc81a1535035911d98483b90e/llama-3.1-8B-Instruct-JAX.flax.


# Step 1: prepare the dataset

For this colab, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [None]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

    ### Instruction: {}


    ### Input: {}

    ### Response: {}"""

    EOS_TOKEN = tokenizer.eos_token

    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=seq_length + 1,
        )
        return {
            "input_tokens": [input_id[:-1] for input_id in tokenized["input_ids"]],
            "target_tokens": [input_id[1:] for input_id in tokenized["input_ids"]],
            "loss_masks": [input_id[1:] for input_id in tokenized["attention_mask"]],
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts to JAX arrays.
        """
        collated = {"input_tokens": [], "target_tokens": [], "loss_masks": []}
        for item in batch:
            for key in collated:
                collated[key].append(item[key])
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value)
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ["train", "test"]:
        ds[split] = ds[split].map(
            _tokenize, batched=True, remove_columns=dataset.column_names
        )

    # Create DataLoaders
    dataloader_args = dict(
        shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn
    )
    train_dataloader = torch.utils.data.DataLoader(ds["train"], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds["test"], **dataloader_args)

    return train_dataloader, test_dataloader

### Uncomment below code ⬇️ if you'd like to run and test your dataset pipeline.

In [None]:
# def test_dataset_pipeline(tokenizer):
#     """Print shapes of first batch to verify dataset pipeline."""
#     train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)
#     batch = next(iter(train_loader))

#     print("Input tokens shape:", batch['input_tokens'].shape)
#     print("Target mask shape:", batch['target_tokens'].shape)
# test_dataset_pipeline(tokenizer)

# Step 2: Configure hyperparameters below and train!

In [None]:
@chex.dataclass(frozen=True)
class TrainerConfig:
    # dataset pipeline knobs
    batch_size: int = 64  # 8 Is the minimum
    seq_length: int = 32
    dataset_size_limit: int | None = None

    # training pipeline knobs
    learning_rate: float = 1e-3
    num_epochs: int = 2
    max_steps: int | None = 20

    print_every_n_steps: int = 5

    # eval
    eval_every_n_steps: int = 10
    max_eval_steps: int | None = 1


trainer_config = TrainerConfig()
optimizer = optax.sgd(trainer_config.learning_rate)


In [None]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    batch_size=trainer_config.batch_size,
    seq_length=trainer_config.seq_length,
    max_examples=trainer_config.dataset_size_limit,
)

Downloading readme:   0%|          | 0.00/11.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/51760 [00:00<?, ? examples/s]

Map:   0%|          | 0/51760 [00:00<?, ? examples/s]

Map:   0%|          | 0/43996 [00:00<?, ? examples/s]

Map:   0%|          | 0/7764 [00:00<?, ? examples/s]

In [None]:
# Print training information
trainer_lib.pprint_training_pipeline(train_dataloader, trainer_config)


Training Configuration Summary:
Total samples: 43996
Batch size: 64
Sequence length: 32
Number of epochs: 2
Steps per epoch: 688
Total training steps: 20
*Note*: Total steps limited by max_steps setting (20)


In [None]:
!pip install wandb
import wandb
from google.colab import userdata

wandb.login(key=userdata.get("WANDB_API_KEY"))

Collecting wandb
  Downloading wandb-0.18.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.7 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.16.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB)
Downloading wandb-0.18.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
PROJECT_NAME = "llama-3.1-8b-fine-tune-TPU"
ENTITY = "wandb"
run = wandb.init(project=PROJECT_NAME, entity=ENTITY, tags=["Colab"], save_code=False)

# Organize Metrics
run.define_metric("train_step")
run.define_metric("epoch")

run.define_metric("train/*", step_metric="train_step")
run.define_metric("val/*", step_metric="train_step")

# Log config
from dataclasses import asdict

config_dict = asdict(trainer_config)
config_dict.update({"model_name": MODEL_NAME})
run.config.update(config_dict)

[34m[1mwandb[0m: Currently logged in as: [33mwyler-zahm[0m ([33mwandb[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111395401111142, max=1.0)…

In [None]:
trainer = trainer_lib.CausalLMTrainer(
    model_name=MODEL_NAME,
    model=model,
    model_ckpt_path=model_path,
    model_configurator=model_configurator,
    optimizer=optimizer,
    training_config=trainer_config,
    mesh=jax_utils.MESH,
    dtype=jnp.bfloat16,  # precision to use for training
)

In [None]:
# Redefine the CausalLMTrainer.train function to instrument with W&B and add eval

import time
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as PS


def train(trainer, train_dataloader, eval_dataloader):
    total_training_time = 0
    total_steps = 0

    for epoch in range(trainer.training_config.num_epochs):
        print(f"Starting epoch {epoch} of training...")

        for step, train_batch in enumerate(train_dataloader):
            trainer.current_step = epoch * len(train_dataloader) + step

            train_batch = jax.device_put(
                train_batch, NamedSharding(trainer.mesh, PS("dp", "fsdp"))
            )

            sharded_rng = jax_utils.next_rng()

            # Start timing
            step_start_time = time.time()

            trainer.train_state, sharded_rng, metrics = trainer.train_step(
                trainer.train_state, train_batch, sharded_rng, run_jitted=False
            )

            # End timing
            step_end_time = time.time()

            # Calculate step duration
            step_duration = step_end_time - step_start_time
            total_training_time += step_duration
            total_steps += 1

            # Calculate steps per second
            steps_per_sec = 1 / step_duration

            to_log = {}
            to_log.update(
                {
                    "train_step": step,
                    "train/loss": metrics["loss"],
                    "train/accuracy": metrics["accuracy"],
                    "train/step_time": step_duration,
                    "train/step_hz": steps_per_sec,
                    "epoch": epoch,
                }
            )

            if step % trainer.training_config.print_every_n_steps == 0:
                print(
                    f"Epoch {epoch}, Step {step}, "
                    f"Train Loss: {metrics['loss']:.4f}, "
                    f"Accuracy: {metrics['accuracy']:.4f}, "
                    f"Step Time: {step_duration:.4f}s, "
                    f"Steps/sec: {steps_per_sec:.2f}"
                )

            # Evaluate if applicable (ADDED)
            # if step % trainer.training_config.eval_every_n_steps == 0:
            #     eval_step_start_time = time.time()
            #     #{'loss': avg_loss, 'accuracy': avg_accuracy}
            #     eval_metrics = trainer.evaluate(trainer.train_state, eval_dataloader, run_jitted=False)
            #     eval_time = time.time()-eval_step_start_time

            #     print(f"Epoch {epoch}, Step {step}, "
            #           f"Val Loss: {eval_metrics['avg_loss']:.4f}, "
            #           f"Val Accuracy: {eval_metrics['avg_accuracy']:.4f}, "
            #           f"Step Time: {eval_time:.4f}s")

            #     to_log.update({
            #         "val/loss": eval_metrics['avg_loss'],
            #         "val/accuracy": eval_metrics['avg_accuracy'],
            #         "val/eval_time": eval_time})

            # Log metrics to W&B
            print("Logging:", to_log)
            run.log(to_log)

            if (
                trainer.training_config.max_steps
                and step >= trainer.training_config.max_steps
            ):
                break

    avg_steps_per_sec = total_steps / total_training_time
    print(f"Average Steps per Second: {avg_steps_per_sec:.2f}")

    return trainer.train_state

In [None]:
state = train(trainer, train_dataloader, val_dataloader)
run.finish()

Starting epoch 0 of training...
Epoch 0, Step 0, Train Loss: 3.7673, Accuracy: 0.3438, Step Time: 70.6932s, Steps/sec: 0.01
Logging: {'train_step': 0, 'train/loss': Array(3.7673268, dtype=float32), 'train/accuracy': Array(0.34375, dtype=float32), 'train/step_time': 70.69318437576294, 'train/step_hz': 0.014145635238110008, 'epoch': 0}
Logging: {'train_step': 1, 'train/loss': Array(3.6964705, dtype=float32), 'train/accuracy': Array(0.34375, dtype=float32), 'train/step_time': 17.022202968597412, 'train/step_hz': 0.05874680274020946, 'epoch': 0}
Logging: {'train_step': 2, 'train/loss': Array(3.609604, dtype=float32), 'train/accuracy': Array(0.34375, dtype=float32), 'train/step_time': 17.354307889938354, 'train/step_hz': 0.057622580303520944, 'epoch': 0}
Logging: {'train_step': 3, 'train/loss': Array(3.5263896, dtype=float32), 'train/accuracy': Array(0.34375, dtype=float32), 'train/step_time': 17.398077964782715, 'train/step_hz': 0.05747761344811798, 'epoch': 0}
Logging: {'train_step': 4, '

VBox(children=(Label(value='0.019 MB of 0.019 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
train/accuracy,▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▅▆▅▆▆▄▆▇▇███████████
train/loss,████▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▂▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁
train/step_hz,▁▇▇▇█▇▇█▇██▇▇█▇▇█▇█▇▇▇▇█▇▇█▇▇█▇▇██▇▇█▇▇▇
train/step_time,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_step,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇█▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇█

0,1
epoch,1.0
train/accuracy,0.97119
train/loss,0.17833
train/step_hz,0.0583
train/step_time,17.15172
train_step,20.0
