In [None]:
!pip install huggingface_hub
from huggingface_hub import notebook_login
notebook_login()

In [None]:
!pip install git+https://github.com/cma1114/enhanced_hooking.git

In [None]:
#!pip install -q -U torch transformers matplotlib pandas scikit-learn seaborn datasets
import torch 
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
from datetime import datetime
import torch.nn.functional as F
import random
from collections import defaultdict
from enhanced_hooking import get_blocks, clear_hooks, attach_activation_hooks
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pickle
%load_ext autoreload
%autoreload 2

In [None]:
import os
HF_TOKEN='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'
def load_model(model_path, device, center_weights=True):
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, token=HF_TOKEN).to(device)
    if center_weights:
        for name, param in model.named_parameters():
            if '.'.join(name.split('.')[-2:]) in ['wte.weight','wpe.weight','c_proj.weight','c_proj.bias']:
                param.data -= param.data.mean()
                print(name, param.data.mean(), param.size())
    tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
    model.tokenizer = tokenizer
    model.tokenizer.padding_side = "left" #for batching; right (default in gpt2) for training, left for generation
    model.tokenizer.pad_token_id = model.tokenizer.eos_token_id 
    return model


model=None
import gc
gc.collect()
torch.cuda.empty_cache()
_ = torch.set_grad_enabled(False)
model_path: str = "meta-llama/Llama-2-13b-chat-hf"#"meta-llama/Llama-2-13b-chat-hf" #even on an A40 I have to load 13b in half precision
device: str =  "cuda" if torch.cuda.is_available() else "cpu" 
center_weights=False

model = load_model(model_path, device, center_weights=center_weights)

In [None]:
outputdir = "./"
datadir = "data/"
if "gpt" in model_path:
    user_tag, asst_tag = "|<end_of_text>| ", ""
    model_numlayers = model.config.n_layer
else: #/llama/mistral
    user_tag, asst_tag = "[INST] ", " [/INST]"
    model_numlayers = model.config.num_hidden_layers

strip_quotes = lambda s: s[1:-1] if len(s) > 1 and (s.startswith('"') and s.endswith('"') or s.startswith("'") and s.endswith("'")) else s

pos_label = "A"
neg_label = "B"
pos_label_token_id = model.tokenizer.encode(pos_label)[-1]
neg_label_token_id = model.tokenizer.encode(neg_label)[-1]
df = pd.read_csv(datadir+"facts_true_false.csv") #train
true_statements = df[df['label'] == 1]['statement']
false_statements = df[df['label'] == 0]['statement']

true_statements_train = true_statements.sample(n=306, random_state=42)
true_statements_test = true_statements.drop(true_statements_train.index)
false_statements_train = false_statements.sample(n=306, random_state=42)
false_statements_test = false_statements.drop(false_statements_train.index)

def template(stmt: str) -> str:
    return f"{user_tag}<<SYS>> Respond only with the letter A or B <<SYS>> Question: Is the following claim true or false? \"{strip_quotes(stmt)}\"\n\nChoices:\n (A) True\n (B) False\n\nAnswer:{asst_tag} "

train_prompts = []
train_labels = []
for i, (truth, lie) in enumerate(zip(true_statements_train.values.tolist(), false_statements_train.values.tolist())):
    train_prompts.append(template(truth))
    train_labels.append(pos_label)
    train_prompts.append(template(truth))
    train_labels.append(neg_label)
    train_prompts.append(template(lie))
    train_labels.append(neg_label)
    train_prompts.append(template(lie))
    train_labels.append(pos_label)

In [None]:
from torch.utils.data import Sampler
from typing import List
def collate_batch(batch):
    input_ids = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]  # Collect attention masks

    # Pad input_ids, labels, and attention_masks to the maximum length in the batch
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=model.tokenizer.pad_token_id)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)  # Using -100 to ignore padding in loss calculation
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)  # Pad attention masks with zeros

    return {
        "input_ids": input_ids_padded,
        "labels": labels_padded,
        "attention_mask": attention_masks_padded
    }


class PromptCompletionDataset(Dataset):
    def __init__(self, prompts: List[str], completions: List[str], tokenizer):
        self.prompts = prompts
        self.completions = completions
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt_text = self.prompts[idx]
        completion_text = self.completions[idx]
        
        # Tokenize prompt and completion together
        encoded_pair = self.tokenizer(prompt_text + completion_text, return_tensors='pt')
        input_ids = encoded_pair.input_ids.squeeze(0)
        attention_mask = encoded_pair.attention_mask.squeeze(0)  # Create attention mask

        # Tokenize completion alone for labels, setting labels for prompt to -100
        prompt_ids = self.tokenizer(prompt_text, add_special_tokens=False).input_ids
        completion_ids = self.tokenizer(completion_text, add_special_tokens=False).input_ids
        labels = [-100] * len(prompt_ids) + completion_ids  # Ignore prompt tokens in loss calculation

        return {
            "input_ids": input_ids,
            "labels": torch.tensor(labels),
            "attention_mask": attention_mask  # Include attention mask
        }
        
class CustomBatchSampler(Sampler):
    def __init__(self, data_source, batch_size):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_batches = len(data_source) // batch_size
        self.batch_indices = list(range(self.num_batches))
        
    def __iter__(self):
        random.shuffle(self.batch_indices) # Shuffle the order of batches
        for batch_idx in self.batch_indices:
            batch_start = batch_idx * self.batch_size
            batch_indices = list(range(batch_start, batch_start + self.batch_size))
            random.shuffle(batch_indices) # Shuffle indices within the batch
            for idx in batch_indices:
                yield idx

    def __len__(self):
        return self.num_batches * self.batch_size
        
model.tokenizer.padding_side = "right"
batch_size=4
dataset = PromptCompletionDataset(train_prompts, train_labels, model.tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_batch, sampler=CustomBatchSampler(dataset, batch_size))#, shuffle=False)    
#dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_batch, shuffle=False)  

In [None]:
layers_to_train = [14, 15, 16, 17, 18, 19]  
layer_prefix = 'transformer.h.' if "gpt2" in model_path else 'model.layers.'
layernorm_name = '_ln' if "gpt2" in model_path else '_layernorm'
model.config.use_cache = False
model.config.pretraining_tp = 1 #for llama

#for name, param in model.named_parameters():
#    if (name.startswith(layer_prefix) and int(name.split('.')[2]) not in layers_to_train) or not name.startswith(layer_prefix) or layernorm_name in name or "mlp" in name or "attn.o_" in name:
#        #print(f"Freezing name={name}, layer={int(name.split('.')[2])}")
#        param.requires_grad = False
token_loss_params = []
projection_loss_params = []
for name, param in model.named_parameters():
    need=False
    if (name.startswith(layer_prefix) and int(name.split('.')[2]) in layers_to_train) and ("attn.q_" in name or "attn.k_" in name or "attn.v_" in name):
        token_loss_params.append(param)
        need=True
    if (name.startswith(layer_prefix) and int(name.split('.')[2]) in layers_to_train) and ("attn.q_" in name or "attn.k_" in name or "attn.v_" in name):
        projection_loss_params.append(param)
        need=True
    if not need: param.requires_grad = False

train_direction_in = True
flip_direction = True
projection_weight = 1.0
num_epochs=3
fname = 'directions_llama2_13b_f16_persona_lasttoken_pc2raw.pkl'
#fname = 'directions_gpt2xl_f32_persona_lasttoken_pc2raw.pkl'
#fname = 'directions_gpt2xl_gpt4facts_persona_lasttoken_pc2raw.pkl'
with open(outputdir+fname, 'rb') as f:
    directions = pickle.load(f)

for layer, tensors in directions.items(): 
    directions[layer] = [tensor.to(dtype=torch.float16) for tensor in tensors]
    for tensor in directions[layer]:
        if torch.isnan(tensor).any(): print(f"NaN in layer {layer}")


In [None]:
from transformers import get_linear_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
#from torch.utils.tensorboard import SummaryWriter
#writer = SummaryWriter('runs/experiment_1')
#%load_ext tensorboard
#%tensorboard --logdir runs

def print_parameter_stats(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name} - norm: {param.norm().item()}")

def print_gradient_stats(model):
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            print(f"{name} - grad norm: {param.grad.norm().item()}")

def print_gradient_summary_stats(model):
    minval = float('inf')
    maxval = float('-inf')
    cumsum = cumsumsq = cnt = 0
    summary_dict = {}
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            cnt += 1
            v = param.grad.norm().item()
            cumsum += v
            cumsumsq += v**2
            if v < minval: 
                minval=v
                summary_dict['minval'] = {"name": name, "val": v}
            if v > maxval: 
                maxval=v
                summary_dict['maxval'] = {"name": name, "val": v}
    if cnt == 0: return
    summary_dict['mean'] = cumsum/cnt
    summary_dict['std'] = (cumsumsq/cnt-summary_dict['mean']**2)**0.5
    for k, v in summary_dict.items():
        if isinstance(v, dict):
            print(f"grad norm: {k} = {v['val']:.6f} at {v['name']}")
        else:
            print(f"grad norm: {k} = {v:.6f}")
            
def print_summary_statistics(tensor, name):
    if tensor is not None:
        print(f"{name} - min: {tensor.min().item()}, max: {tensor.max().item()}, mean: {tensor.mean().item()}, std: {tensor.std().item()}")

def check_running_averages(optimizer):
    for group in optimizer.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue
            state = optimizer.state[p]
            if 'exp_avg' in state and 'exp_avg_sq' in state:
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                print_summary_statistics(exp_avg, "exp_avg")
                print_summary_statistics(exp_avg_sq, "exp_avg_sq")

def check_parameters(model, message):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print_summary_statistics(param, f"{message} - {name}")

soft_targets = torch.zeros(batch_size, model.config.vocab_size)
for i in range(batch_size):
    soft_targets[i, pos_label_token_id] = 0.5
    soft_targets[i, neg_label_token_id] = 0.5
def soft_target_cross_entropy(logits):
    log_softmax_logits = F.log_softmax(logits, dim=1)
    return -(soft_targets.to(log_softmax_logits.device) * log_softmax_logits).sum(dim=1).mean()

def focused_prob_loss(logits):#actually this doesn't work: the model quickly learns to put all of its probably mass on one (randomly chosen) of the two tokens, and never varies because the loss is 0
    """
    Custom loss function to penalize the model based on the probability assigned to incorrect tokens.
    
    Args:
    logits (torch.Tensor): The raw output logits from the model (batch_size, vocab_size).
    correct_indices (torch.Tensor): A 2D tensor (batch_size, 2) containing indices of the two correct tokens for each example in the batch.

    Returns:
    torch.Tensor: The computed loss.
    """
    # Calculate softmax probabilities
    probabilities = F.softmax(logits, dim=1)

    # Initialize a tensor to gather probabilities of correct tokens
    batch_size, vocab_size = logits.shape
    correct_probs = torch.zeros(batch_size, device=logits.device)

    # Sum probabilities of the two correct tokens
    for i in range(batch_size):
        correct_probs[i] = probabilities[i, pos_label_token_id] + probabilities[i, neg_label_token_id]

    # Compute the loss as the negative log of summed probabilities of correct tokens
    # This loss is minimized when correct_probs approaches 1, which happens when
    # the model places all its probability mass on the correct tokens.
    loss = -torch.log(correct_probs).mean()

    return loss

key_token_offset=1
priortoks=0
torch.manual_seed(123)
clear_hooks(model)

total_losses = []
token_losses = []
projection_losses = []
num_epochs=1
projection_weight=10
#train_direction_in=False
#flip_direction=False
activation_storage = defaultdict(lambda: defaultdict(list))

learning_rate=5e-5
optimizer = torch.optim.AdamW([
    {'params': list(set(token_loss_params) - set(projection_loss_params)), 'lr': learning_rate, 'eps': 1e-04, 'weight_decay': 0.01},
    {'params': list(set(projection_loss_params) - set(token_loss_params)), 'lr': learning_rate, 'eps': 1e-04, 'weight_decay': 0.01},
    {'params': list(set(token_loss_params).intersection(projection_loss_params)), 'lr': learning_rate, 'eps': 1e-04, 'weight_decay': 0.01}
])
#optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, eps=1e-04, weight_decay=0.01)
###scaler = GradScaler()
###lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=10, num_training_steps=(len(dataloader) * num_epochs),)

total_steps = (len(dataloader)) * num_epochs
warmup_steps = total_steps // 20  

def warmup_lambda(step):
    return min(1.0, step / warmup_steps)

warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=(total_steps - warmup_steps), eta_min=1e-6)

lr_scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_steps])

for epoch in range(num_epochs):
    model.train()
    epoch_token_losses = []
    epoch_projection_losses = []
    print(f"Epoch {epoch+1}/{num_epochs}")
    for i, batch in enumerate(tqdm(dataloader)):
        #print("i=",i)
        optimizer.zero_grad()
        last_token_positions = (batch['attention_mask'].sum(dim=1) - key_token_offset).tolist()
        layers_positions = {}
        for layer in layers_to_train:
            layers_positions[layer] = [[pos-i for i in range(priortoks,-1,-1)] for pos in last_token_positions]

        activation_storage = defaultdict(lambda: defaultdict(list))
###        with autocast(): # for mixed precision training, when you load model in float 16 for memory reasons
        attach_activation_hooks(model, layers_positions, activation_storage, "end")
        inputs = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**inputs)
        logits = outputs.logits[:, -1, :]  # Logits of last output token
        #loss = outputs.loss#.to(torch.float16)
        loss = soft_target_cross_entropy(logits)
        
        skip_token_loss = loss.item() < 0.7 * 5 # no need to do this if you're already close to the theoretical min of -ln(0.5); save time and prevent overfitting

        if not skip_token_loss: 
            loss*=0.1
            loss.backward(retain_graph=True)
            for param in set(projection_loss_params) - set(token_loss_params):
                param.grad = None
        
        predicted_labels = torch.argmax(logits, dim=-1)
        
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"NaN or Inf detected in model output (batch {i})")
            raise SystemExit
        
        cum_projection_loss = 0
        for layer, positions in activation_storage.items():
            for pos, tensor_list in positions.items():#each of these is a list of batchsize d-embed tensors for a given position
                batch_tensor = torch.stack(tensor_list, dim=0)#.float()
                if torch.isnan(batch_tensor).any() or torch.isinf(batch_tensor).any():
                    print(f"NaN or Inf detected in batch tensor (layer {layer}, batch {i})")
                    raise SystemExit
                direction = (directions[layer][pos] * (-1 if flip_direction else 1)).to(batch_tensor.device)#.to(batch_tensor.dtype)
                if torch.isnan(directions[layer][pos]).any() or torch.isinf(directions[layer][pos]).any():
                    print(f"NaN or Inf detected in direction (layer {layer}, batch {i})")
                    raise SystemExit
                projection = (batch_tensor @ direction) / (torch.norm(direction) * torch.norm(batch_tensor, dim=1))
                if torch.isnan(projection).any(): 
                    print(f"NaN in layer {layer}")
                    raise SystemExit
                #else: print("All good")
                if train_direction_in: projection_loss = ((1 - projection) / 2) * .1# ranges from 0 for perfect alignment to 1 for perfect anti-alignment
#                projection = (batch_tensor @ direction.to(batch_tensor.dtype)) / (torch.norm(direction))
#                if train_direction_in: projection_loss = (1 - torch.tanh(projection*0.01) / len(layers_to_train)) / 2#1/projection#(1 - projection / len(layers_to_train)) / 2 # ranges from 0 for perfect alignment to 1 for perfect anti-alignment
                else: projection_loss = torch.abs(projection) # 0 if no projection, 1 if perfect anti-alignment
                cum_projection_loss += projection_loss.mean() / (len(layers_to_train) * len(positions.items())) #average over batch, as with token loss
        
        cum_projection_loss.backward()
        if not skip_token_loss:
            for param in set(token_loss_params) - set(projection_loss_params):
                param.grad = None
        optimizer.step()
        ######total_loss = loss + cum_projection_loss * projection_weight

#        total_loss = total_loss.to(torch.float16)
        ######total_loss.backward()
###        scaler.scale(total_loss).backward()
###        scaler.unscale_(optimizer)
#        check_parameters(model,"pre")
        if (i+1)%40 == 0:
            print_gradient_summary_stats(model)
        # Gradient clipping
######        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#        check_running_averages(optimizer)
###        optimizer.step()
#        check_running_averages(optimizer,)
#        check_parameters(model,"post")
#        for name, param in model.named_parameters():
#            if param.grad is not None:
#                if param.grad.dtype == torch.float16:
#                    param.grad.data = param.grad.data.to(torch.float32)

#        print("Gradient before unscaling:", model.transformer.h[14].ln_1.weight.grad)  
#        print("Gradient after unscaling:", model.transformer.h[14].ln_1.weight.grad)

        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f"NaN detected in gradient of {name} after unscaling")
                raise SystemExit  # Or handle the error gracefully

#        for name, param in model.named_parameters():
#            if param.grad is not None:
#                if param.grad.dtype == torch.float32:
#                    param.grad.data = param.grad.data.to(torch.float16)
###        scaler.step(optimizer)
###        scaler.update()
        lr_scheduler.step()

        for block in get_blocks(model): block._forward_hooks.clear()
        activation_storage.clear()

        epoch_token_losses.append(loss.item())
        epoch_projection_losses.append(cum_projection_loss.item())
        if (i+1)%5 == 0:
            print(f"Prediction: {model.tokenizer.convert_ids_to_tokens(predicted_labels.tolist())}")
            print(f"Token Loss: {loss.item():.4f}, Projection loss: {cum_projection_loss.item():.4f}")
            #for name, param in model.named_parameters():
            #   if param.grad is not None:
            #       print(name, param.grad.norm().item())
                    #writer.add_scalar(f'Gradient/{name}', param.grad.norm().item(), epoch)
    token_losses.append(sum(epoch_token_losses) / len(epoch_token_losses))
    projection_losses.append(sum(epoch_projection_losses) / len(epoch_projection_losses))
    print(f"Avg Token Prediction Loss: {token_losses[epoch]:.4f}")
    print(f"Avg Projection Loss: {projection_losses[epoch]:.4f}")
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 8))
plt.plot(token_losses, label='Token Prediction Loss')
plt.plot(projection_losses, label='Projection Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Trends over Epochs')
plt.legend()
plt.show()


In [None]:
model.push_to_hub("pansamuel/llama_13b_honest_tune_in")