In [None]:
import math
import torch
from torch.utils.data import DataLoader
import sys
sys.dont_write_bytecode = True
from datasets import Dataset, load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, LlamaTokenizer, LlamaForCausalLM, LlamaConfig, DataCollatorForSeq2Seq
from transformers.generation.logits_process import LogitsProcessorList
from collections import Counter
from typing import Union, List, Dict, Optional, Callable
from token_prioritization import compute_token_scores, create_prompt, preprocess_and_tokenize

In [32]:
ds = load_dataset("databricks/databricks-dolly-15k")['train']  # 15011 ['instruction', 'context', 'response', 'category']

In [8]:
print(create_prompt(ds[0]))

Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. 
### Question: When did Virgin Australia start operating? 
### Response: 


In [12]:
model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
access_token = "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"
config = LlamaConfig.from_pretrained(model_name_or_path, token=access_token)
tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, token=access_token)
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(
    model_name_or_path, 
    config=config, 
    token=access_token,
    device_map='auto',
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.79s/it]


In [28]:
token_scores = compute_token_scores(ds, tokenizer, model, prompt_fn=create_prompt, size=100)
# token_str_scores = {tokenizer.decode([token]): score for token, score in token_scores.items()}
# token_str_scores = dict(sorted(token_str_scores.items(), key=lambda x: x[1], reverse=True))

In [33]:
dataset = ds.map(
    preprocess_and_tokenize,
    batched=True,
).remove_columns(ds.column_names)  # ['input_ids', 'attention_mask', 'labels']

Map:   0%|          | 0/15011 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 15011/15011 [00:16<00:00, 891.58 examples/s]


In [40]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=tokenizer.pad_token_id,
    pad_to_multiple_of=None,
)
train_dataloader = DataLoader(
    dataset,
    batch_size=5,
    collate_fn=data_collator,
)
inputs = next(iter(train_dataloader))

In [56]:
# Normal generation
outputs = model.generate(
    **inputs,
    do_sample=False,
)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(responses)



["Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. \n### Question: When did Virgin Australia start operating? \n### Response: \nVirgin Australia started operating on 31 August 2000 as Virgin Blue, with two aircraft on a single route. ", ' \n### Question: Which is a species of fish? Tope or Rope \n### Response: �ope is a species of fish, while rope is not a species of fish.\n\nThe correct answer is "tope". "Tope" is a species of fish, while "rope" is not a species of fish.', ' \n### Question: Why can camels survive for long without w

In [59]:
for rs in responses:
    print(rs)
    print('\n')

Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. 
### Question: When did Virgin Australia start operating? 
### Response: 
Virgin Australia started operating on 31 August 2000 as Virgin Blue, with two aircraft on a single route. 


 
### Question: Which is a species of fish? Tope or Rope 
### Response: �ope is a species of fish, while rope is not a species of fish.

The correct answer is "tope". "Tope" is a species of fish, while "rope" is not a species of fish.


 
### Question: Why can camels survive for long without water? 
### R

In [60]:
# Weighted logit generation
class DataDrivenTokenLogitsProcessor:
    def __init__(self, token_scores, boost_factor=5.0):
        self.token_scores = token_scores
        self.boost_factor = boost_factor

    def __call__(self, input_ids, logits):
        for token, score in self.token_scores.items():
            logits[:, token] += self.boost_factor * score
        return logits
    
logits_processor = LogitsProcessorList([DataDrivenTokenLogitsProcessor(token_scores)])
outputs_weighted = model.generate(
    **inputs,
    do_sample=False,
    logits_processor=logits_processor,
)
responses_weighted = tokenizer.batch_decode(outputs_weighted, skip_special_tokens=True)
print(responses_weighted)



["Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. \n### Question: When did Virgin Australia start operating? \n### Response: operating since operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operation operat

In [61]:
for rs in responses:
    print(rs)
    print('\n')

Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. 
### Question: When did Virgin Australia start operating? 
### Response: 
Virgin Australia started operating on 31 August 2000 as Virgin Blue, with two aircraft on a single route. 


 
### Question: Which is a species of fish? Tope or Rope 
### Response: �ope is a species of fish, while rope is not a species of fish.

The correct answer is "tope". "Tope" is a species of fish, while "rope" is not a species of fish.


 
### Question: Why can camels survive for long without water? 
### R