In [1]:

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import wandb
from sae_lens import SAE
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f56b0f306b0>

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device

In [4]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_20/width_16k/canonical",
)
sae = sae.to(device)

In [5]:
def select_action(policy_net, observation):
    # Get activated logits from policy_net (B, dict_size)
    logits = policy_net(observation)
    # Add small random noise for diversity
    epsilon = torch.randn_like(logits) * 0.01
    logits_noisy = logits + epsilon
    # Select top-1 value and index from noisy logits
    topk_vals, topk_indices = torch.topk(logits_noisy, k=1, dim=-1)
    # Clamp top values to a minimum of 1 (values below 1 become 1)
    topk_vals = torch.clamp(topk_vals, min=50.0)
    # Create one-hot action vector with the clamped top value at the selected index
    action = torch.zeros_like(logits)
    action.scatter_(1, topk_indices, topk_vals)
    # Compute softmax probabilities from the original logits and get log probability of the chosen index
    probs = torch.softmax(logits, dim=-1)
    chosen_probs = torch.gather(probs, 1, topk_indices)
    log_prob = torch.log(chosen_probs).squeeze(1)
    return action, log_prob

In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self, latent_dim, dict_size):
        super(PolicyNetwork, self).__init__()
        self.linear = nn.Linear(latent_dim, dict_size)
        self.activation = nn.Tanh()  # using Tanh instead of ReLU
    
    def forward(self, obs):
        logits = self.linear(obs)
        activated = self.activation(logits)
        return activated

In [7]:
class CriticNetwork(nn.Module):
    def __init__(self, latent_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(latent_dim, latent_dim)
        self.act = nn.Tanh()  # using Tanh here
        self.fc2 = nn.Linear(latent_dim, 1)
        
    def forward(self, obs):
        x = self.act(self.fc1(obs))
        value = self.fc2(x)
        return value

In [8]:
def batch_steering_hook(policy_net, sae):
    class SteeringHook:
        def __init__(self, policy_net, sae):
            self.policy_net = policy_net
            self.sae = sae
            self.observation = None   # will be tensor of shape (B, latent_dim)
            self.action = None        # (B,)
            self.log_prob = None      # (B,)

        def __call__(self, module, inputs):
            residual = inputs[0]  # shape: (B, seq_len, hidden_dim)
            observation = residual[:, -1, :]  # (B, latent_dim)
            self.observation = observation.detach()
            action, log_prob = select_action(self.policy_net, observation)
            steering = sae.decode(action)  # shape: (B, hidden_dim)
            self.action = action.detach()
            self.log_prob = log_prob.detach()
            # Add the corresponding steering vector to the last token.
            residual[:, -1, :] = residual[:, -1, :] + steering
            return (residual)
    return SteeringHook(policy_net, sae)

In [9]:
class PPOTrainer:
    def __init__(self, policy, critic, batch_size=8, ppo_clip=0.2, lr=1e-4):
        self.policy = policy
        self.critic = critic
        self.batch_size = batch_size
        self.ppo_clip = ppo_clip
        self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr)
    
    def compute_advantages(self, rewards, values):
        return rewards - values

    def train_step(self, observations, actions, rewards, old_log_probs):
        # Compute policy outputs.
        mean = self.policy(observations)  # shape: (B, action_dim)
        sigma = torch.ones_like(mean) * 0.1  # fixed sigma
        dist = torch.distributions.Normal(mean, sigma)
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        
        # Compute critic outputs.
        values = self.critic(observations).squeeze(-1)
        
        advantages = rewards - values  # simple advantage
        ratio = torch.exp(new_log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        
        critic_loss = nn.MSELoss()(values, rewards)
        
        # Combine losses to perform a single backward pass.
        total_loss = policy_loss + critic_loss
        
        self.optimizer_policy.zero_grad()
        self.optimizer_critic.zero_grad()
        total_loss.backward()
        self.optimizer_policy.step()
        self.optimizer_critic.step()
        
        return policy_loss.item(), critic_loss.item()

In [10]:
class MMLUDataLoader:
    def __init__(self, dataset, split, limit=None):
        self.data = dataset[split]
        if limit is not None:
            self.data = self.data.select(range(limit))
        self.index = 0
        self.n_samples = len(self.data)
        
    def get_batch(self, batch_size):
        batch = []
        for _ in range(batch_size):
            sample = self.data[self.index % self.n_samples]
            self.index += 1
            batch.append({
                "question": sample["question"],
                "choices": sample["choices"],
                "answer": sample["answer"]
            })
        return batch

In [11]:
mmlu_dataset = load_dataset("cais/mmlu", "all")
train_loader = MMLUDataLoader(mmlu_dataset, split="auxiliary_train")
val_loader = MMLUDataLoader(mmlu_dataset, split="validation", limit=100)

In [12]:
LATENT_DIM = 2304   # Gemma's latent dimension (from layer 20)
DICT_SIZE = 16384   # SAE dictionary size

In [13]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
llm = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", output_hidden_states=True).to(device)
llm.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemm

In [14]:
policy_net = PolicyNetwork(LATENT_DIM, DICT_SIZE).to(device)
critic_net = CriticNetwork(LATENT_DIM).to(device)
ppo_trainer = PPOTrainer(policy_net, critic_net, batch_size=8, ppo_clip=0.2, lr=1e-4)

In [15]:
from tqdm import tqdm
import os

def train(num_steps=20000, validate_every=100, checkpoint_every=200, checkpoint_dir="./checkpoints"):
    wandb.init(project="gemma_mmlu_ppo")
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_val_accuracy = 0.0
    train_acc_sum = 0.0
    train_acc_count = 0
    for step in tqdm(range(1, num_steps + 1), desc="Training Steps"):
        batch = train_loader.get_batch(ppo_trainer.batch_size)
        prompts, correct_answers = [], []
        for sample in batch:
            question = sample["question"]
            choices = sample["choices"]
            if isinstance(sample["answer"], int):
                correct_answer = chr(65 + sample["answer"])
            else:
                correct_answer = sample["answer"].strip().upper()
            prompt = (question + "\n" +
                      "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)) +
                      "\nChoose one of the following options only: A, B, C, or D" +
                      "\nAnswer:")
            prompts.append(prompt)
            correct_answers.append(correct_answer)
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        steering_hook = batch_steering_hook(policy_net, sae)
        hook_handle = llm.model.layers[20].register_forward_pre_hook(steering_hook)
        generated_ids = llm.generate(inputs, max_new_tokens=1)
        hook_handle.remove()
        batch_rewards = []
        for i in range(generated_ids.shape[0]):
            gen_tok = tokenizer.decode(generated_ids[i, -1]).strip()
            predicted_label = gen_tok[0].upper() if gen_tok else ""
            reward = 1 if predicted_label == correct_answers[i] else 0
            batch_rewards.append(reward)
        obs_batch = steering_hook.observation
        action_batch = steering_hook.action
        log_prob_batch = steering_hook.log_prob
        rewards_tensor = torch.tensor(batch_rewards, dtype=torch.float32, device=device)
        nonzero_vals = action_batch[action_batch != 0]
        avg_activation = nonzero_vals.mean().item() if nonzero_vals.numel() > 0 else 0.0
        policy_loss, critic_loss = ppo_trainer.train_step(obs_batch, action_batch, rewards_tensor, log_prob_batch)
        train_accuracy = sum(batch_rewards) / len(batch_rewards)
        train_acc_sum += train_accuracy
        train_acc_count += 1
        wandb.log({"step": step, "policy_loss": policy_loss, "critic_loss": critic_loss, 
                   "train_accuracy": train_accuracy, "avg_activation": avg_activation})
        if step % validate_every == 0:
            used_indices = set()  # clear used_indices each validation step
            val_batch = val_loader.get_batch(val_loader.n_samples)
            val_prompts, correct_answers_val = [], []
            for sample in val_batch:
                question = sample["question"]
                choices = sample["choices"]
                if isinstance(sample["answer"], int):
                    correct_answer = chr(65 + sample["answer"])
                else:
                    correct_answer = sample["answer"].strip().upper()
                prompt = (question + "\n" +
                          "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)) +
                          "\nChoose one of the following options only: A, B, C, or D" +
                          "\nAnswer:")
                val_prompts.append(prompt)
                correct_answers_val.append(correct_answer)
            inputs_val = tokenizer(val_prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
            steering_hook_val = batch_steering_hook(policy_net, sae)
            hook_handle_val = llm.model.layers[20].register_forward_pre_hook(steering_hook_val)
            generated_ids_val = llm.generate(inputs_val, max_new_tokens=1)
            hook_handle_val.remove()
            total_correct = 0
            val_action_batch = steering_hook_val.action
            topk_vals_val, topk_indices_val = torch.topk(val_action_batch, k=1, dim=-1)
            for idx in topk_indices_val.view(-1).tolist():
                used_indices.add(idx)
            for i in range(generated_ids_val.shape[0]):
                gen_tok = tokenizer.decode(generated_ids_val[i, -1]).strip()
                predicted_label = gen_tok[0].upper() if gen_tok else ""
                if predicted_label == correct_answers_val[i]:
                    total_correct += 1
            val_accuracy = total_correct / len(val_batch)
            avg_train_accuracy = train_acc_sum / train_acc_count if train_acc_count > 0 else 0
            wandb.log({"step": step, "val_accuracy": val_accuracy, "avg_train_accuracy": avg_train_accuracy, 
                       "unique_indices": len(used_indices)})
            print(f"Step {step}: Policy Loss {policy_loss:.4f}, Critic Loss {critic_loss:.4f}, "
                  f"Avg Train Acc {avg_train_accuracy:.4f}, Val Acc {val_accuracy:.4f}, "
                  f"Unique Indices: {len(used_indices)}")
            train_acc_sum, train_acc_count = 0.0, 0
        if step % checkpoint_every == 0:
            checkpoint_path = os.path.join(
                checkpoint_dir,
                f"gemma-2-2b_layer20_ppo_lr{ppo_trainer.optimizer_policy.defaults['lr']}_batch{ppo_trainer.batch_size}_step{step}.pt"
            )
            torch.save({
                'step': step,
                'policy_state_dict': policy_net.state_dict(),
                'critic_state_dict': critic_net.state_dict(),
                'optimizer_policy_state_dict': ppo_trainer.optimizer_policy.state_dict(),
                'optimizer_critic_state_dict': ppo_trainer.optimizer_critic.state_dict()
            }, checkpoint_path)
            if step % validate_every == 0 and val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                best_checkpoint_path = os.path.join(checkpoint_dir, "best_policy.pt")
                torch.save({
                    'step': step,
                    'policy_state_dict': policy_net.state_dict(),
                    'critic_state_dict': critic_net.state_dict(),
                    'optimizer_policy_state_dict': ppo_trainer.optimizer_policy.state_dict(),
                    'optimizer_critic_state_dict': ppo_trainer.optimizer_critic.state_dict()
                }, best_checkpoint_path)


In [None]:
train()

[34m[1mwandb[0m: Currently logged in as: [33mseonglae[0m ([33mtexonom[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Training Steps:   0%|          | 0/20000 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
Training Steps:   0%|          | 100/20000 [01:17<9:04:38,  1.64s/it]

Step 100: Policy Loss 0.0409, Critic Loss 0.3624, Avg Train Acc 0.3825, Val Acc 0.4800, Unique Indices: 100


Training Steps:   1%|          | 199/20000 [02:00<58:08,  5.68it/s]  

Step 200: Policy Loss -0.0000, Critic Loss 0.0144, Avg Train Acc 0.6550, Val Acc 0.4900, Unique Indices: 98


Training Steps:   2%|▏         | 300/20000 [02:38<6:23:43,  1.17s/it] 

Step 300: Policy Loss 0.1686, Critic Loss 0.1412, Avg Train Acc 0.7300, Val Acc 0.4800, Unique Indices: 99


Training Steps:   2%|▏         | 399/20000 [03:18<3:31:29,  1.54it/s]

Step 400: Policy Loss 0.0563, Critic Loss 0.0481, Avg Train Acc 0.7937, Val Acc 0.4800, Unique Indices: 98


Training Steps:   2%|▎         | 500/20000 [04:33<8:29:47,  1.57s/it] 

Step 500: Policy Loss -0.0000, Critic Loss 0.0733, Avg Train Acc 0.8938, Val Acc 0.4800, Unique Indices: 99


Training Steps:   3%|▎         | 599/20000 [05:12<1:08:47,  4.70it/s]

Step 600: Policy Loss -0.0000, Critic Loss 0.0716, Avg Train Acc 0.8612, Val Acc 0.5000, Unique Indices: 99


Training Steps:   4%|▎         | 700/20000 [06:58<11:13:19,  2.09s/it]

Step 700: Policy Loss 0.0097, Critic Loss 0.3348, Avg Train Acc 0.7013, Val Acc 0.4800, Unique Indices: 100


Training Steps:   4%|▍         | 799/20000 [08:34<4:56:35,  1.08it/s] 

Step 800: Policy Loss 0.0429, Critic Loss 0.0936, Avg Train Acc 0.7100, Val Acc 0.4900, Unique Indices: 97


Training Steps:   4%|▍         | 900/20000 [10:17<9:48:30,  1.85s/it] 

Step 900: Policy Loss 0.1717, Critic Loss 0.1383, Avg Train Acc 0.7025, Val Acc 0.4900, Unique Indices: 99


Training Steps:   5%|▍         | 999/20000 [11:50<4:45:29,  1.11it/s]

Step 1000: Policy Loss 0.1444, Critic Loss 0.1196, Avg Train Acc 0.7238, Val Acc 0.4900, Unique Indices: 99


Training Steps:   6%|▌         | 1100/20000 [14:17<9:49:11,  1.87s/it] 

Step 1100: Policy Loss 0.1838, Critic Loss 0.3074, Avg Train Acc 0.6787, Val Acc 0.4900, Unique Indices: 99


Training Steps:   6%|▌         | 1199/20000 [15:47<5:06:54,  1.02it/s]

Step 1200: Policy Loss 0.0298, Critic Loss 0.1865, Avg Train Acc 0.7200, Val Acc 0.5100, Unique Indices: 97


Training Steps:   6%|▋         | 1300/20000 [17:38<9:51:52,  1.90s/it] 

Step 1300: Policy Loss 0.0414, Critic Loss 0.3279, Avg Train Acc 0.6863, Val Acc 0.5000, Unique Indices: 99


Training Steps:   7%|▋         | 1399/20000 [19:15<5:05:49,  1.01it/s]

Step 1400: Policy Loss 0.0871, Critic Loss 0.2009, Avg Train Acc 0.6925, Val Acc 0.4900, Unique Indices: 99


Training Steps:   8%|▊         | 1500/20000 [21:04<10:37:24,  2.07s/it]

Step 1500: Policy Loss 0.0243, Critic Loss 0.2891, Avg Train Acc 0.6937, Val Acc 0.5000, Unique Indices: 100


Training Steps:   8%|▊         | 1599/20000 [22:41<5:16:15,  1.03s/it] 

Step 1600: Policy Loss 0.2449, Critic Loss 0.2176, Avg Train Acc 0.6875, Val Acc 0.5000, Unique Indices: 97


Training Steps:   8%|▊         | 1700/20000 [24:38<9:55:42,  1.95s/it] 

Step 1700: Policy Loss 0.1278, Critic Loss 0.2253, Avg Train Acc 0.6987, Val Acc 0.4900, Unique Indices: 99


Training Steps:   9%|▉         | 1799/20000 [26:18<5:10:42,  1.02s/it]

Step 1800: Policy Loss 0.0486, Critic Loss 0.2593, Avg Train Acc 0.7050, Val Acc 0.5000, Unique Indices: 100


Training Steps:  10%|▉         | 1900/20000 [28:06<9:33:34,  1.90s/it] 

Step 1900: Policy Loss 0.0580, Critic Loss 0.0870, Avg Train Acc 0.7113, Val Acc 0.4900, Unique Indices: 100


Training Steps:  10%|▉         | 1999/20000 [29:42<5:04:17,  1.01s/it]

Step 2000: Policy Loss 0.1487, Critic Loss 0.1989, Avg Train Acc 0.6975, Val Acc 0.4800, Unique Indices: 97


Training Steps:  10%|█         | 2100/20000 [31:42<9:20:54,  1.88s/it] 

Step 2100: Policy Loss 0.1816, Critic Loss 0.1425, Avg Train Acc 0.7150, Val Acc 0.5000, Unique Indices: 99


Training Steps:  11%|█         | 2199/20000 [33:20<4:41:21,  1.05it/s]

Step 2200: Policy Loss 0.0984, Critic Loss 0.1764, Avg Train Acc 0.7075, Val Acc 0.5100, Unique Indices: 98


Training Steps:  12%|█▏        | 2300/20000 [35:13<9:10:36,  1.87s/it] 

Step 2300: Policy Loss 0.1493, Critic Loss 0.2851, Avg Train Acc 0.7050, Val Acc 0.4800, Unique Indices: 99


Training Steps:  12%|█▏        | 2399/20000 [36:50<7:32:34,  1.54s/it]

Step 2400: Policy Loss 0.1623, Critic Loss 0.1573, Avg Train Acc 0.6700, Val Acc 0.4900, Unique Indices: 100


Training Steps:  12%|█▎        | 2500/20000 [38:36<10:41:23,  2.20s/it]

Step 2500: Policy Loss 0.0960, Critic Loss 0.1678, Avg Train Acc 0.6925, Val Acc 0.5200, Unique Indices: 100


Training Steps:  13%|█▎        | 2599/20000 [40:15<5:08:48,  1.06s/it] 

Step 2600: Policy Loss -0.0000, Critic Loss 0.0685, Avg Train Acc 0.6737, Val Acc 0.5000, Unique Indices: 97


Training Steps:  14%|█▎        | 2700/20000 [42:05<9:43:53,  2.03s/it] 

Step 2700: Policy Loss 0.0116, Critic Loss 0.0228, Avg Train Acc 0.7075, Val Acc 0.5000, Unique Indices: 100


Training Steps:  14%|█▍        | 2799/20000 [43:42<4:45:46,  1.00it/s]

Step 2800: Policy Loss 0.0662, Critic Loss 0.1831, Avg Train Acc 0.6750, Val Acc 0.4900, Unique Indices: 99


Training Steps:  14%|█▍        | 2900/20000 [45:28<10:21:57,  2.18s/it]

Step 2900: Policy Loss 0.0715, Critic Loss 0.1299, Avg Train Acc 0.7013, Val Acc 0.4800, Unique Indices: 100


Training Steps:  15%|█▍        | 2999/20000 [46:59<4:56:17,  1.05s/it] 

Step 3000: Policy Loss 0.1003, Critic Loss 0.2673, Avg Train Acc 0.6825, Val Acc 0.4700, Unique Indices: 99


Training Steps:  16%|█▌        | 3100/20000 [50:01<9:53:30,  2.11s/it]  

Step 3100: Policy Loss 0.2471, Critic Loss 0.2874, Avg Train Acc 0.6963, Val Acc 0.4900, Unique Indices: 99


Training Steps:  16%|█▌        | 3199/20000 [51:36<5:07:56,  1.10s/it]

Step 3200: Policy Loss 0.1095, Critic Loss 0.1565, Avg Train Acc 0.7200, Val Acc 0.5000, Unique Indices: 99


Training Steps:  16%|█▋        | 3300/20000 [53:27<9:15:07,  1.99s/it] 

Step 3300: Policy Loss 0.0634, Critic Loss 0.1897, Avg Train Acc 0.6663, Val Acc 0.5000, Unique Indices: 99


Training Steps:  17%|█▋        | 3399/20000 [55:01<4:13:40,  1.09it/s]

Step 3400: Policy Loss 0.1591, Critic Loss 0.1481, Avg Train Acc 0.6600, Val Acc 0.4900, Unique Indices: 99


Training Steps:  18%|█▊        | 3500/20000 [56:50<8:52:32,  1.94s/it] 

Step 3500: Policy Loss 0.1417, Critic Loss 0.3130, Avg Train Acc 0.7050, Val Acc 0.4900, Unique Indices: 100


Training Steps:  18%|█▊        | 3599/20000 [58:25<4:55:25,  1.08s/it]

Step 3600: Policy Loss 0.1367, Critic Loss 0.3177, Avg Train Acc 0.6637, Val Acc 0.4800, Unique Indices: 99


Training Steps:  18%|█▊        | 3700/20000 [1:00:13<8:45:22,  1.93s/it]

Step 3700: Policy Loss 0.2872, Critic Loss 0.3500, Avg Train Acc 0.6863, Val Acc 0.4900, Unique Indices: 98


Training Steps:  19%|█▉        | 3768/20000 [1:01:17<4:05:14,  1.10it/s]