# Context Aware Decoding Demo

In [25]:
from '../../Representation\ Engineering/representation-engineering-main' import setup

SyntaxError: invalid syntax (1601499739.py, line 1)

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
from torch.nn import functional as F


model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): 

In [36]:
test_case_1 = {
        'context' : "The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026.",
        'question' : "How many world cups has Argentina won?",        
        'answer' : '''The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026. How many world cups has Argentina won?
How many world cups has Argentina won?
Argentina has won 19 World Cups, the most of any country.
Argentina won its first World Cup in 1978, beating West Germany ''' 
    }
    
test_case_2 = {
    'context' : "Prison Link Cymru had 1099 referrals in 2015-16 and said some ex-offenders were living rough for up to a year before finding suitable accommodation ......",
    'question' : "Summarize the article in one sentence. ",
    'answer' : '''Prison Link Cymru had 1099 referrals in 2015-16 and said some ex-offenders were living rough for up to a year before finding suitable accommodation ...... Summarize the article in one sentence. ›
The article discusses the challenges faced by ex-offenders in finding suitable accommodation, and the efforts of Prison Link Cymru to address this issue.''' 
}

test_case_3 = {
    'context' : '''Write a quote that ends in the word "early":''',
    'question' : "Better late than",
    'answer' : '''Context-Aware Decoding Output:
 Write a quote that ends in the word "early": Better late than never.

Vocabulary: late, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end,'''
}

context = test_case_3['context']
question = test_case_3['question']

context_input = tokenizer(context, return_tensors="pt").input_ids.to(device)
question_input = tokenizer(question, return_tensors="pt").input_ids.to(device)

input_ids = torch.cat([context_input, question_input], dim=-1)


def standard_decoding(input_ids, max_length=128, temperature=1.0, top_k=50, top_p=0.9):
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def context_aware_sampling(model, tokenizer, input_ids, context_ids, alpha=0.9, max_length=128, temperature=1.0):
    generated_tokens = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            full_context_outputs = model(generated_tokens)
            full_context_logits = full_context_outputs.logits[:, -1, :] 

            question_only_input = generated_tokens[:, len(context_ids):]
            question_only_outputs = model(question_only_input)
            question_only_logits = question_only_outputs.logits[:, -1, :] 

        adjusted_logits = (1 + alpha) * full_context_logits - alpha * question_only_logits
        adjusted_probs = F.softmax(adjusted_logits / temperature, dim=-1)

        next_token = torch.multinomial(adjusted_probs, num_samples=1)

        generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    return generated_tokens

In [37]:
model.eval()
standard_output = standard_decoding(input_ids)
output_tokens = context_aware_sampling(
                                        model,
                                        tokenizer,
                                        input_ids,
                                        context_ids=context_input,
                                        alpha=0.5,
                                        max_length=50,
                                        temperature=0.0000000000000000000000000001,
                                    )

context_aware_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)


print("Standard Decoding Output:\n", standard_output)
print("__" * 50)
print("Context-Aware Decoding Output:\n", context_aware_output)


Standard Decoding Output:
 Write a quote that ends in the word "early": Better late than never, it turns out.
As of 2015, there were still more than 17 million American children living in homes with parents who worked 18 hours or more each week, up from 13 million in 1990.
____________________________________________________________________________________________________
Context-Aware Decoding Output:
 Write a quote that ends in the word "early": Better late than never.

Vocabulary: late, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end, end,


In [35]:
context_aware_output == test_case_2['answer']

True

In [None]:
context_aware_output

In [7]:
'''The current year is 2027. Argentina won World Cups in 1978,1986,2022 and 2026. How many world cups has Argentina won? (Answer: 6)\nThe world football championships, better known as world cups, are countless in number. In Spanish speaking countries called "the World Cup" is called. A country that qualifies for a World Cup is usually''' == context_aware_output

True

In [40]:
json_dict = {'paper_name':'Trusting Your Evidence: Hallucinate Less with Context-aware Decoding',
             'paper_url':'https://arxiv.org/abs/2305.14739',
             'year':'2023',
             'repo_url':'https://github.com/xhan77/context-aware-decoding.git',
             'repo_path':'representation-engineering-main',
             'implementations':'include only one method introduced in the paper for the agent to be implemented.'}
# Specify the file name
filename = 'info.json'

# Writing JSON data
with open(filename, 'w') as f:
    json.dump(json_dict, f, indent=4)

In [16]:
“paper_name”: the full name of the paper. 
“paper_url”: the paper url of its ACL version
“year”: in which year of ACL it was accepted
“repo_url”: official Github repo url
“repo_path”: "benchmark/datasets/{year}}/{folder_name}/{repo_name}"
“implementations”: it includes a list of methods for the agent to be implemented and each method is a structure. For each structure, it includes 
“instruction”: The method you would like the agent to do.
“index”: the index of the method in this paper, starting from 1.
“category”: the category of the method, you can find a phrase to represent here and we can discuss the possible categories after we annotate some papers.
“goal_file”: The masked file which the agent needs to implement. You need to provide the path.
“golden_file”: Path of the golden file.
“retrieval_content”: required code blocks in the repository to help with the implementation. For each code block, you need to annotate 3 attributes: The identifier is the variable or function name used in the goal file and it can be empty. The path is the file of the code block. The snippet is the code block content.
“unit_test_file”: path of the unit test file.


2

In [None]:

def test_function(context, question):