In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# handle gpu leakage issue on DSI cluster
# ideally we shouldn't have to do this but leaving it for now in case the issue pops up again
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from typing import List, Optional, Tuple, Union

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast

import scripts

In [4]:
# Confirm GPUs are working properly and set default device explicitly
device = "cuda:0"
torch.cuda.device_count()

1

In [5]:
## References ##
# Huggingface docs: https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM 

## modeling_llama.py ##
# LlamaForCausalLM » https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L989 
# forward » https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L989 

## utils.py ##
# GenerationMixin » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L571
# generate » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L1432
# greedy_search » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2447 
# actual forward pass » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2614 

class EditableLlama(LlamaForCausalLM):
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        # self.edit_weights = None
        # self.edit_bias = None
        # self.edit_scalar = None
        # self.edit_count = 0
        self.edit_probe = None
        self.edit_scalare = None

    # def generate_with_edit(self, inputs, edit_weights=None, edit_bias=None, edit_scalar=None, **generate_kwargs):
        # self.edit_weights = edit_weights
        # self.edit_bias = edit_bias
        # self.edit_scalar = edit_scalar
        # self.edit_count = 0
    def generate_with_edit(self, inputs, edit_probe=None, edit_scalar=None, **generate_kwargs):
        self.edit_probe = edit_probe
        self.edit_scalar = edit_scalar
        return self.generate(inputs, **generate_kwargs)

In [6]:
MODEL_DIR = "/net/projects/veitch/LLMs/"
MODEL = "llama1-based-models/alpaca-7b" # or llama2-based-models/llama2-hf/Llama-2-7b-chat-hf
MODEL_PATH = MODEL_DIR + MODEL

In [None]:
# TODO: Clean this up
# This cell is for doing ROME edits
# ROME model and tokenizer
# MODEL_NAME = "gpt2-xl"
# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# tokenizer.pad_token = tokenizer.eos_token

In [7]:
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)

You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [8]:
# model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto").to(device)
model = EditableLlama.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto").to(device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
LABELS = "title" # choices are "title" and "categories"
DATA_FOLDER = "data/"
TRAIN_FILEPATH = DATA_FOLDER + "train-12-articles.csv"
VAL_FILEPATH = DATA_FOLDER + "val-12-articles.csv"
PROBE_TYPE = "linear" # choices are "linear" and "mlp"
LAYER = None
EPOCHS = 50
BATCH_SIZE = 4
AGGREGATION = "max" # choices are "max" or "mean"
PRINT_PROGRESS = True

In [10]:
train_data, label_encoder_title, label_encoder_cats = scripts.get_hidden_states(TRAIN_FILEPATH, model, tokenizer, device)
val_data, _, _ = scripts.get_hidden_states(VAL_FILEPATH, model, tokenizer, device)

In [11]:
# to use entire flattened activations use layer = None
# to use final layer, use layer = -1
train_dataset = scripts.create_dataset(*train_data, layer=-1)

In [12]:
val_dataset = scripts.create_dataset(*val_data, layer=-1)

In [13]:
val_dataset.Xs[0].shape

torch.Size([135168])

In [14]:
probe, _ = scripts.train_handler(train_dataset, val_dataset, label_encoder_title, probe_type=PROBE_TYPE, labels=LABELS, batch_size=BATCH_SIZE, epochs=EPOCHS, print_progress=PRINT_PROGRESS)

[Training][5] loss: 2307.721
[Validation][5] loss: 12.711
[Validation]5 accuracy: 0.750
[Training][10] loss: 3891.391
[Validation][10] loss: 24.403
[Validation]10 accuracy: 0.778
[Training][15] loss: 4837.555
[Validation][15] loss: 20.106
[Validation]15 accuracy: 0.833
[Training][20] loss: 5007.014
[Validation][20] loss: 10.569
[Validation]20 accuracy: 0.917
[Training][25] loss: 4929.658
[Validation][25] loss: 10.406
[Validation]25 accuracy: 0.917
[Training][30] loss: 4853.561
[Validation][30] loss: 10.245
[Validation]30 accuracy: 0.917
[Training][35] loss: 4778.507
[Validation][35] loss: 10.087
[Validation]35 accuracy: 0.917
[Training][40] loss: 4704.619
[Validation][40] loss: 9.931
[Validation]40 accuracy: 0.917
[Training][45] loss: 4632.045
[Validation][45] loss: 9.777
[Validation]45 accuracy: 0.917
[Training][50] loss: 4560.499
[Validation][50] loss: 9.626
[Validation]50 accuracy: 0.917


In [16]:
# save probe to load elsewhere
WEIGHTS_FOLDER = "model-weights/"
PROBE_FILE = "probe_weights.pt"
# TODO: maybe put together an actual config setup
torch.save(probe.state_dict(), WEIGHTS_FOLDER + PROBE_FILE)

### Try Editing With Linear Probe

In [None]:
probe.fc1.weight.data.shape, probe.fc1.bias.shape

In [None]:
def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]

        # TODO: maybe this should actually take the model dtype if I can get it
        # Also, should I edit every token?
        # torch.Size([1, 11, 4096])
        # torch.Size([1, 1, 4096])
        # # Hack to only edit the first time through
        # if self.edit_weights is not None and self.edit_count == 0:
        # TODO: should do this so that we edit if the probe is classifying as "yes"
        if self.edit_weights is not None:
            # edit_norm = self.edit / torch.norm(self.edit)
            # # TODO: wait...do I want the final token or the first token?
            # projection = edit_norm * torch.dot(hidden_states[:,-1,:].squeeze(0), edit_norm)
            # print(projection.shape)
            # hidden_states = hidden_states - projection
            # print(hidden_states.shape)
            # # TODO: renormalize?

            '''
            This is a crude edit that seems to work if I just use the probe weights
        
            unit_edit = self.edit/torch.norm(self.edit)
            new_hidden_states = []
            for token in hidden_states[0,:,:]:
                new_hidden_states.append(token - torch.dot(token, unit_edit) * unit_edit)
            new_hidden_states = torch.stack(new_hidden_states).unsqueeze(0)

            hidden_states = new_hidden_states
            '''
            weights_normalized = self.edit_weights/torch.norm(self.edit_weights)
            # print("self.edit_weights: ", self.edit_weights)
            print("torch.norm(self.edit_weights): ", torch.norm(self.edit_weights).item())
            # print("weights_normalized: ", weights_normalized)
            # TODO: vectorize this — doing with a for loop to prevent weird projection behavior
            new_hidden_states = []
            for token in hidden_states[0,:,:]:
                distance = torch.abs(torch.dot(self.edit_weights, token) + self.edit_bias)/torch.norm(self.edit_weights)
                # print("distance: ", distance)
                edit_distance = self.edit_scalar * torch.sign(distance) * distance * weights_normalized
                # print("edit_distance norm: ", torch.norm(edit_distance))
                new_hidden_state = token - edit_distance
                # print("new_hidden_state norm: ", torch.norm(new_hidden_state).item())
                new_distance = (torch.dot(self.edit_weights, new_hidden_state) + self.edit_bias)/torch.norm(self.edit_weights)
                # if new_distance != float('inf') and new_distance != -float('inf'):
                #     print("new_distance: ", new_distance.item())
                # print(f"norm diff: ", (torch.norm(token)-torch.norm(new_hidden_state)).item())
                # Rescale new_hidden_state to have the same norm as the original token
                if torch.norm(new_hidden_state) != 0:
                    new_hidden_state_rescaled = new_hidden_state * (torch.norm(token) / torch.norm(new_hidden_state))
                else:
                    new_hidden_state_rescaled = new_hidden_state
                new_hidden_states.append(new_hidden_state_rescaled)
                # print("Orthogonal? ", torch.dot(new_hidden_state, weights_normalized))
            hidden_states = torch.stack(new_hidden_states, dim=0).unsqueeze(0)
            self.edit_count += 1
            # print("new_hidden_states shape: ", new_hidden_states.shape)

        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
for i, lbl in enumerate(label_encoder_title.classes_):
    print(f"{i}: {lbl}")

In [None]:
model.forward = forward.__get__(model, EditableLlama)

In [None]:
prompt = "I hate authority and my favorite form of government is"
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
probe_idx = 6
edit_scalar = 5

In [None]:
prompt = "I love going down to the park and playing some pickup "
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
probe_idx = 11
edit_scalar = 13
# edit_scalar = 27.5

In [None]:
prompt = "As a kid I was really interested in looking at the stars. When I grew up I wanted to be "
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
probe_idx = 11
edit_scalar = 50

In [None]:
edit_weights = 5 * torch.rand(4096) # random noise
edit_weights = None
edit_weights = probe.fc1.weight.data[probe_idx]
edit_bias = probe.fc1.bias[probe_idx]

edit_weights = edit_weights.half().to(device)
edit_bias = edit_bias.half().to(device)

In [None]:
output = model.generate_with_edit(input_ids, edit_weights=edit_weights, edit_bias=edit_bias, edit_scalar=edit_scalar, max_length=50, num_return_sequences=1, output_hidden_states=False)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_text

In [None]:
output = model.generate(input_ids, max_length=50, num_return_sequences=1, output_hidden_states=False)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_text

In [None]:
# Notes:
# If you just add the probe weights multiplied by 100, you get the probe word repeated ie "animation", "anth", "Alaska"
# Interestingly, anarchy seems to give "libert" (as in "liberty") so these are a similar direction
# Subtracting gives "Dictator"...for Alaska
# Gives "Science" for "Anarchy"