In [1]:
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
from transformers.utils import is_flash_attn_2_available 
from transformers import BitsAndBytesConfig

models_path = Path('models')

In [2]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                         bnb_4bit_compute_dtype=torch.float16)
if (is_flash_attn_2_available()) and (torch.cuda.get_device_capability(0)[0] >= 8):
  attn_implementation = "flash_attention_2"
else:
  attn_implementation = "sdpa"
print(f"[INFO] Using attention implementation: {attn_implementation}")

[INFO] Using attention implementation: sdpa


In [3]:
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_id)
llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id, 
                                                 torch_dtype=torch.float16, # datatype to use, we want float16
                                                 quantization_config=quantization_config ,
                                                 low_cpu_mem_usage=True, # use full memory 
                                                 attn_implementation=attn_implementation,) 



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
pipe = pipeline(
    "text-generation",
    model=llm_model,
    tokenizer=tokenizer,
    # device_map="auto",  # This will automatically choose the device for you
)

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

In [7]:
outputs = pipe(
    messages,
    max_new_tokens=256,
)

# Print the generated text
print(outputs[0]["generated_text"][-1])

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


{'role': 'assistant', 'content': "Yer lookin' fer a swashbucklin' chatbot, eh? Alright then, matey! Yer talkin' to Blackbeak Bot, the scurvy dog o' pirate chatbots! I be here to share me treasure o' knowledge, answer yer questions, and keep ye entertained on the high seas o' the internet! So hoist the sails, me hearty, and let's set sail fer a grand adventure o' conversation!"}


In [None]:
def create_rag_prompt(query, retrieved_docs, context=None):
    """
    Creates a prompt for a Retrieval-Augmented Generation (RAG) model.

    Args:
    - query (str): The user's question or input.
    - retrieved_docs (list of str): List of documents or knowledge pieces retrieved from a retriever model.
    - context (list of dict, optional): Conversation history with 'role' and 'content'. Default is None.

    Returns:
    - str: Formatted prompt for the model.
    """
    
    # Initialize the prompt
    prompt = ""

    # Add context if any (previous system/user conversation)
    if context:
        for message in context:
            role = message["role"].capitalize()
            content = message["content"]
            prompt += f"{role}: {content}\n"
    
    # Add retrieved documents or knowledge to the prompt
    if retrieved_docs:
        prompt += "Here are some pieces of information you might find useful:\n"
        for idx, doc in enumerate(retrieved_docs, 1):
            prompt += f"Document {idx}: {doc}\n"
    
    # Add the user query
    prompt += f"User: {query}\n"
    prompt += "Assistant:"

    return prompt
