# BatFit LoRA Training on Colab
Use this notebook to fine-tune BatFit adapters on a Colab GPU while mirroring the CLI workflow.

**Workflow overview**
1. Clone or verify the repo checkout.
2. Install the Colab-friendly requirements.
3. (Optional) mount Google Drive for persistent artifacts.
4. Configure hyperparameters via `BATFIT_*` env vars.
5. Step through each modular stage (prompt, data, tokenizer, trainer) to debug.
6. Launch training via `trainer.train()`, `main()`, or the CLI fallback.


## Step 1 – Clone or verify repository
This repo is public, so point `GIT_REPO`/`BATFIT_REPO_DIR` where you like. If a `.git` folder already exists, the cell simply reports the current working directory.


In [None]:
import os
import subprocess
from pathlib import Path

repo_url = os.environ.get('GIT_REPO', 'https://github.com/wahajaslm/batfit.git')
repo_dir = Path(os.environ.get('BATFIT_REPO_DIR', 'batfit'))
if Path('.git').exists():
    print(f'Already inside repo: {Path.cwd()}')
else:
    if not repo_dir.exists():
        subprocess.check_call(['git', 'clone', repo_url, str(repo_dir)])
    os.chdir(repo_dir)
    print(f'Working directory -> {Path.cwd()}')


## Step 2 – Install Colab dependencies
Installs the lightweight requirements (transformers, datasets, peft, etc.). Run once per runtime.


In [None]:
!pip install -q -r requirements-colab.txt

## Step 3 – (Optional) Mount Drive
Only needed when you want checkpoints/logs to persist beyond the Colab session.


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

## Step 4 – Configure training knobs
Set `BATFIT_BASE_MODEL`, `BATFIT_MAX_LEN`, `BATFIT_EPOCHS`, or any other `BATFIT_*` overrides (batch size, learning rate, etc.) before building datasets.


In [None]:
import os
os.environ['BATFIT_BASE_MODEL'] = os.environ.get('BATFIT_BASE_MODEL', 'TinyLlama/TinyLlama-1.1B-Chat-v1.0')
os.environ['BATFIT_MAX_LEN'] = os.environ.get('BATFIT_MAX_LEN', '768')
os.environ['BATFIT_EPOCHS'] = os.environ.get('BATFIT_EPOCHS', '2')
print('Base model:', os.environ['BATFIT_BASE_MODEL'])

## Step 5 – Modular training pipeline
Each sub-step below calls a helper from `scripts/train_lora.py` so you can inspect intermediate artifacts without launching a full training run.


### 5a. Load system prompt
`resolve_system_prompt()` gives precedence to `BATFIT_SYSTEM_PROMPT`, then `data/common/prompts/system.txt`, then the script default.


In [None]:
import os, sys
from pathlib import Path

repo_root = Path.cwd()
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

from scripts.train_lora import (
    load_from_manifest,
    resolve_system_prompt,
    prepare_tokenizer,
    tokenize_splits,
    resolve_device,
    prepare_model,
    build_trainer,
    main,
)

system_prompt = resolve_system_prompt()
print('System prompt loaded (chars):', len(system_prompt))


### 5b. Load manifest-defined datasets
`load_from_manifest()` applies manifest weights, normalizes JSONL rows, and optionally carves a validation split.


In [None]:
train_raw, val_raw = load_from_manifest()
print('Train rows:', len(train_raw))
print('Val rows:', len(val_raw) if val_raw is not None else 0)


### 5c. Prepare tokenizer
`prepare_tokenizer()` ensures the chat template has a pad token and right-side padding before batching.


In [None]:
tokenizer = prepare_tokenizer()
print('Tokenizer vocab size:', tokenizer.vocab_size)


### 5d. Tokenize datasets
`tokenize_splits()` converts normalized rows into LM inputs/labels so you can inspect lengths and spot data issues.


In [None]:
train_dataset, val_dataset = tokenize_splits(train_raw, val_raw, tokenizer, system_prompt)
print('Tokenized train len:', len(train_dataset))
print('Tokenized val len:', len(val_dataset) if val_dataset is not None else 0)


### 5e. Resolve device + dtype
`resolve_device()` checks CUDA/MPS availability and determines the correct `device_map`/dtype combo.


In [None]:
use_cuda, use_mps, device_map, dtype = resolve_device()
print('CUDA:', use_cuda, 'MPS:', use_mps, 'device_map:', device_map, 'dtype:', dtype)


### 5f. Build base model + LoRA adapters
`prepare_model()` loads the base checkpoint, disables cache, enables gradient checkpointing, and injects LoRA modules.


In [None]:
model = prepare_model(os.environ.get('BATFIT_BASE_MODEL', 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'), dtype, device_map, use_mps)
print('Model with LoRA adapters ready.')


### 5g. Assemble Trainer
`build_trainer()` wires the model, tokenizer, datasets, and default hyperparameters together. Rerun if you tweak settings above.


In [None]:
trainer = build_trainer(model, tokenizer, train_dataset, val_dataset)
print('Trainer ready.')


### Step 6a – Train interactively
Uncomment `trainer.train()` once you're satisfied with the inspected artifacts.


In [None]:
# trainer.train()


### Step 6b – Run `main()` end-to-end
Call `main()` if you want the exact CLI behavior without stepping through each helper.


In [None]:
# from scripts.train_lora import main
# _ = main()


### Step 6c – Legacy CLI fallback
Shelling out remains available for parity with older instructions.


In [None]:
# !python scripts/train_lora.py
