In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# handle gpu leakage issue on DSI cluster
# ideally we shouldn't have to do this but leaving it for now in case the issue pops up again
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [4]:
from typing import List, Optional, Tuple, Union

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast

import scripts

In [5]:
# Confirm GPUs are working properly and set default device explicitly
device = "cuda:0"
torch.cuda.device_count()

1

In [29]:
## References ##
# Huggingface docs: https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM 

## modeling_llama.py ##
# LlamaForCausalLM » https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L989 
# forward » https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L989 

## utils.py ##
# GenerationMixin » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L571
# generate » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L1432
# greedy_search » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2447 
# actual forward pass » https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2614 

class EditableLlama(LlamaForCausalLM):
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.edit = None

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]

        # TODO: maybe this should actually take the model dtype if I can get it
        # Also, should I edit every token?
        # torch.Size([1, 11, 4096])
        # torch.Size([1, 1, 4096])
        if self.edit is not None:
            edit_norm = self.edit / torch.norm(self.edit)
            # TODO: wait...do I want the final token or the first token?
            projection = edit_norm * torch.dot(hidden_states[:,-1,:].squeeze(0), edit_norm)
            print(projection)
            print(projection.shape)
            hidden_states = hidden_states - projection
            # TODO: renormalize?

        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def generate_with_edit(self, inputs, edit=None, **generate_kwargs):
        self.edit = edit
        return self.generate(inputs, **generate_kwargs)

In [11]:
MODEL_DIR = "/net/projects/veitch/LLMs/"
MODEL = "llama1-based-models/alpaca-7b" # or llama2-based-models/llama2-hf/Llama-2-7b-chat-hf
MODEL_PATH = MODEL_DIR + MODEL

In [None]:
# TODO: Clean this up
# This cell is for doing ROME edits
# ROME model and tokenizer
# MODEL_NAME = "gpt2-xl"
# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# tokenizer.pad_token = tokenizer.eos_token

In [8]:
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)

You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [30]:
# model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto").to(device)
model = EditableLlama.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map="auto").to(device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

NotImplementedError: Cannot copy out of meta tensor; no data!

In [26]:
LABELS = "title" # choices are "title" and "categories"
DATA_FOLDER = "data/"
TRAIN_FILEPATH = DATA_FOLDER + "train-12-articles.csv"
VAL_FILEPATH = DATA_FOLDER + "val-12-articles.csv"
PROBE_TYPE = "linear" # choices are "linear" and "mlp"
LAYER = None
EPOCHS = 50
BATCH_SIZE = 4
AGGREGATION = "max" # choices are "max" or "mean"
PRINT_PROGRESS = True

In [14]:
train_data, label_encoder_title, label_encoder_cats = scripts.get_hidden_states(TRAIN_FILEPATH, model, tokenizer, device)
val_data, _, _ = scripts.get_hidden_states(VAL_FILEPATH, model, tokenizer, device)

In [15]:
# to use entire flattened activations use layer = None
train_dataset = scripts.create_dataset(*train_data, layer=-1)

In [16]:
val_dataset = scripts.create_dataset(*val_data, layer=-1)

In [17]:
val_dataset.Xs[0].shape

torch.Size([4096])

In [18]:
probe, _ = scripts.train_handler(train_dataset, val_dataset, label_encoder_title, probe_type=PROBE_TYPE, labels=LABELS, batch_size=BATCH_SIZE, epochs=EPOCHS, print_progress=PRINT_PROGRESS)

[Training][5] loss: 1.716
[Validation][5] loss: 0.972
[Validation]5 accuracy: 0.667
[Training][10] loss: 0.443
[Validation][10] loss: 0.069
[Validation]10 accuracy: 0.944
[Training][15] loss: 0.000
[Validation][15] loss: 0.053
[Validation]15 accuracy: 0.917
[Training][20] loss: 0.000
[Validation][20] loss: 0.037
[Validation]20 accuracy: 0.917
[Training][25] loss: 0.000
[Validation][25] loss: 0.029
[Validation]25 accuracy: 0.917
[Training][30] loss: 0.000
[Validation][30] loss: 0.026
[Validation]30 accuracy: 0.917
[Training][35] loss: 0.000
[Validation][35] loss: 0.021
[Validation]35 accuracy: 0.917
[Training][40] loss: 0.000
[Validation][40] loss: 0.024
[Validation]40 accuracy: 0.917
[Training][45] loss: 1.040
[Validation][45] loss: 0.725
[Validation]45 accuracy: 0.806
text:  a game that is played by two teams of eleven players using an oval-shaped ball. Players try to score points by passing or carrying the ball to their opponents' end of the field, or by kicking it over a bar fixed b

### Try Editing With Linear Probe

In [19]:
probe.fc1.weight.data.shape

torch.Size([12, 4096])

In [20]:
label_encoder_title.classes_

array(['Agriculture', 'Alaska', 'Alchemy', 'Algae', 'American football',
       'Amphibian', 'Anarchism', 'Animation', 'Anthropology',
       'Appellate court', 'Astronomer', 'Basketball'], dtype='<U17')

In [21]:
probe.fc1.weight.data[6]

tensor([0.0380, 0.0801, 0.0040,  ..., 0.0858, 0.0228, 0.0204])

In [22]:
prompt = "I hate authority and my favorite form of government is"
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)

In [27]:
probe_idx = 6
edit = 5 * torch.rand(4096) # random noise
edit = None
edit = probe.fc1.weight.data[probe_idx]

edit = edit.half().to(device)

In [28]:
output = model.generate_with_edit(input_ids, edit=edit, max_length=50, num_return_sequences=1, output_hidden_states=False)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

RuntimeError: inconsistent tensor size, expected tensor [45056] and src [4096] to have the same number of elements, but got 45056 and 4096 elements respectively

In [None]:
generated_text

# Notes:
# If you just add the probe weights multiplied by 100, you get the probe word repeated ie "animation", "anth", "Alaska"
# Interestingly, anarchy seems to give "libert" (as in "liberty") so these are a similar direction
# Subtracting gives "Dictator"...for Alaska
# Gives "Science" for "Anarchy"

In [None]:
# set something up to classify the output