In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from pathlib import Path

# ML
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange, reduce, repeat
from jaxtyping import Float, Int, Bool
from torch.utils.data import DataLoader

import wandb

# Numeric
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt


In [4]:
from reprpo.interventions.config import ExperimentConfig
from reprpo.models.load import load_model, print_trainable_parameters
args = ExperimentConfig(batch_size=2)
args

ExperimentConfig(dataset='us_history_textbook', verbose=1, dev=False, load_in_4bit=False, load_in_8bit=False, use_gradient_checkpointing=False, batch_size=2, n_samples=5400, eval_samples=None, max_length=196, max_prompt_length=96, base_model='wassname/llama-3-2-1b-sft', save=True, wandb=True)

## Load model

In [5]:
model, tokenizer = load_model(args.base_model, load_in_4bit=args.load_in_4bit,  load_in_8bit=args.load_in_8bit,  
                            #   attn_implementation='eager' # for gemma
)
model


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm

In [6]:
# peft_config = LoraConfig(
#     r=64,
#     lora_alpha=16,
#     use_rslora=True,
#     # use_dora=True,
#     task_type="CAUSAL_LM",
#     # target_modules=["all-linear"], #  QLoRA-style training
# )
# # if hasattr(PL_MODEL, 'setup_grad_proj'):
# #     peft_config = PL_MODEL.setup_grad_proj(peft_config)

# model = get_peft_model(model, peft_config, adapter_name=adapter_name)
# print_trainable_parameters(model)

## Dataset

In [7]:
from datasets import load_dataset
from reprpo.data.collate3 import TokenizeRow

def ds2dl(ds):
    return DataLoader(
        ds
        .select_columns(["chosen", "rejected", "chosen_mask", "rejected_mask"])
        .with_format("torch"),
        batch_size=args.batch_size,
    )

tokenize_row = TokenizeRow(
    tokenizer,
    max_length=args.max_length,
    max_prompt_length=args.max_prompt_length,
)
ds_train = load_dataset("wassname/genies_preferences", name=args.dataset)
ds_train_tok = ds_train.map(tokenize_row, batched=False)    
dl_train = ds2dl(ds_train_tok["train"])

In [8]:
# QC tokenization

r2 = ds_train['train'][0]
r = ds_train_tok['train'][0]

print(r2['prompt'], r2['chosen'])
print('-'*20)
print(tokenizer.decode(r['chosen']))


Below is an instruction that describes a task, paired with an input that provides further context. Complete the request to the best of your ability.

### Instruction:
Predict the next few sentences of the following excerpt from a high-quality US History textbook. 

### Input:
The Emancipation Proclamation is a significant document in American history. Can you explain the purpose and impact of this proclamation?

### Response:
 The Emancipation Proclamation was issued by President Abraham Lincoln in 1863 during the Civil War. Its purpose was to declare that all slaves in Confederate territory were to be set free. While it did not immediately free any slaves, it changed the nature of the war and gave a moral purpose to the Union's cause. The proclamation paved the way for the eventual abolition of slavery in the United States.
--------------------
Below is an instruction that describes a task, paired with an input that provides further context. Complete the request to the best of your ab

## Collect hs

In [None]:
from reprpo.interventions.dpo_helpers import compute_ptheta, compute_logprobs
from from reprpo.interventions.dpo import model_forward_with_logprobs, dpo_forward_batch, calc_dpo_loss_w_metrics

In [None]:
# Conceptual approach
def analyze_gradients(model, chosen, rejected):
    # Forward pass
    chosen_logits = model(chosen, output_hidden_states=True)
    rejected_logits = model(rejected, output_hidden_states=True)
    
    # Compute IPO loss
    loss = ipo_loss(chosen_logits, rejected_logits)
    
    # Get gradients w.r.t activations
    for layer in model.layers:
        layer.activations.retain_grad()
    
    loss.backward()
    
    # Analyze gradient magnitudes, directions, etc.