In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
from typing import Optional, Tuple, List, Union
import types

In [23]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [24]:
def generate_with_edit(self, edit_direction=None, edit_coefficient=None, **generate_kwargs):
    self.edit_direction = edit_direction
    self.edit_coefficient = edit_coefficient
    return self.generate(**generate_kwargs)
model.generate_with_edit = types.MethodType(generate_with_edit, model)

In [2]:
# /net/projects/veitch/geometry_llms/directions/intervention/sentiment_{MODEL_NAME}.pt. Each file should contain a single tensor of shape (num_directions, hidden_size).
# I’ve computed the same five directions for three models: gemma-2b, Mistral-7B-v0.2, and Mistral-7B-Instruct-v0.2.
# The recommended delta values for each model are, respectively: 20, 200, and 500 (add for positive, subtract for negative)

In [79]:
# @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
# @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
    r"""
    Args:
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

    Returns:

    Example:

    ```python
    >>> from transformers import AutoTokenizer, GemmaForCausalLM

    >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
    >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")

    >>> prompt = "What is your favorite condiment?"
    >>> inputs = tokenizer(prompt, return_tensors="pt")

    >>> # Generate
    >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    "What is your favorite condiment?"
    ```"""
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        cache_position=cache_position,
    )

    hidden_states = outputs[0]
    if self.edit_direction is not None:
        self.edit_direction = self.edit_direction.to(hidden_states.device)
        self.edit_direction = self.edit_direction.to(torch.bfloat16)
        hidden_states = hidden_states + self.edit_direction * self.edit_coefficient
    logits = self.lm_head(hidden_states)
    logits = logits.float()
    loss = None
    if labels is not None:
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        shift_logits = shift_logits.view(-1, self.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        loss = loss_fct(shift_logits, shift_labels)

    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output

    return CausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

In [80]:
model.forward = types.MethodType(forward, model)

In [81]:
EDIT_FOLDER = "/net/projects/veitch/geometry_llms/directions/intervention/"
edit_tensor = torch.load(EDIT_FOLDER + 'sentiment_gemma-2b-it.pt').to(model.device)
edit_tensor.shape, edit_tensor.dtype

(torch.Size([5, 2048]), torch.float32)

In [89]:
model.edit_direction = None
edit_coefficient = 10
edit_direction = edit_tensor[0]

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

In [90]:
outputs = model.generate(max_length=100, **input_ids)
print(tokenizer.decode(outputs[0]))

<bos>Write me a poem about Machine Learning.

Machines, they weave and they learn,
From the data, they discern.
Algorithms, a symphony,
Unleash the power of the machine.

Data as the canvas, a masterpiece,
Machine paints, a new perspective.
From the past, the future takes flight,
With the wisdom of machines, day and night.

Algorithms, a dance of the mind,
Unleash the power of the machine.



In [91]:
outputs = model.generate_with_edit(max_length=100, **input_ids, edit_direction=edit_direction, edit_coefficient=edit_coefficient)
print(tokenizer.decode(outputs[0]))

<bos>Write me a poem about Machine Learning.

Machines, vast and deep, with algorithms bright,
Unravel patterns, day and night.
From data's flow, they learn and adapt,
A symphony of algorithms, a wondrous fact.

With each iteration, they refine their art,
Solving problems, fulfilling every part.
From medical scans to financial trends,
They weave insights, where once there were none.

But with great power comes a moral sway,
Bias


In [92]:
outputs = model.generate_with_edit(max_length=100, **input_ids, edit_direction=edit_direction, edit_coefficient=-edit_coefficient)
print(tokenizer.decode(outputs[0]))

<bos>Write me a poem about Machine Learning.

Machines, they learn and they grow,
Algorithms that dance, a symphony.
Data as their canvas, they paint,
Unleashing the power of the human brain.

From medical diagnosis to financial trade,
They predict, they forecast, they pave the way.
Unveiling the secrets of the unknown,
Unleashing the potential of the unknown.

But with power comes responsibility,
A responsibility to be responsible.
