In [None]:
import torch
from models.common import RangeWeight
from models.modeling_mistral import MistralForCausalLM
from transformers import AutoTokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model = MistralForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

In [3]:
prompt = """Answer the following question based on the documents provided.
Document 1: Today, the weather is raining.
Document 2: Today, the weather is sunny.
Question: What is the weather like today?"""
messages = [
    {"role": "user", "content": prompt},
]

chat_prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(chat_prompt, add_special_tokens=False, return_tensors="pt").to(
    model.device
)

In [4]:
outputs = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=100,
    pad_token_id=tokenizer.eos_token_id,
)
response = outputs[0][inputs["input_ids"].shape[-1] :]
print("Model output WITHOUT attention scaling:")
print(tokenizer.decode(response))

Model output WITHOUT attention scaling:
Based on the provided documents, there is a discrepancy between Document 1 and Document 2 regarding the weather today. Document 1 states that it is raining, while Document 2 states that it is sunny. Without additional information or clarification, it is impossible to determine which document is accurate. Therefore, I cannot answer the question definitively based on the given documents.</s>


In [None]:
print(len(prompt))
print(len(inputs["input_ids"][0]))
print(inputs["input_ids"][0])

In [5]:
range_weights = [
    RangeWeight(15, 28, 1),
    RangeWeight(28, 41, 0),
]
for range_weight in range_weights:
    print(f"Giving the following text a weight of {range_weight.weight}:")
    print(
        tokenizer.decode(inputs["input_ids"][0][range_weight.start : range_weight.end])
    )
    print("---")

Giving the following text a weight of 1:
Document 1: Today, the weather is raining.

---
Giving the following text a weight of 0:
Document 2: Today, the weather is sunny.

---


In [6]:
outputs = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=100,
    pad_token_id=tokenizer.eos_token_id,
    range_weights=range_weights,
)
response = outputs[0][inputs["input_ids"].shape[-1] :]
print("Model output WITH attention scaling:")
print(tokenizer.decode(response))

Model output WITH attention scaling:
Based on Document 1, the weather is raining today.</s>
