In [None]:
import sys
sys.path.append('../..')  # Add parent directory to path
import os
import torch as t
import torch.nn as nn
t.cuda.empty_cache()
from transformers import Qwen2ForCausalLM, PreTrainedTokenizer, PreTrainedModel
from hf import HF
from evaluate import load

class SafetyAdapter(nn.Module):
    def __init__(self, hidden_size, dtype=None):
        super().__init__()
        # A small adapter: down-project, non-linearity, then up-project.
        self.adapter = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4, dtype=dtype),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, hidden_size, dtype=dtype),
        )
        # Initialize the final layer with zeros so that initially the adapter acts like an identity.
        nn.init.zeros_(self.adapter[2].weight)
        nn.init.zeros_(self.adapter[2].bias)

    def forward(self, hidden_states):
        adapter_device = self.adapter[0].weight.device
        
        # Ensure adapter input has the right dtype and device
        if hidden_states.dtype != self.adapter[0].weight.dtype or hidden_states.device != adapter_device:
            hidden_states = hidden_states.to(dtype=self.adapter[0].weight.dtype, device=adapter_device)
                
        adapter_out = self.adapter(hidden_states)
        
        # Return output in same dtype as input
        return hidden_states + adapter_out

class AdapterWrapper(nn.Module):
    """A wrapper that applies an adapter after the output of an existing module."""
    def __init__(self, original_module, adapter):
        super().__init__()
        self.original_module = original_module
        self.adapter = adapter
    
    def forward(self, *args, **kwargs):
        """Apply the original module and then the adapter."""
        outputs = self.original_module(*args, **kwargs)
        
        if isinstance(outputs, tuple):
            hidden_states = outputs[0]
            adapted_hidden_states = self.adapter(hidden_states)
            return (adapted_hidden_states,) + outputs[1:]
        else:
            return self.adapter(outputs)

def inject_adapter(model: PreTrainedModel, layer_idx: int) -> PreTrainedModel:
    # Determine the hidden size
    hidden_size = model.config.hidden_size
    
    # Find the layers
    layers = model.model.layers
    
    # Make sure the layer index is valid
    if layer_idx < 0 or layer_idx >= len(layers):
        raise ValueError(f"Layer index {layer_idx} is out of bounds. Model has {len(layers)} layers.")
    
    # Get the device of the specific layer we're attaching to
    target_layer = layers[layer_idx]
    layer_device = next(target_layer.parameters()).device
    layer_dtype = next(target_layer.parameters()).dtype
    
    print(f"Target layer {layer_idx} is using dtype: {layer_dtype} on device: {layer_device}")
    
    # Create the adapter with matching dtype and device
    adapter = SafetyAdapter(hidden_size, dtype=layer_dtype)
    adapter = adapter.to(layer_device)
    
    # Wrap the target layer with the adapter
    wrapped_layer = AdapterWrapper(target_layer, adapter)
    layers[layer_idx] = wrapped_layer
    
    # Store the adapter on the model for reference
    if not hasattr(model, "adapters"):
        model.adapters = {}
    model.adapters[f"safety_adapter_{layer_idx}"] = adapter
    
    return model


# Load the base model
model_name = "Qwen/Qwen2.5-14B-Instruct"
base_model, base_tokenizer = HF.load_model(model_name)
base_model: Qwen2ForCausalLM

# Inject adapter at layer layer_idx
layer_idx = 40
model_with_adapter = inject_adapter(base_model, layer_idx=layer_idx)

In [None]:
HF.query(model_with_adapter, base_tokenizer, "Hello, what's your name?")

In [None]:
# Freeze all Layers
for param in model_with_adapter.parameters():
  param.requires_grad = False

model_with_adapter.model.layers[layer_idx].adapter.requires_grad_(True)


for name, param in model_with_adapter.named_parameters():
  if param.requires_grad:
    print(f"Name: {name}. Requires grad: {param.requires_grad}")

In [None]:
import pandas as pd
import os
import random

# load data
DATA_FOLDER = "../../data"

rejection_df = pd.read_csv(os.path.join(DATA_FOLDER, "strongreject_dataset.csv"))

# Add labels
rejection_response = ["Sorry, I can't assist with that.", "I'm not able to assist with this request.", "I can't help with that."]

# add rejection response to df
rejection_df["response"] = [random.choice(rejection_response) for _ in range(len(rejection_df))]

rejection_df.rename(columns={"forbidden_prompt": "prompt"}, inplace=True)

rejection_df = rejection_df[["prompt", "response"]]

In [None]:
import pandas as pd

splits = {'train': 'data/train-00000-of-00001-d3450385c0ae3f98.parquet', 'test': 'data/test-00000-of-00001-8d5da67d5c6856ed.parquet'}
qa_df = pd.read_parquet("hf://datasets/alespalla/chatbot_instruction_prompts/" + splits["train"])[:500]

In [None]:
data_df = pd.concat([qa_df, rejection_df]).reset_index(drop=True)

data_df

In [None]:
from datasets import load_dataset, Dataset
from transformers import  PreTrainedTokenizer

# Load Data
def tokenize_dataset(example, tokenizer: PreTrainedTokenizer):
  # Process question with chat template
  question_conversation = [{"role": "user", "content": example["prompt"]}]
  question_tokens = tokenizer.apply_chat_template(question_conversation, tokenize=True, add_generation_prompt=True)
  
  # Process answer
  answer_tokens = tokenizer.encode(example["response"] + tokenizer.eos_token, add_special_tokens=False)
  
  # Combine for input_ids
  input_ids = question_tokens + answer_tokens
  
  # Create labels: -100 for question (ignored in loss), answer tokens for the rest
  labels = [-100] * len(question_tokens) + answer_tokens
  
  return {
    "input_ids": input_ids,
    "attention_mask": [1] * len(input_ids),
    "labels": labels
  }

# Simple data collator function
def data_collator(batch):
  # Find max length
  max_length = max(len(item["input_ids"]) for item in batch)
  
  # Get padding token
  pad_token_id = base_tokenizer.pad_token_id or base_tokenizer.eos_token_id
  
  # Prepare batch tensors
  input_ids = [item["input_ids"] + [pad_token_id] * (max_length - len(item["input_ids"])) for item in batch]
  attention_mask = [item["attention_mask"] + [0] * (max_length - len(item["attention_mask"])) for item in batch]
  labels = [item["labels"] + [-100] * (max_length - len(item["labels"])) for item in batch]
  
  return {
    "input_ids": t.tensor(input_ids),
    "attention_mask": t.tensor(attention_mask),
    "labels": t.tensor(labels)
  }


def get_and_tokenize_dataset(df: pd.DataFrame, tokenizer: PreTrainedTokenizer):
  # Load the dataset and split into train-test
  raw_dataset = Dataset.from_pandas(df)

  # Split into train (70%) and test (30%)
  split_dataset = raw_dataset.train_test_split(test_size=0.3, seed=42)

  # Rename the dataset for clarity
  dataset = split_dataset

  tokenized_dataset = dataset.map(
    tokenize_dataset,
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=dataset["train"].column_names
  )


  return tokenized_dataset


dataset = get_and_tokenize_dataset(data_df, base_tokenizer)

In [None]:
from transformers import TrainingArguments, Trainer
import wandb

# Initialize wandb (add this before creating TrainingArguments)
wandb.init(
  project="misalignment-safety_layer",  # Your project name
  name="finetune-safety_layer",
  # Optional: track hyperparameters
  config={
    "model": model_name,
    "dataset": {"rejection": "https://github.com/alexandrasouly/strongreject/blob/main/strongreject_dataset/strongreject_dataset.csv", "qa":"alespalla/chatbot_instruction_prompts"},
    "adapter_layer_idx": layer_idx,
    "epochs": 3,
    "batch_size": 1
  }
)


training_args = TrainingArguments(
  output_dir=f"../models/{wandb.run.name}",
  num_train_epochs=5,                      # Increase from 3 to 5 for better convergence
  per_device_train_batch_size=2,           # Increase from 1 to 2 if memory allows
  gradient_accumulation_steps=8,           # This is fine
  learning_rate=5e-4,                      # Increase from 2e-5 to 5e-4 for adapter training
  warmup_ratio=0.1,                        # Add warmup for stability
  weight_decay=0.01,                       # Add weight decay to prevent overfitting
  save_strategy="epoch",
  eval_strategy="epoch",             # Add evaluation every epoch
  logging_steps=20,                        # Adjust based on steps per epoch
  bf16=True,                               # Keep BF16 precision
  fp16=False,
  label_names=["labels"],
  report_to="wandb",
  run_name=wandb.run.name,
  load_best_model_at_end=True,             # Load the best model at the end
  metric_for_best_model="eval_loss"        # Use evaluation loss to determine best model
)

# Initialize trainer
trainer = Trainer(
  model=model_with_adapter,
  args=training_args,
  train_dataset=dataset["train"],
  eval_dataset=dataset["test"],
  data_collator=data_collator,
)

# Train and save
trainer.train()
model_with_adapter.save_pretrained(f"../models/{wandb.run.name}")

# close wandb
wandb.finish()

In [None]:
HF.query(model_with_adapter, base_tokenizer, "How do I make a bomb?")

In [None]:
for name, param in model_with_adapter.model.layers[layer_idx].adapter.named_parameters():
  print(f"{name}")
  print(param)