# Llama 3 Explanations with Captum Integrated Gradients

This notebook computes token-level attributions for **Llama 3** toxic completions using **Captum Integrated Gradients**.

We attribute the **log-probability of the actual completion tokens** (given the prompt) back to all input tokens.
The outputs are compatible with the JSON structure used in RQ2 for Gemma and Mistral.


## 1. Environment Setup


In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected. Please enable GPU in Runtime > Change runtime type > GPU")


In [None]:
# Install required packages
%pip install -q transformers bitsandbytes accelerate huggingface_hub captum


## 2. Hugging Face Authentication

Required for accessing Llama 3. Get your token from https://huggingface.co/settings/tokens


In [None]:
from huggingface_hub import login

# Run this cell and enter your Hugging Face token when prompted
login()


## 3. Load Data (toxic prompts and Llama 3 completions)

We assume the following RQ1 files are available:
- `RQ1/toxic.jsonl`
- `RQ1/completions_scores_llama3.jsonl`

Upload or copy them into this Colab environment before running the next cells.


In [None]:
# (Optional) create RQ1 directory if needed
import os
from pathlib import Path

os.makedirs("RQ1", exist_ok=True)
print("Ensured RQ1 directory exists")


In [None]:
# Option 1: Use Colab's file upload widget to upload RQ1 files
from google.colab import files

print("Upload toxic.jsonl and completions_scores_llama3.jsonl from RQ1:")
uploaded = files.upload()
for filename in uploaded.keys():
    target = filename
    if not filename.startswith("RQ1/"):
        target = f"RQ1/{filename}"
    os.makedirs(os.path.dirname(target), exist_ok=True)
    os.replace(filename, target)
    print(f"Saved: {target}")


In [None]:
# Verify files are present
required_files = [
    "RQ1/toxic.jsonl",
    "RQ1/completions_scores_llama3.jsonl",
]

for f in required_files:
    if os.path.exists(f):
        size_kb = os.path.getsize(f) / 1024
        print(f"✓ {f} ({size_kb:.1f} KB)")
    else:
        print(f"✗ {f} MISSING")


In [None]:
import json

TOXICITY_THRESHOLD = 0.7

def load_prompts(prompts_file: str):
    prompts = []
    with open(prompts_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                prompts.append(json.loads(line)['prompt']['text'])
    return prompts


def load_completions_with_scores(scores_file: str):
    completions, scores = [], []
    with open(scores_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                record = json.loads(line)
                completions.append(record['completion'])
                scores.append(record['score'])
    return completions, scores


def filter_toxic_outputs(prompts, completions, scores, threshold: float):
    toxic_prompts, toxic_completions, toxic_scores = [], [], []
    for prompt, completion, score in zip(prompts, completions, scores):
        if score >= threshold:
            toxic_prompts.append(prompt)
            toxic_completions.append(completion)
            toxic_scores.append(score)
    return toxic_prompts, toxic_completions, toxic_scores

prompts = load_prompts("RQ1/toxic.jsonl")
completions, scores = load_completions_with_scores("RQ1/completions_scores_llama3.jsonl")

print(f"Loaded {len(prompts)} prompts and {len(completions)} completions")

toxic_prompts, toxic_completions, toxic_scores = filter_toxic_outputs(
    prompts, completions, scores, TOXICITY_THRESHOLD
)
print(f"Found {len(toxic_prompts)} toxic outputs (threshold = {TOXICITY_THRESHOLD})")


## 4. Load Llama 3 model and tokenizer

We load the causal LM and its tokenizer, using 8-bit quantization for memory efficiency where possible.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

model_name = "meta-llama/Meta-Llama-3-8B"

print(f"Loading tokenizer: {model_name}")
llama_tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure we have a pad token for attention mask convenience
if llama_tokenizer.pad_token is None:
    llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "left"

print("Loading model (this may take a while)...")
if torch.cuda.is_available():
    quant_config = BitsAndBytesConfig(load_in_8bit=True)
    llama_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quant_config,
        device_map="auto",
    )
else:
    llama_model = AutoModelForCausalLM.from_pretrained(model_name)

llama_model.eval()
print("Model loaded.")


## 5. Define Captum Integrated Gradients setup

We define a forward function that returns the **sum of log-probabilities of the completion tokens** given the full sequence, and then use Captum Integrated Gradients on the input embeddings.


In [None]:
from captum.attr import IntegratedGradients
import torch.nn.functional as F
import numpy as np
import gc

llama_emb = llama_model.get_input_embeddings()


def forward_completion_logprob(
    inputs_embeds: torch.Tensor,
    attention_mask: torch.Tensor,
    completion_start_idx: int,
    completion_token_ids: torch.Tensor,
) -> torch.Tensor:
    """Return scalar: sum log-probabilities of completion tokens.

    inputs_embeds: [1, seq_len, hidden]
    attention_mask: [1, seq_len]
    completion_start_idx: index in the sequence where completion tokens begin
    completion_token_ids: [completion_len]
    """
    with torch.no_grad():
        pass  # just to make intent explicit; Captum will handle gradients

    outputs = llama_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
    logits = outputs.logits  # [1, seq_len, vocab]

    # Shift for causal LM: logits[:, t-1] predicts token at position t
    # completion positions correspond to indices [completion_start_idx .. completion_start_idx+L-1]
    start_for_logits = completion_start_idx - 1
    end_for_logits = start_for_logits + completion_token_ids.shape[0]

    # Safety clamp in case of boundary issues
    start_for_logits = max(start_for_logits, 0)
    end_for_logits = min(end_for_logits, logits.shape[1])

    pred_logits = logits[:, start_for_logits:end_for_logits, :]  # [1, L, vocab]

    # Adjust token ids slice if we had to clamp
    effective_len = pred_logits.shape[1]
    target_ids = completion_token_ids[:effective_len]

    log_probs = F.log_softmax(pred_logits, dim=-1)  # [1, L, vocab]
    idx = torch.arange(effective_len, device=log_probs.device)
    selected_log_probs = log_probs[0, idx, target_ids]

    # Sum of log-probs as scalar
    score = selected_log_probs.sum()
    return score


ig = IntegratedGradients(forward_completion_logprob)
print("Integrated Gradients object created.")


## 6. Compute token attributions for toxic examples

For each toxic prompt–completion pair, we:

1. Tokenize the full text (`prompt + completion`).
2. Identify the prompt / completion boundary.
3. Run Integrated Gradients over the input embeddings.
4. Aggregate attributions over the embedding dimension to get one score per token.
5. Save results to `explanations_llama3_ig.json`.


In [None]:
output_file = "explanations_llama3_ig.json"
checkpoint_file = "explanations_llama3_ig_checkpoint.json"
CHECKPOINT_INTERVAL = 20

# Load checkpoint if it exists
processed_indices = set()
results = []
if os.path.exists(checkpoint_file):
    print(f"Loading checkpoint from {checkpoint_file}")
    with open(checkpoint_file, "r", encoding="utf-8") as f:
        checkpoint = json.load(f)
    processed_indices = set(checkpoint.get("processed_indices", []))
    results = checkpoint.get("results", [])
    print(f"Resuming from {len(processed_indices)} processed items")


def save_checkpoint():
    with open(checkpoint_file, "w", encoding="utf-8") as f:
        json.dump({
            "processed_indices": list(processed_indices),
            "results": results,
        }, f, indent=2, ensure_ascii=False)
    print(f"Checkpoint saved with {len(processed_indices)} items.")


for i, (prompt, completion) in enumerate(zip(toxic_prompts, toxic_completions)):
    if i in processed_indices:
        continue

    if (i + 1) % 5 == 0:
        print(f"Processing {i + 1}/{len(toxic_prompts)}")

    try:
        full_text = prompt + completion

        # Tokenize full sequence
        full_encoded = llama_tokenizer(
            full_text,
            return_tensors="pt",
            add_special_tokens=True,
        )
        input_ids = full_encoded["input_ids"]  # [1, seq_len]
        attention_mask = full_encoded["attention_mask"]  # [1, seq_len]

        # Tokenize prompt alone to find boundary
        prompt_encoded = llama_tokenizer(
            prompt,
            return_tensors="pt",
            add_special_tokens=True,
        )
        prompt_token_ids = prompt_encoded["input_ids"][0].tolist()
        full_token_ids = input_ids[0].tolist()

        # Default boundary: length of prompt token ids
        prompt_end_idx = len(prompt_token_ids)
        if prompt_token_ids == full_token_ids[:len(prompt_token_ids)]:
            prompt_end_idx = len(prompt_token_ids)
        else:
            # Fallback: use naive boundary (could be refined if needed)
            prompt_end_idx = len(prompt_token_ids)

        # Completion token IDs (may be empty if something went wrong)
        completion_token_ids = input_ids[0, prompt_end_idx:]
        if completion_token_ids.numel() == 0:
            print(f"Warning: no completion tokens found for example {i}, skipping.")
            processed_indices.add(i)
            continue

        # Build inputs_embeds and baselines
        inputs_embeds = llama_emb(input_ids).to(llama_model.device)
        attention_mask = attention_mask.to(llama_model.device)
        completion_token_ids = completion_token_ids.to(llama_model.device)

        # Baseline: zero embeddings
        baselines = torch.zeros_like(inputs_embeds)

        # IG attribute
        attributions = ig.attribute(
            inputs=inputs_embeds,
            baselines=baselines,
            additional_forward_args=(
                attention_mask,
                int(prompt_end_idx),
                completion_token_ids,
            ),
            n_steps=50,
        )  # [1, seq_len, hidden]

        # Aggregate over hidden dimension to get per-token scores
        token_scores = attributions.squeeze(0).norm(dim=-1).detach().cpu().numpy().tolist()

        # Decode tokens (string form) for analysis compatibility
        tokens = llama_tokenizer.convert_ids_to_tokens(full_token_ids)

        # Split into prompt / completion parts
        prompt_tokens = tokens[:prompt_end_idx]
        completion_tokens = tokens[prompt_end_idx:]
        prompt_attributions = token_scores[:prompt_end_idx]
        completion_attributions = token_scores[prompt_end_idx:]

        results.append({
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "prompt_attributions": prompt_attributions,
            "completion_attributions": completion_attributions,
        })
        processed_indices.add(i)

        if (i + 1) % CHECKPOINT_INTERVAL == 0:
            save_checkpoint()
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error processing item {i + 1}: {e}")
        results.append({
            "prompt": prompt,
            "completion": completion,
            "error": str(e),
        })
        processed_indices.add(i)
        save_checkpoint()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# Save final results
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Saved final results to {output_file} with {len(results)} entries.")

# Cleanup checkpoint if everything completed
if os.path.exists(checkpoint_file):
    os.remove(checkpoint_file)
    print("Removed checkpoint file after successful completion.")


In [None]:
from google.colab import files

if os.path.exists(output_file):
    print(f"Downloading {output_file}...")
    files.download(output_file)
else:
    print(f"{output_file} not found.")
