In [1]:
import os
import torch
print(torch.cuda.is_available())

True


In [2]:
"""Training script for hallucination detection probes."""

import os
import json
import atexit
from pathlib import Path
from typing import List
from dataclasses import asdict
import argparse

import torch
import wandb
from torch.utils.data import Subset
from transformers import TrainingArguments
from dotenv import load_dotenv

from utils.file_utils import save_jsonl, save_json, load_yaml
from utils.model_utils import load_model_and_tokenizer, print_trainable_parameters
from utils.probe_loader import upload_probe_to_hf

from probe.dataset import TokenizedProbingDataset, create_probing_dataset, tokenized_probing_collate_fn
from probe.config import TrainingConfig
from probe.value_head_probe import setup_probe
from probe.trainer import ProbeTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

def main(training_config: TrainingConfig):
    """Main training function."""

    if hasattr(model, 'config'):
        try:
            model.config.use_cache = False
        except Exception:
            pass
    if training_config.enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'):
        try:
            model.gradient_checkpointing_enable()
        except Exception:
            pass
    
    print(f"Setting up probe: {training_config.probe_config.probe_id}")
    model, probe = setup_probe(model, training_config.probe_config)

    print_trainable_parameters(probe)

    # Load datasets
    print("Loading datasets:")
    train_datasets: List[TokenizedProbingDataset] = [
        create_probing_dataset(config, tokenizer)
        for config in training_config.train_dataset_configs
    ]
    eval_datasets: List[TokenizedProbingDataset] = [
        create_probing_dataset(config, tokenizer)
        for config in training_config.eval_dataset_configs
    ]
    
    # Concatenate training datasets
    train_dataset = train_datasets[0]
    for dataset in train_datasets[1:]:
        train_dataset += dataset

    # If requested, shuffle and shave down the training dataset to a fixed number of samples
    if training_config.num_train_samples is not None:
        total = len(train_dataset)
        num = max(0, min(int(training_config.num_train_samples), total))
        if num < total:
            g = torch.Generator()
            g.manual_seed(training_config.seed)
            perm = torch.randperm(total, generator=g).tolist()
            selected_indices = perm[:num]
            train_dataset = Subset(train_dataset, selected_indices)
            print(f"Using a subset of the training dataset: {num}/{total} samples")

    training_args = TrainingArguments(
        output_dir=str(training_config.probe_config.probe_path),
        overwrite_output_dir=True,
        per_device_train_batch_size=training_config.per_device_train_batch_size,
        per_device_eval_batch_size=training_config.per_device_eval_batch_size,
        max_steps=training_config.max_steps,
        num_train_epochs=training_config.num_train_epochs,
        logging_steps=training_config.logging_steps,
        eval_steps=training_config.eval_steps,
        remove_unused_columns=False,
        label_names=["classification_labels", "lm_labels"],
        report_to="wandb",
        run_name=training_config.probe_config.probe_id,
        eval_strategy="steps" if training_config.eval_steps else "no",
        logging_first_step=True,
        logging_strategy="steps",
        max_grad_norm=training_config.max_grad_norm,
        gradient_accumulation_steps=training_config.gradient_accumulation_steps,
        learning_rate=training_config.learning_rate,
        seed=training_config.seed,
    )
    
    # Add separate learning rates to training_args
    training_args.probe_head_lr = training_config.probe_head_lr
    training_args.lora_lr = training_config.lora_lr

    # Disable checkpoint saving
    # (there's a weird bug that occurs when trying to save during training)
    training_args.set_save(strategy="no")

    trainer = ProbeTrainer(
        probe=probe,
        eval_datasets=eval_datasets,
        cfg=training_config,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=None, # this is a dummy argument is for the HF base Trainer class
        data_collator=tokenized_probing_collate_fn,
        eval_steps=training_config.eval_steps,
        tokenizer=tokenizer,
    )

    def save_model_callback():
        """Save probe weigths, tokenizer and training config to disk."""
        probe.save(training_config.probe_config.probe_path)
        tokenizer.save_pretrained(training_config.probe_config.probe_path)
        save_json(
            training_config,
            training_config.probe_config.probe_path / "training_config.json"
        )

    # Register save callback for unexpected exits
    atexit.register(save_model_callback)
    
    print("Training...")
    trainer.train()

    # Save the model
    print(f"Saving model to {training_config.probe_config.probe_path}")
    save_model_callback()

    # Final evaluation
    eval_metrics = trainer.evaluate(
        save_roc_curves=training_config.save_roc_curves,
        dump_raw_eval_results=training_config.dump_raw_eval_results,
        verbose=True,
    )

    if training_config.save_evaluation_metrics:
        save_json(
            eval_metrics,
            training_config.probe_config.probe_path / "evaluation_results.json"
        )

    wandb.finish()

    if training_config.upload_to_hf:
        print(f"Uploading probe to HuggingFace Hub...")
        upload_probe_to_hf(
            repo_id=training_config.probe_config.hf_repo_id,
            probe_id=training_config.probe_config.probe_id,
            token=os.environ.get("HF_WRITE_TOKEN"),
        )




In [3]:
# Load config from YAML
training_config = TrainingConfig(**load_yaml('configs/train_config_slim_llama.yaml'))

In [4]:
# Load environment variables from .env if present
load_dotenv()

if training_config.upload_to_hf:
    assert os.environ.get("HF_WRITE_TOKEN", None) is not None

wandb.init(project=training_config.wandb_project, name=training_config.probe_config.probe_id)

print("Training config:")
for key, value in asdict(training_config).items():
    print(f"\t{key}: {value}")

# Load model and tokenizer
print(f"Loading model: {training_config.probe_config.model_name}")
model, tokenizer = load_model_and_tokenizer(
    training_config.probe_config.model_name
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mksevdari[0m ([33methz-lsai-25[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training config:
	wandb_project: hallucination-probes
	wandb_name: None
	probe_config: {'probe_id': 'clean_code_llama3_1_8b_lora', 'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'layer': 30, 'lora_layers': [], 'lora_r': 16, 'lora_alpha': 32, 'lora_dropout': 0.05, 'load_from': None, 'probe_path': PosixPath('/capstor/scratch/cscs/tkwiecinski/hallucination_probes/probes/clean_code_llama3_1_8b_lora'), 'hf_repo_id': 'obalcells/hallucination-probes', 'threshold': 0.5}
	upload_to_hf: False
	save_evaluation_metrics: True
	save_roc_curves: False
	dump_raw_eval_results: False
	per_device_train_batch_size: 4
	per_device_eval_batch_size: 4
	high_loss_threshold: None
	lambda_lm: 0.0
	lambda_kl: 0.5
	anneal_max_aggr: True
	anneal_warmup: 1.0
	learning_rate: 5e-05
	probe_head_lr: 0.001
	lora_lr: 0.0001
	sparsity_penalty_weight: None
	num_train_samples: None
	max_steps: -1
	num_train_epochs: 1
	enable_gradient_checkpointing: True
	gradient_accumulation_steps: 1
	max_grad_norm: 1.0
	eval_steps:

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.08it/s]


In [5]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [6]:
if hasattr(model, 'config'):
    try:
        model.config.use_cache = False
    except Exception:
        pass
if training_config.enable_gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'):
    try:
        model.gradient_checkpointing_enable()
    except Exception:
        pass

In [None]:
print(f"Setting up probe: {training_config.probe_config.probe_id}")
model, probe = setup_probe(model, training_config.probe_config)

print_trainable_parameters(probe)

Setting up probe: clean_code_llama3_1_8b_lora
Parameters that will be trained:
  - value_head.weight: shape torch.Size([1, 4096]), device cuda:0
  - value_head.bias: shape torch.Size([1]), device cuda:0

Total trainable parameters: 4,097 (0.00%)
Total parameters: 8,030,265,345


(4097, 8030265345)

In [8]:
probe

ValueHeadProbe(
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        

In [9]:
training_config.probe_config

ProbeConfig(probe_id='clean_code_llama3_1_8b_lora', model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', layer=30, lora_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], lora_r=16, lora_alpha=32, lora_dropout=0.05, load_from=None, probe_path=PosixPath('/capstor/scratch/cscs/tkwiecinski/hallucination_probes/probes/clean_code_llama3_1_8b_lora'), hf_repo_id='obalcells/hallucination-probes', threshold=0.5)