In [1]:
import torch
import random
from transformers import LlamaForCausalLM, LlamaTokenizerFast, \
AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm

from functools import reduce

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cache_path = '/data/dangnguyen/cache/'

In [3]:
model_name = 'sharpbai/Llama-2-7b-chat'
tokenizer = LlamaTokenizerFast.from_pretrained(model_name, cache_dir=cache_path)
model = LlamaForCausalLM.from_pretrained(model_name, cache_dir=cache_path)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████| 34/34 [00:04<00:00,  7.05it/s]


In [3]:
model_name = 'mistralai/Mistral-7B-Instruct-v0.2'
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_path)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path)

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.20it/s]


In [4]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [5]:
model.eval()
model.to('cuda')

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
 

### DAIR-AI Emotion dataset

In [None]:
def map_label(label):
    if label == 0:
        return 'sadness'
    elif label == 1:
        return 'joy'
    elif label == 2:
        return 'love'
    elif label == 3:
        return 'anger'
    elif label == 4:
        return 'fear'
    elif label == 5:
        return 'surprise'
    else:
        return 'none'

In [None]:
dataset = load_dataset('dair-ai/emotion', split='test')

# Filtering for obvious samples
keywords = ['sad', 'joy', 'love', 'anger', 'fear', 'surprise']
dataset_clean = {
    'text': [],
    'label': [],
}

for data in dataset:
    keep = True
    for kw in keywords:
        if kw in data['text']:
            keep = False
            break
    if keep:
        dataset_clean['text'].append(data['text'])
        dataset_clean['label'].append(data['label'])

In [None]:
labels_eng = list(map(map_label, dataset_clean['label']))
dataset_clean['label_text'] = labels_eng

In [None]:
dataset_torch = Dataset.from_dict(dataset_clean)
dataset_torch

In [None]:
random.choice(dataset_torch)

In [None]:
template = open('/data/dangnguyen/weak-to-strong/prompts/emotion_fewshot.txt').read()
template

### Tweet Sentiment dataset

In [None]:
dataset = load_dataset('mteb/tweet_sentiment_extraction', split='test')
dataset

In [None]:
dataset_torch = Dataset.from_dict(dataset[:2000])
dataset_torch

In [None]:
random.choice(dataset_torch)

In [None]:
template = open('/data/dangnguyen/weak-to-strong/prompts/tweet_sentiment_fewshot.txt').read()
template

### Running the model

In [None]:
outputs = []
for sample in tqdm(dataset_torch):
    prompt = template.format(input=sample['text'])
    input_ids = tokenizer(prompt, return_tensors='pt').to('cuda')
    max_len = input_ids['input_ids'].shape[1]
    with torch.no_grad():
        output_ids = model.generate(input_ids['input_ids'], max_length=max_len+5)
    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    outputs.append(output[0])

In [None]:
outputs

In [None]:
output_labels = list(map(lambda x: x.split('>\n\nAnswer:\n')[1].lower().strip(), outputs))
output_labels

In [None]:
# control
labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
control_labels = [random.choice(labels) for _ in range(2000)]

In [None]:
correct = 0
for pred, gt in zip(output_labels, dataset_torch['label_text']):
    if pred == gt:
        correct += 1
        
accuracy = correct / len(output_labels)
accuracy

### Synthetic datasets: hierarchical equality

In [6]:
# Generating the data
def get_equiv_rel_data(n_samples=1000):
    def sample_one(start=10, end=99):
        label = random.choice(['Yes', 'No'])
        if label == 'Yes':
            a = random.randint(start, end-3)
            b = a + 3
        else:
            dist = 3
            while dist == 3:
                dist = random.randint(start, end-start)
            a = random.randint(start, end-dist)
            b = a + dist
        return (a, b, label)
    
    sampled_data = [sample_one() for _ in range(n_samples)]
    sampled_input = ["{} {}".format(a, b) for a, b, _ in sampled_data]
    sampled_label = [label for _, _, label in sampled_data]
    return (sampled_input, sampled_label)

def get_accuracy(preds, labels):
    num_correct = 0
    for pred, label in zip(preds, labels):
        if pred == label:
            num_correct += 1
    return num_correct / len(labels)

In [7]:
total_data_size = 1000
raw_data = get_equiv_rel_data(total_data_size)

In [8]:
num_pos = 0
for label in raw_data[1]:
    if label == 'Yes':
        num_pos += 1
num_pos / len(raw_data[1])

0.499

In [9]:
template = open('/data/dangnguyen/weak-to-strong/prompts/equiv_relation_zeroshot.txt').read()
template

'You are asked to solve the equivalence relation task. Given two numbers, a and b, say "Yes" if a = b + 3 and "No" a does not equal b + 3.\n\nInput:\n{INPUT}\nOutput:\n'

In [10]:
prompts = [template.format(INPUT=data) for data in raw_data[0]]
prompts

['You are asked to solve the equivalence relation task. Given two numbers, a and b, say "Yes" if a = b + 3 and "No" a does not equal b + 3.\n\nInput:\n10 27\nOutput:\n',
 'You are asked to solve the equivalence relation task. Given two numbers, a and b, say "Yes" if a = b + 3 and "No" a does not equal b + 3.\n\nInput:\n65 99\nOutput:\n',
 'You are asked to solve the equivalence relation task. Given two numbers, a and b, say "Yes" if a = b + 3 and "No" a does not equal b + 3.\n\nInput:\n10 91\nOutput:\n',
 'You are asked to solve the equivalence relation task. Given two numbers, a and b, say "Yes" if a = b + 3 and "No" a does not equal b + 3.\n\nInput:\n88 91\nOutput:\n',
 'You are asked to solve the equivalence relation task. Given two numbers, a and b, say "Yes" if a = b + 3 and "No" a does not equal b + 3.\n\nInput:\n62 65\nOutput:\n',
 'You are asked to solve the equivalence relation task. Given two numbers, a and b, say "Yes" if a = b + 3 and "No" a does not equal b + 3.\n\nInput:\

In [None]:
with open()

In [93]:
train_split = int(0.1 * total_data_size)

train_data = Dataset.from_dict(
    {
        'input': prompts[:train_split],
        'label': raw_data[1][:train_split]
    }
)

test_data = Dataset.from_dict(
    {
        'input': prompts[train_split:],
        'label': raw_data[1][train_split:]
    }
)

In [94]:
train_dataloader = DataLoader(
    train_data, batch_size=32,
)
test_dataloader = DataLoader(
    test_data, batch_size=32,
)

In [96]:
all_preds = []

with torch.no_grad():
    epoch_iterator = tqdm(test_dataloader)
    for _, inputs in enumerate(epoch_iterator):
        input_ids = tokenizer(inputs['input'], padding=True, return_tensors='pt').to('cuda')
        output_ids = model(**input_ids)
        
        pred_ids = output_ids.logits[:, -1, :].argmax(dim=-1)
        preds = tokenizer.batch_decode(pred_ids)
        all_preds.append(preds)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:41<00:00,  1.44s/it]


In [103]:
all_preds = reduce(lambda x, y: x + y, all_preds)
test_acc = get_accuracy(all_preds, test_data['label'])
test_acc

0.5366666666666666