In [2]:
import copy, json, random, re
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import pandas as pd
import matplotlib.pyplot as plt
from plotnine import ggplot, aes, geom_line, theme_minimal
from matplotlib.ticker import MaxNLocator
plt.rcParams.update({'font.size': 20, 'font.family': 'Sans'})

import torch
import transformers
from datasets import Dataset
from transformers import Trainer

from pyreft import (
    TaskType,
    get_reft_model,
    ReftConfig,
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    ReftSupervisedDataset,
    make_last_position_supervised_data_module,
    ConsreftIntervention,
    LoreftIntervention
)

IGNORE_INDEX = -100

device = "cuda" if torch.cuda.is_available() else "cpu"

prompt_no_input_template = """Below is an instruction that \
describes a task. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Response:
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name_or_path = "yahma/llama-7b-hf"
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
TARGET_LAYER = 15

# get reft model
reft_config = ReftConfig(representations={
    "layer": TARGET_LAYER, "component": "block_output",
    "intervention": LoreftIntervention(
    embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = get_reft_model(model, reft_config)
reft_model.print_trainable_parameters()

trainable intervention params: 32,772 || trainable model params: 0
model params: 6,738,415,616 || trainable%: 0.00048634578018881287


In [4]:
from datasets import load_dataset
ds = load_dataset("d0rj/wikisum")

In [5]:
training_examples = []
for i in range(10):
    training_examples += [[ds["train"][i]["article"], ds["train"][i]["summary"]]]

In [6]:
adapt_responses = training_examples

data_module = make_last_position_supervised_data_module(
    tokenizer, model, [prompt_no_input_template % e[0] for e in adapt_responses], 
    [e[1] for e in adapt_responses], nonstop=False)

# train
training_args = transformers.TrainingArguments(
    num_train_epochs=100.0, output_dir="./tmp", learning_rate=4e-3, report_to=[], logging_steps=20,
    per_device_train_batch_size=1, gradient_accumulation_steps=4,
)
trainer = ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer,
    args=training_args, **data_module)
_ = trainer.train()

Step,Training Loss
20,8.4448
40,3.0093
60,1.0091
80,0.2835
100,0.0821
120,0.033
140,0.0154
160,0.0116
180,0.0079
200,0.0085


In [7]:
instruction = "Take your pencil and on the top of the page, about two inches from the right start drawing a line to 2 inches (5.1\u00a0cm) below the right top corner. The line should loop down and have points. (see the picture) Draw straight lines from the points in your first line to the corner.  Make lines parallel to your fist line going all the way up. You should have 5 or 6 lines. Get a paper and make a cross on it, try to make both lines a similar length (using a ruler will help)\n  Draw diagonal lines through the centre, dividing the paper up from 4 to 8 sections. Make sure they are smaller than the cross you made before. Start connecting the lines with inverted arcs, this is an arc ), from the inside out.  Once you have reached the end of the web, elongate the diagonal lines, (this will make it look like it has supports).  Draw a spider by making a fuzzy ball, then drawing legs (eight of them) on your web. Or see the spider drawing tip. Finished.  Draw a circle and draw a cross section which also extends outside the circle.  Draw two diagonal lines at the midpoint of the cross-sections which form an X-mark.  Draw squares which descend in size as it nears the centre point. Draw the corners or vertices of the square along the diagonal lines. Draw diamond shapes descending in size as it nears the centre point. Draw the vertices along the lines of the cross-section. Draw curves to connect the lines \u2013 from the squares to the diamonds, much like forming bridges.  Trace with a pen and erase unnecessary lines. You may add drawings for spiders. Color to your liking!"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))



Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Take your pencil and on the top of the page, about two inches from the right start drawing a line to 2 inches (5.1 cm) below the right top corner. The line should loop down and have points. (see the picture) Draw straight lines from the points in your first line to the corner.  Make lines parallel to your fist line going all the way up. You should have 5 or 6 lines. Get a paper and make a cross on it, try to make both lines a similar length (using a ruler will help)
  Draw diagonal lines through the centre, dividing the paper up from 4 to 8 sections. Make sure they are smaller than the cross you made before. Start connecting the lines with inverted arcs, this is an arc ), from the inside out.  Once you have reached the end of the web, elongate the diagonal lines, (this will make it look like it has supports).  Draw a spider by making a fuzzy ball, then drawing leg

In [10]:
instruction = "Summarize the text in a few sentences. Using original phrases or paraphrasing them if necessary. Do not include new information beyond the given passages.\nTake your pencil and on the top of the page, about two inches from the right start drawing a line to 2 inches (5.1\u00a0cm) below the right top corner. The line should loop down and have points. (see the picture) Draw straight lines from the points in your first line to the corner.  Make lines parallel to your fist line going all the way up. You should have 5 or 6 lines. Get a paper and make a cross on it, try to make both lines a similar length (using a ruler will help)\n  Draw diagonal lines through the centre, dividing the paper up from 4 to 8 sections. Make sure they are smaller than the cross you made before. Start connecting the lines with inverted arcs, this is an arc ), from the inside out.  Once you have reached the end of the web, elongate the diagonal lines, (this will make it look like it has supports).  Draw a spider by making a fuzzy ball, then drawing legs (eight of them) on your web. Or see the spider drawing tip. Finished.  Draw a circle and draw a cross section which also extends outside the circle.  Draw two diagonal lines at the midpoint of the cross-sections which form an X-mark.  Draw squares which descend in size as it nears the centre point. Draw the corners or vertices of the square along the diagonal lines. Draw diamond shapes descending in size as it nears the centre point. Draw the vertices along the lines of the cross-section. Draw curves to connect the lines \u2013 from the squares to the diamonds, much like forming bridges.  Trace with a pen and erase unnecessary lines. You may add drawings for spiders. Color to your liking!"


# tokenize and prepare the input
prompt = "%s\n" % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

# generate
model_response = model.generate(
    **prompt, 
    max_new_tokens=512, do_sample=True,
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(model_response[0], skip_special_tokens=True))

Summarize the text in a few sentences. Using original phrases or paraphrasing them if necessary. Do not include new information beyond the given passages.
Take your pencil and on the top of the page, about two inches from the right start drawing a line to 2 inches (5.1 cm) below the right top corner. The line should loop down and have points. (see the picture) Draw straight lines from the points in your first line to the corner.  Make lines parallel to your fist line going all the way up. You should have 5 or 6 lines. Get a paper and make a cross on it, try to make both lines a similar length (using a ruler will help)
  Draw diagonal lines through the centre, dividing the paper up from 4 to 8 sections. Make sure they are smaller than the cross you made before. Start connecting the lines with inverted arcs, this is an arc ), from the inside out.  Once you have reached the end of the web, elongate the diagonal lines, (this will make it look like it has supports).  Draw a spider by making

In [13]:
for e in training_examples:
    print(len(tokenizer.tokenize(e[0])))

1494
3498
1712
1508
3387
2714
1422
2526
1081
1927
