In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.utils import is_flash_attn_2_available 
from nn_rag import Retrieval

### Retrieval

In [2]:
rag = Retrieval.from_memory()
rag.set_source_uri('chroma:///hadron/data/')

<nn_rag.components.retrieval.Retrieval at 0x7fc2fc3db160>

### Model

In [3]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

if (is_flash_attn_2_available()) and (torch.cuda.get_device_capability(0)[0] >= 8):
    attn_implementation = "flash_attention_2"
else:
    attn_implementation = "sdpa"
    
# 4 bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16, # gain back some speed 
    bnb_4bit_quant_type="nf4", # weights initialized with normal distribution
    bnb_4bit_use_double_quant=True, # better memory footprint
    llm_int8_enable_fp32_cpu_offload=True, # offload weights cross GPU and CPU
)

# instantiate the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map='auto',
    attn_implementation=attn_implementation
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



In [4]:
def get_model_num_params(model: torch.nn.Module):
    return sum([param.numel() for param in model.parameters()])

f"Model Size: {get_model_num_params(model):,.0f} parameters"

'Model Size: 8,030,261,248 parameters'

### Prompt Engineering

In [5]:
def prompt_formatter(query: str, context_items: list[dict]) -> str:
    """
    Augments query with text-based context from context_items.
    """
    # Join context items into one dotted paragraph
    context = "- " + "\n- ".join([item for item in context_items["text"]])

    # Create a base prompt with examples to help the model
    base_prompt = """Based on the following context items, please answer the query.
Give yourself room to think by extracting relevant passages from the context before answering the query.
Don't return the thinking, only return the answer.
Make sure your answers are as explanatory as possible.
Use the following examples as reference for the ideal answer style.
\nExample 1:
Query: What are the core principles of responsible AI mentioned in the guide?
Answer: The guide outlines core principles of responsible AI, which include fairness and inclusion, robustness and safety, privacy and security, and transparency and control. Additionally, it emphasizes the importance of governance and accountability mechanisms to ensure these principles are upheld throughout the development and deployment of AI systems.
\nExample 2:
Query: How does Meta's open approach contribute to AI innovation?
Answer: Meta's open approach to AI innovation involves open-sourcing code and datasets, contributing to the AI community's infrastructure, and making large language models available for research. This approach fosters a vibrant AI-innovation ecosystem, driving breakthroughs in various sectors and enabling exploratory research and large-scale production deployment. It also draws upon the collective wisdom and diversity of the AI community to improve and democratize AI technology.
\nExample 3:
Query: What are the stages of responsible LLM product development according to the guide?
Answer: The guide identifies four stages of responsible LLM product development: determining the use case, fine-tuning for the product, addressing input- and output-level risks, and building transparency and reporting mechanisms in user interactions. Each stage involves specific considerations and mitigation strategies to ensure the safe and effective deployment of LLM-powered products.
\nNow use the following context items to answer the user query:
{context}
\nRelevant passages: <extract relevant passages from the context here>
User query: {query}
Answer:"""

    # Update base prompt with context items and query   
    base_prompt = base_prompt.format(context=context, query=query)

    # Create prompt template for instruction-tuned model
    dialogue_template = [
        {"role": "user", "content": base_prompt}
    ]

    # Apply the chat template
    prompt = tokenizer.apply_chat_template(conversation=dialogue_template,
                                          tokenize=False,
                                          add_generation_prompt=True
                                          ).to(model.device)

    return prompt 

### Generate

In [6]:
def ask(query, 
        temperature=0.7,
        max_new_tokens=512,
        format_answer_text=True, 
        return_answer_only=True):
    
    # Get the query answers from the embedding
    context_items = rag.tools.query_reranker(query=query)
    
    # Format the prompt with context items
    prompt = prompt_formatter(query=query, context_items=context_items)
    
    # Tokenize the prompt
    input_ids = tokenizer(prompt, return_tensors="pt").to(device)

    # Set tokenize terminators
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    # Generate an output of tokens
    outputs = llm_model.generate(**input_ids,
                                 eos_token_id=terminators,
                                 temperature=temperature,
                                 do_sample=True,
                                 max_new_tokens=max_new_tokens,
                                 top_p=0.9,
                                )
    
    # Turn the output tokens into text
    output_text = tokenizer.decode(outputs[0])

    if format_answer_text:
        # Replace special tokens and unnecessary help message
        output_text = output_text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "").replace("Sure, here is the answer to the user query:\n\n", "")

    # Only return the answer without the context items
    if return_answer_only:
        return output_text
    
    return output_text, context_items

In [7]:
# response = outputs[0][input_ids.shape[-1]:]
# print(tokenizer.decode(response, skip_special_tokens=True))

In [8]:
import textwrap

def print_wrapped(text, wrap_length=80):
    wrapped_text = textwrap.fill(text, wrap_length)
    print(wrapped_text)


In [9]:
query = "What considerations should be taken into account when defining content policies for LLMs?"
print(f"Query: {query}")

# Answer query with context and return context 
answer, context_items = ask(query=query, 
                            temperature=0.7,
                            max_new_tokens=512,
                            return_answer_only=False)

print(f"Answer:\n")
print_wrapped(answer)
print(f"Context items:")
context_items

Query: What considerations should be taken into account when defining content policies for LLMs?




RuntimeError: MPS backend out of memory (MPS allocated: 6.80 GB, other allocations: 2.56 MB, max allowed: 6.80 GB). Tried to allocate 10.37 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).