# 04 · Fine-tune Slice *t* (Unsloth QLoRA)

Load 4-bit Llama weights, configure adapters, and respect the 25M token budget per slice.

In [1]:
!pip install unsloth



In [2]:
# Persistent Drive + run mode setup
import os
import sys
from pathlib import Path

try:
    from google.colab import drive  # type: ignore
    DRIVE_MOUNT = Path('/content/drive')
    if not DRIVE_MOUNT.exists():
        drive.mount('/content/drive')
except Exception as exc:  # pragma: no cover
    print(f'Colab drive mount skipped: {exc}')

if Path('/content/drive').exists():
    DRIVE_ROOT = Path('/content/drive/MyDrive').resolve()
else:
    DRIVE_ROOT = Path.home().resolve()

PROJECT_ROOT = DRIVE_ROOT / 'secure-llm-mia'
if not PROJECT_ROOT.exists():
    raise FileNotFoundError('Run 00_colab_setup.ipynb first to clone the repo on Drive.')

if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

os.environ['SECURE_LLM_MIA_ROOT'] = str(PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

from src.utils.runtime import current_run_mode

RUN_MODE = current_run_mode()
print('PROJECT_ROOT:', PROJECT_ROOT)
print('Active run mode:', RUN_MODE.name, '-', RUN_MODE.description)

DATA_ROOT = PROJECT_ROOT / 'data'
ARTIFACTS_DIR = PROJECT_ROOT / 'artifacts'
CHECKPOINT_ROOT = PROJECT_ROOT / 'checkpoints'
for path in (DATA_ROOT, ARTIFACTS_DIR, CHECKPOINT_ROOT):
    path.mkdir(parents=True, exist_ok=True)

BHC_DATA_DIR = DRIVE_ROOT / 'mimic-iv-bhc'
BHC_DATA_DIR.mkdir(parents=True, exist_ok=True)
BHC_CSV_PATH = BHC_DATA_DIR / 'mimic-iv-bhc.csv'
print('BHC CSV path:', BHC_CSV_PATH)


PROJECT_ROOT: /content/drive/MyDrive/secure-llm-mia
Active run mode: subset - Quick debugging subset (<=2k rows) for lightweight Colab smoke tests.
BHC CSV path: /content/drive/MyDrive/mimic-iv-bhc/mimic-iv-bhc.csv


In [None]:
import math
from typing import Dict

import torch
from datasets import Dataset, concatenate_datasets
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
from unsloth import FastLanguageModel

from src.modeling.lora import LoRAHyperParams, compute_gradient_accumulation
from src.modeling.train import TokenBudgetTracker

PACKED_DIR = ARTIFACTS_DIR / 'packed' / RUN_MODE.name
if not PACKED_DIR.exists():
    raise FileNotFoundError('Packed datasets missing. Run notebook 03 before fine-tuning.')

slice_paths = sorted(PACKED_DIR.glob('slice_*.parquet'))
if not slice_paths:
    raise FileNotFoundError('No per-slice parquet files found in packed directory.')

slice_datasets: Dict[int, Dataset] = {}
for path in slice_paths:
    try:
        slice_id = int(path.stem.split('_')[1])
    except (IndexError, ValueError) as exc:
        print(f'Skipping {path}: unable to parse slice id ({exc}).')
        continue
    ds = Dataset.from_parquet(str(path))
    slice_datasets[slice_id] = ds
    print(f'Slice {slice_id}: {len(ds)} training sequences loaded.')

if not slice_datasets:
    raise RuntimeError('No valid slice datasets available for training.')

SLICES = sorted(slice_datasets.keys())
TRACKS = ['noreplay', 'replay10']
REPLAY_FRACTION = 0.10
MODEL_NAME = os.getenv('UNSLOTH_MODEL_NAME', 'unsloth/Meta-Llama-3.1-8B-bnb-4bit')
MAX_SEQ_LENGTH = 4096
TOKENS_PER_SLICE = 3_000_000
TOKENS_PER_STEP = 128_000
MICRO_BATCH = 1
AVG_TOKENS_PER_SAMPLE = 3_000

accum_steps = compute_gradient_accumulation(TOKENS_PER_STEP, MICRO_BATCH, AVG_TOKENS_PER_SAMPLE)
print('Gradient accumulation:', accum_steps)

is_ampere_plus = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
use_bf16 = bool(is_ampere_plus and torch.cuda.is_bf16_supported())
use_fp16 = torch.cuda.is_available() and not use_bf16

max_steps = math.ceil(TOKENS_PER_SLICE / TOKENS_PER_STEP)
print('Max steps per slice:', max_steps)

def init_model():
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        load_in_4bit=True,
    )
    tokenizer.padding_side = 'right'
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = MAX_SEQ_LENGTH
    lora_cfg = LoRAHyperParams(
        r=32,
        alpha=32,
        dropout=0.0,
        target_modules=('q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'),
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_cfg.r,
        target_modules=list(lora_cfg.target_modules),
        lora_alpha=lora_cfg.alpha,
        lora_dropout=lora_cfg.dropout,
        bias='none',
        use_gradient_checkpointing='unsloth',
    )
    return model, tokenizer

for track in TRACKS:
    print(f'=== Track: {track} ===')
    for slice_id in SLICES:
        model, tokenizer = init_model()
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=8)

        base_ds = slice_datasets[slice_id]
        train_ds = base_ds
        if track == 'replay10':
            prior_ids = [sid for sid in SLICES if sid < slice_id]
            if prior_ids:
                combined = concatenate_datasets([slice_datasets[sid] for sid in prior_ids])
                replay_count = min(int(len(base_ds) * REPLAY_FRACTION), len(combined))
                if replay_count > 0:
                    replay_subset = combined.shuffle(seed=17).select(range(replay_count))
                    train_ds = concatenate_datasets([base_ds, replay_subset])
                    print(f'Slice {slice_id}: added {len(replay_subset)} replay samples from previous slices.')

        train_ds = train_ds.with_format(type='torch', columns=['input_ids', 'attention_mask'])
        output_dir = CHECKPOINT_ROOT / f'slice_{slice_id}' / track / RUN_MODE.name

        training_args = TrainingArguments(
            output_dir=str(output_dir),
            per_device_train_batch_size=MICRO_BATCH,
            gradient_accumulation_steps=accum_steps,
            learning_rate=1e-4,
            warmup_steps=10,
            max_steps=max_steps,
            logging_steps=max(1, max_steps // 5),
            save_steps=max_steps,
            bf16=use_bf16,
            fp16=use_fp16,
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_ds,
            data_collator=data_collator,
        )

        print(f'--- Training slice {slice_id} ({track}) ---')
        trainer.train()

        tracker = TokenBudgetTracker(tokens_per_slice=TOKENS_PER_SLICE)
        for seq in train_ds['input_ids']:
            if tracker.update(len(seq)):
                break
        print(f'Slice {slice_id}: approx tokens consumed {tracker.consumed_tokens:,}')

        output_dir.mkdir(parents=True, exist_ok=True)
        FastLanguageModel.save_pretrained(model, output_dir)
        tokenizer.save_pretrained(output_dir)
        print(f'Saved adapters + tokenizer to {output_dir}')

        del trainer, model, tokenizer
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
