In [1]:
!pip install -U transformers
!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m41.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [16]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
torch.cuda.is_available(), torch.cuda.get_device_name(0)

(True, 'NVIDIA L4')

In [17]:

#we're going to hold off on loading the model until we've initialized the tokenizer and new token

'''
# Load model directly
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    load_in_4bit=True,
    device_map="auto"
)
'''

NEW_TOKEN = "~short"
INIT_WORD = "general"  # Semantically neutral word for initialization

tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    use_fast=False
)

# Add the new token
num_added = tokenizer.add_tokens([NEW_TOKEN])
print(f"Added {num_added} new token(s): {NEW_TOKEN}")
print(f"New vocab size: {len(tokenizer)}")

# Save the updated tokenizer
tokenizer.save_pretrained("my_tokenizer")

Added 1 new token(s): ~short
New vocab size: 32001


('my_tokenizer/tokenizer_config.json',
 'my_tokenizer/special_tokens_map.json',
 'my_tokenizer/tokenizer.model',
 'my_tokenizer/added_tokens.json')

In [18]:
# Reload the tokenizer to ensure consistency
tokenizer = AutoTokenizer.from_pretrained("my_tokenizer", use_fast=False)

print("\nLoading model...")

# Now load the model
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    load_in_4bit=True,
    device_map="auto"
)


Loading model...


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [19]:
print(f"\nResizing model embeddings from {model.config.vocab_size} to {len(tokenizer)}...")

model.resize_token_embeddings(len(tokenizer))
print(f"✓ Model resized. New vocab size: {model.config.vocab_size}")


Resizing model embeddings from 32000 to 32001...
✓ Model resized. New vocab size: 32001


In [20]:
print(f"\nInitializing '{NEW_TOKEN}' with embedding from '{INIT_WORD}'...")

# Get the ID of the initialization word
init_token_id = tokenizer.convert_tokens_to_ids(INIT_WORD)
if init_token_id == tokenizer.unk_token_id:
    raise ValueError(f"'{INIT_WORD}' not in vocabulary! Choose a different word.")

# Get the ID of the new token
new_token_id = tokenizer.convert_tokens_to_ids(NEW_TOKEN)
print(f"'{NEW_TOKEN}' assigned ID: {new_token_id}")

# Get embedding layer and initialize
embedding_layer = model.get_input_embeddings()
init_vec = embedding_layer.weight[init_token_id].clone()

with torch.no_grad():
    embedding_layer.weight[new_token_id] = init_vec

print(f"✓ Initialized '{NEW_TOKEN}' with '{INIT_WORD}' embedding")
print(f"  Vector norm: {torch.norm(embedding_layer.weight[new_token_id]).item():.4f}")


Initializing '~short' with embedding from 'general'...
'~short' assigned ID: 32000
✓ Initialized '~short' with 'general' embedding
  Vector norm: 0.1699


In [21]:
print("\nSetting up gradient masking...")

# Freeze ALL parameters first
for param in model.parameters():
    param.requires_grad = False

# Get references to the layers we want to train
embed_weight = model.get_input_embeddings().weight
lm_head_weight = model.get_output_embeddings().weight

# Enable gradients on these weights
embed_weight.requires_grad = True
lm_head_weight.requires_grad = True

# Register gradient hooks to mask all but the new token
def mask_embedding_grad(grad):
    mask = torch.zeros_like(grad)
    mask[new_token_id] = 1.0
    return grad * mask

def mask_lm_head_grad(grad):
    mask = torch.zeros_like(grad)
    mask[new_token_id] = 1.0
    return grad * mask

embed_weight.register_hook(mask_embedding_grad)
lm_head_weight.register_hook(mask_lm_head_grad)

print("✓ Gradient masking configured for token ID:", new_token_id)


Setting up gradient masking...
✓ Gradient masking configured for token ID: 32000


In [22]:
optimizer = torch.optim.AdamW(
    [embed_weight, lm_head_weight],
    lr=1e-3,
    weight_decay=0.0
)

print("✓ Optimizer created")

✓ Optimizer created


In [23]:
from google.colab import files
import json

# Upload the file
uploaded = files.upload()

# Save to Colab filesystem (it gets saved automatically, but let's be explicit)
# The file is now in the current directory

# Read it normally
examples = []
with open("first_100.jsonl", "r") as f:
    for line in f:
        if line.strip():
            examples.append(json.loads(line))

print(f"Successfully loaded {len(examples)} examples.")
print(f"\nFirst example:")
print(json.dumps(examples[0], indent=2))

Saving first_100.jsonl to first_100 (1).jsonl
Successfully loaded 100 examples.

First example:
{
  "prompt": "Can brain cells move? By movement I mean long distance migration (preferably within the brain only). Give me a ~short answer.",
  "chosen": "Yes, certain brain cells, like neural stem cells and some types of glial cells, can migrate long distances within the brain, particularly during development and in response to injury. However, most mature neurons are largely static in their locations.",
  "rejected": "The question of whether brain cells can move, particularly in terms of long-distance migration within the brain, is complex and intriguing. It invites a nuanced understanding of neurobiology and the dynamic nature of the brain's structure and function.\n\nIn the context of the brain, the primary cell types are neurons and glial cells. Neurons are the primary signaling cells responsible for transmitting information through electrical and chemical signals, while glial cells su

In [24]:
#Modified loss function to reflect loss in hewitt et al (Feb 2025) p.12
'''
L(x, yc, yr) = -log σ(β log[pθ(yc|x)/pθ(yr|x)] - β log[pθ₀(yc|x)/pθ₀(yr|x)])
               - log σ(β log[pθ(yc|x)/pθ₀(yc|x)])
'''
import torch.nn.functional as F

# === SETUP: Create reference model (frozen copy) ===
print("Creating reference model...")

# Load a separate reference model (completely frozen)
if 'ref_model' not in globals():
    ref_model = AutoModelForCausalLM.from_pretrained(
        "mistralai/Mistral-7B-v0.1",
        load_in_4bit=True,
        device_map="auto"
    )
    # Resize reference model to match vocabulary
    ref_model.resize_token_embeddings(len(tokenizer))
    print("✓ Reference model loaded and resized")
else:
    print("Reference model already loaded")

# match the main model's initialization
print(f"Initializing reference model '{NEW_TOKEN}' from '{INIT_WORD}'...")
ref_embeds = ref_model.get_input_embeddings()
init_token_id = tokenizer.convert_tokens_to_ids(INIT_WORD)
new_token_id = tokenizer.convert_tokens_to_ids(NEW_TOKEN)

with torch.no_grad():
    ref_embeds.weight[new_token_id] = ref_embeds.weight[init_token_id].clone()

# Ensure reference model is completely frozen
ref_model.eval()
for param in ref_model.parameters():
    param.requires_grad = False

# === MODIFIED LOSS FUNCTION ===
def compute_log_prob_with_model(model_to_use, prompt, response):
    """Compute log p(response | prompt) using specified model"""
    enc_prompt = tokenizer(prompt, return_tensors="pt").to("cuda")
    prompt_ids = enc_prompt["input_ids"]

    enc_resp = tokenizer(response, return_tensors="pt").to("cuda")
    resp_ids = enc_resp["input_ids"]

    input_ids = torch.cat([prompt_ids, resp_ids], dim=1)
    labels = torch.cat([
        torch.full_like(prompt_ids, -100),
        resp_ids
    ], dim=1)

    with torch.set_grad_enabled(model_to_use.training):
        outputs = model_to_use(input_ids=input_ids, labels=labels)

    nll = outputs.loss
    num_response_tokens = resp_ids.numel()

    log_prob = -nll * num_response_tokens
    return log_prob

beta = 0.2

def apo_up_loss(prompt, chosen, rejected):
    # Current model probabilities (trainable)
    log_pc = compute_log_prob_with_model(model, prompt, chosen)
    log_pr = compute_log_prob_with_model(model, prompt, rejected)

    # Reference model probabilities (frozen, no grad)
    with torch.no_grad():
        log_pc_ref = compute_log_prob_with_model(ref_model, prompt, chosen)
        log_pr_ref = compute_log_prob_with_model(ref_model, prompt, rejected)

    # Check for NaN early
    if torch.isnan(log_pc) or torch.isnan(log_pr):
        print(" ⚠️ NaN in current model probabilities!")
        return None

    # APO-up loss with reference model
    # Term 1: DPO loss (likelihood ratio vs reference)
    llr_current = log_pc - log_pr
    llr_ref = log_pc_ref - log_pr_ref
    t1 = -F.logsigmoid(beta * (llr_current - llr_ref))

    # Term 2: Anchor term (chosen vs reference)
    t2 = -F.logsigmoid(beta * (log_pc - log_pc_ref))

    loss = t1 + t2

    # Debugging
    if torch.isnan(loss):
        print(f"  NaN in loss! llr_curr={llr_current.item():.2f}, "
              f"llr_ref={llr_ref.item():.2f}")
        return None

    return loss

Creating reference model...
Reference model already loaded


In [25]:
# === DIAGNOSTIC CHECKS ===
print("\n=== PRE-TRAINING DIAGNOSTICS ===")

# 1. Check all relevant shapes
embed_layer = model.get_input_embeddings()
lm_head = model.get_output_embeddings()

print(f"Embedding weight shape: {embed_layer.weight.shape}")
print(f"LM head weight shape: {lm_head.weight.shape}")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model config vocab size: {model.config.vocab_size}")

# 2. Check if weights require gradients
print(f"\nEmbedding requires_grad: {embed_layer.weight.requires_grad}")
print(f"LM head requires_grad: {lm_head.weight.requires_grad}")

# 3. Test a forward pass before training
print("\n=== TESTING FORWARD PASS ===")
test_prompt = "Hello"
test_tokens = tokenizer(test_prompt, return_tensors="pt").to("cuda")
try:
    with torch.no_grad():
        test_output = model(**test_tokens)
    print(f"✓ Forward pass successful, logits shape: {test_output.logits.shape}")
except Exception as e:
    print(f"✗ Forward pass failed: {e}")

# 4. Test backward pass on a simple example
print("\n=== TESTING BACKWARD PASS ===")
try:
    model.train()
    optimizer.zero_grad()

    test_input = tokenizer("Test input", return_tensors="pt").to("cuda")
    test_labels = test_input["input_ids"].clone()

    outputs = model(**test_input, labels=test_labels)
    loss = outputs.loss
    loss.backward()

    print(f"✓ Backward pass successful, loss: {loss.item():.4f}")

    # Check which parameters actually got gradients
    params_with_grad = sum(1 for p in model.parameters() if p.grad is not None)
    print(f"✓ Parameters with gradients: {params_with_grad}")

except Exception as e:
    print(f"✗ Backward pass failed: {e}")

print("\n=== END DIAGNOSTICS ===\n")



=== PRE-TRAINING DIAGNOSTICS ===
Embedding weight shape: torch.Size([32001, 4096])
LM head weight shape: torch.Size([32001, 4096])
Tokenizer vocab size: 32001
Model config vocab size: 32001

Embedding requires_grad: True
LM head requires_grad: True

=== TESTING FORWARD PASS ===
✓ Forward pass successful, logits shape: torch.Size([1, 2, 32001])

=== TESTING BACKWARD PASS ===
✓ Backward pass successful, loss: 14.0295
✓ Parameters with gradients: 2

=== END DIAGNOSTICS ===



In [26]:
#debugging the training loop
num_epochs = 5

model.train()

for epoch in range(num_epochs):
    total_loss = 0

    for i, ex in enumerate(examples):
        if i >= 3:  # Only debug first 3 examples
            break

        print(f"\nExample {i+1}:")
        print(f"  Prompt: {ex['prompt'][:50]}...")

        optimizer.zero_grad()
        loss = apo_up_loss(ex["prompt"], ex["chosen"], ex["rejected"])

        if loss is None or torch.isnan(loss):
            print("  Skipping this example due to NaN")
            continue

        loss.backward()

        # Clip gradients to prevent inf/NaN corruption
        torch.nn.utils.clip_grad_norm_(
            [embed_weight, lm_head_weight],
            max_norm=1.0
        )

        # Check gradients before update
        if torch.isnan(embed_weight.grad).any() or torch.isinf(embed_weight.grad).any():
            print(f"⚠️ Skipping example {i} with invalid gradients")
            optimizer.zero_grad()  # Clear bad gradients
            continue

        optimizer.step()

        total_loss += loss.item()

        if (i + 1) % 20 == 0:
            print(f"  Epoch {epoch+1}, Example {i+1}/{len(examples)}, "
                  f"Current loss: {loss.item():.4f}")

    avg_loss = total_loss / len(examples)
    print(f"Epoch {epoch+1}/{num_epochs} | Avg Loss: {avg_loss:.4f}")

print("\n✓ Training complete!")


Example 1:
  Prompt: Can brain cells move? By movement I mean long dist...

Example 2:
  Prompt: In our computer systems lecture we were introduced...
 ⚠️ NaN in current model probabilities!
  Skipping this example due to NaN

Example 3:
  Prompt: View tabular file such as CSV from command line, h...
 ⚠️ NaN in current model probabilities!
  Skipping this example due to NaN
Epoch 1/5 | Avg Loss: 0.0120

Example 1:
  Prompt: Can brain cells move? By movement I mean long dist...
 ⚠️ NaN in current model probabilities!
  Skipping this example due to NaN

Example 2:
  Prompt: In our computer systems lecture we were introduced...
 ⚠️ NaN in current model probabilities!
  Skipping this example due to NaN

Example 3:
  Prompt: View tabular file such as CSV from command line, h...
 ⚠️ NaN in current model probabilities!
  Skipping this example due to NaN
Epoch 2/5 | Avg Loss: 0.0000

Example 1:
  Prompt: Can brain cells move? By movement I mean long dist...
 ⚠️ NaN in current model probabilit

KeyboardInterrupt: 