In [1]:
!pip install -Uq sentence-transformers

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
chechpoint = "/home/xuan/llama3.2-8b-train-py"

tokenizer = AutoTokenizer.from_pretrained(chechpoint)
llama_model = AutoModelForCausalLM.from_pretrained(chechpoint, torch_dtype=torch.float16)
generator = pipeline("text-generation", model=llama_model, tokenizer=tokenizer, device="cuda")  

Loading checkpoint shards: 100%|██████████| 2/2 [01:41<00:00, 50.67s/it]


In [12]:
text_snippets = [
    "Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.",
    "As they ventured deeper into the forest, they encountered a wide array of obstacles.",
    "Ethan and Fiona crossed treacherous ravines using rickety bridges, relying on each other's strength.",
    "Overwhelmed with joy, Fiona thanked Ethan and disappeared into the embrace of her family.",
    "Ethan returned to his cottage, heart full of memories and a smile brighter than ever before.",
]

In [13]:
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings_text_snippets = model.encode(text_snippets)    

In [14]:
def retrieve_snippet(query):
    query_embedded = model.encode([query])                                              # Encode the query to obtain its embedding
    similarities = model.similarity(embeddings_text_snippets, query_embedded)           # Calculate cosine similarities between the query embedding and the snippet embeddings
    retrieved_texts = text_snippets[similarities.argmax().item()]                       # Retrieve the text snippet with the highest similarity
    return retrieved_texts

In [15]:
# In this step, we utilize the retrieved context snippets to generate a relevant answer using LLaMA, exemplifying the power of RAG in enhancing the quality of responses.

def ask_query(query):
    retrieved_texts = retrieve_snippet(query)

    # Prepare the messages for the text generation pipeline
    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."
                "Provide one Answer ONLY the following query based on the context provided below. "
                "Do not generate or answer any other questions. "
                "Do not make up or infer any information that is not directly stated in the context. "
                "Provide a concise answer."
                f"{retrieved_texts}"},
        {"role": "user", "content": query}
    ]

    # Generate a response using the text generation pipeline
    response = generator(messages, max_new_tokens=128)[-1]["generated_text"][-1]["content"]
    print(f"Query: \n\t{query}")
    print(f"Context: \n\t{retrieved_texts}")
    print(f"Answer: \n\t{response}")

In [16]:
query = "Why did Fiona thank Ethan?"
ask_query(query)

Query: 
	Why did Fiona thank Ethan?
Context: 
	Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.
Answer: 
	Fiona thanked Ethan for his unwavering support.
