In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.utils import is_flash_attn_2_available 
from nn_rag import Retrieval

### Retrieval

In [2]:
rag = Retrieval.from_memory()
rag.set_source_uri('chroma:///hadron/data/')

<nn_rag.components.retrieval.Retrieval at 0x7f864bb68e20>

### Model

In [3]:
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = 'google/gemma-2b-it'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if (is_flash_attn_2_available()) and (torch.cuda.get_device_capability(0)[0] >= 8):
    attn_implementation = "flash_attention_2"
else:
    attn_implementation = "sdpa"
    
# 4 bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16, # gain back some speed 
    bnb_4bit_quant_type="nf4", # weights initialized with normal distribution
    bnb_4bit_use_double_quant=True, # better memory footprint
    llm_int8_enable_fp32_cpu_offload=True, # offload weights cross GPU and CPU
)

# instantiate the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    quantization_config=None,
    device_map=device,
    attn_implementation=attn_implementation
)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
def get_model_num_params(model: torch.nn.Module):
    return sum([param.numel() for param in model.parameters()])

f"Model Size: {get_model_num_params(model):,.0f} parameters"

'Model Size: 2,506,172,416 parameters'

### Prompt Engineering

In [5]:
def prompt_formatter(query: str, context_items: list[dict]) -> str:
    """
    Augments query with text-based context from context_items.
    """
    # Join context items into one dotted paragraph
    context = "- " + "\n- ".join([item for item in context_items])

    # Create a base prompt with examples to help the model
    base_prompt = """I am a professional summarizer with an academic audience. 
Based on the following context items, please answer the query. 
Give yourself room to think by extracting relevant passages from the context before answering the query. 
Don't return the thinking, only return the answer. Make sure your answers are as explanatory as possible.
Now use the following context items to answer the user query:
{context}
Relevant passages: <extract relevant passages from the context here>
User query: {query}
Answer:"""

    # Update base prompt with context items and query   
    base_prompt = base_prompt.format(context=context, query=query)

    # Create prompt template for instruction-tuned model
    dialogue_template = [
        {"role": "user", "content": base_prompt}
    ]

    # Apply the chat template
    prompt = tokenizer.apply_chat_template(conversation=dialogue_template,
                                          tokenize=False,
                                          add_generation_prompt=True)

    return prompt 

### Generate

In [6]:
def ask(query, 
        temperature=0.7,
        max_new_tokens=512,
        format_answer_text=True, 
        return_answer_only=True):
    
    # Get the query answers from the embedding
    context_items = rag.tools.query_reranker(query=query, bi_limit=20, cross_limit=10)['text'].to_pylist()
    
    # Format the prompt with context items
    prompt = prompt_formatter(query=query, context_items=context_items)
    
    # Tokenize the prompt
    input_ids = tokenizer(prompt, return_tensors="pt").to(device)

    # Set tokenize terminators
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    # Generate an output of tokens
    outputs = model.generate(**input_ids,
                             eos_token_id=terminators,
                             temperature=temperature,
                             do_sample=True,
                             max_new_tokens=max_new_tokens,
                             top_p=0.9,
                            )
    
    # Turn the output tokens into text
    output_text = tokenizer.decode(outputs[0])

    if format_answer_text:
        # Replace special tokens and unnecessary help message
        output_text = output_text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "").replace("Sure, here is the answer to the user query:\n\n", "")

    # Only return the answer without the context items
    if return_answer_only:
        return output_text
    
    return output_text, context_items

In [7]:
# response = outputs[0][input_ids.shape[-1]:]
# print(tokenizer.decode(response, skip_special_tokens=True))

In [8]:
import textwrap

def print_wrapped(text, wrap_length=80):
    wrapped_text = textwrap.fill(text, wrap_length)
    print(wrapped_text)


In [9]:
import random

questions = [
    "What are the steps of the responsible fine-tuning flow?",
    "What is the impact and best practices for mitigating ethical and responsibility risks in the deployment of artificial intelligence (AI) models?",
    "How does the broader impact of AI technologies effect potential biases, privacy concerns, and the societal implications of AI deployment?",
    "How should I design AI systems that are fair and do not perpetuate existing biases?",
    "What is the importance of transparency through clear communication in AI systems?",
    "How does privacy and security secure against malicious attacks in the deployment of artificial intelligence (AI) models?",
    "How does the guide addresse the significance of accountability in responsible AI",    
]

query = random.choice(questions)

In [10]:
print(f"Query: {query}")

# Answer query with context and return context 
answer, context_items = ask(query=query, 
                            temperature=0.7,
                            max_new_tokens=768,
                            return_answer_only=False)

print(f"Answer:")
print_wrapped(answer)
print(f"\nContext items:")
context_items

Query: How should I design AI systems that are fair and do not perpetuate existing biases?




Answer:
The context does not provide any information about how to design AI systems that
are fair and do not perpetuate existing biases, so I cannot answer this question
from the provided context.

Context items:


['Fine-tuning adapts the model to domain- or application- specific requirements and introduces additional layers of safety mitigations. Examples of fine-tuning for a pretrained LLM include: • Text summarization: By using a pretrained language model, the model can be fine-tuned on a dataset that includes pairs of long-form documents and corresponding summaries. This fine-tuned model can then generate concise summaries for new documents. Question answering: Fine-tuning a language model on a Q&A dataset such as SQuAD (Stanford Question Answering Dataset) allows the model to learn how to answer questions based on a given context paragraph. The fine-tuned model can then be used to answer questions on various topics.',
 '13 JULY 2023 Address input- and output-level risks Without proper safeguards at the input and output levels, it is hard to ensure that the model will respond properly to adversarial inputs and will be protected from efforts to circumvent content policies and safeguard measur