In [8]:
import os
import sys
import importlib

# Add the parent directory of the current working directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '.')))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

try:
    import llama3_jax
    print("llama3_jax package imported successfully")
except ImportError as e:
    print(f"Error importing felafax: {e}")

llama3_jax package imported successfully


In [9]:
from llama3_jax.trainer_engine import setup
setup.setup_environment()

In [10]:
from llama3_jax.trainer_engine import utils, jax_utils
from llama3_jax.trainer_engine import automodel_lib, checkpoint_lib, trainer_lib, convert_lib
from llama3_jax import llama_config

setup.reload_modules("llama3_jax")

  from .autonotebook import tqdm as notebook_tqdm


Reloaded all felafax modules.


In [11]:
from typing import (Any, Dict, List, Mapping, Optional, Sequence, Tuple,
                    Union)

import jax
import jax.numpy as jnp
import chex
import optax

import torch

from datasets import load_dataset
from transformers import default_data_collator

In [5]:
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [6]:
# Select a supported model from above list to use!
MODEL_NAME = "Meta-Llama-3.1-8B"

In [22]:
# Constants for paths
FELAFAX_DIR = os.path.dirname(os.path.dirname(felafax.llama3_jax.__file__))
GCS_DIR = "/home/felafax-storage/"
EXPORT_DIR = os.path.join(FELAFAX_DIR, "llama3_jax", "export")

HF_COMPATIBLE_EXPORT_DIR = os.path.join(GCS_DIR, "llama3_jax", "hf_export")
HF_REPO_ID = f"{HUGGINGFACE_USERNAME}/llama3.1_finetuned_8B"

In [8]:
model_path, model, model_configurator, tokenizer = automodel_lib.AutoJAXModelForCausalLM.from_pretrained("llama-3.1-8B-JAX",
                                                                           HUGGINGFACE_TOKEN)

Downloading model llama-3.1-8B-JAX...


Fetching 3 files: 100%|██████████| 3/3 [00:01<00:00,  2.78it/s]

llama-3.1-8B-JAX was downloaded to /home/felafax-storage/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/llama3.1_8b_serialized.flax.





# Dataset pipeline

In [9]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

In [10]:
def test_dataset_pipeline(tokenizer):
    """Print shapes of first batch to verify dataset pipeline."""
    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=32)
    batch = next(iter(train_loader))
    print("Input tokens shape:", batch['input_tokens'].shape)
    print("Target mask shape:", batch['target_tokens'].shape)
test_dataset_pipeline(tokenizer)

Map: 100%|██████████| 27/27 [00:02<00:00, 13.05 examples/s]
Map: 100%|██████████| 5/5 [00:02<00:00,  2.46 examples/s]

Input tokens shape: (1, 32)
Target mask shape: (1, 32)





# Training loop

In [11]:
@chex.dataclass(frozen=True)
class TrainingConfig:
    learning_rate: float = 1e-4
    num_epochs: int = 1
    max_steps: int | None = 5
    batch_size: int = 32
    seq_length: int = 64
    dataset_size_limit: int | None = 32
    print_every_n_steps: int = 1
    eval_every_n_steps: int = 1000


training_cfg = TrainingConfig()
optimizer = optax.sgd(training_cfg.learning_rate)

In [12]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    seq_length=training_cfg.seq_length,
    max_examples=training_cfg.dataset_size_limit,
)

Map: 100%|██████████| 27/27 [00:01<00:00, 14.26 examples/s] 
Map: 100%|██████████| 5/5 [00:01<00:00,  2.58 examples/s]


In [13]:
trainer = trainer_lib.CausalLMTrainer(
    model=model,
    model_ckpt_path=model_path,
    model_configurator=model_configurator,
    optimizer=optimizer,
    training_config=training_cfg,
    mesh=jax_utils.MESH, 
)

Loading causal language model...


In [14]:
state = trainer.train(train_dataloader, val_dataloader, run_jitted=True)

Starting epoch 0 of training...
Epoch 0, Step 0, Train Loss: 2.8467, Accuracy: 0.4531
Epoch 0, Step 0, Eval Loss: 2.6411, Accuracy: 0.4688
Epoch 0, Step 1, Train Loss: 2.8010, Accuracy: 0.4062
Epoch 0, Step 2, Train Loss: 2.5753, Accuracy: 0.4219
Epoch 0, Step 3, Train Loss: 2.2409, Accuracy: 0.5312
Epoch 0, Step 4, Train Loss: 2.2621, Accuracy: 0.5517
Epoch 0, Step 5, Train Loss: 2.3617, Accuracy: 0.4531


In [None]:
export_path = os.path.join(EXPORT_DIR, "llama3.flax")
trainer.save_checkpoint(state, path=export_path)

In [None]:
convert_lib.save_hf_compatible_checkpoint(f'flax_params::{export_path}',
                                          HF_COMPATIBLE_EXPORT_DIR,
                                          model_configurator)

In [None]:
# convert_lib.upload_checkpoint_to_hf(HF_COMPATIBLE_EXPORT_DIR, HF_REPO_ID,
#                                     HUGGINGFACE_TOKEN)