In [2]:
import pandas as pd
import numpy as np
from datetime import datetime

import chromadb
from chromadb.config import Settings

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline


In [3]:
model_id = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(model_id,trust_remote_code=True)

pipe = pipeline('text-generation', 
                model=lm_model,
                tokenizer=tokenizer, 
                max_new_tokens=256,
                device_map='auto')

In [4]:
class NewsRagger:
    def __init__(self, MAX_NEWS=500,DOCUMENT='title',TOPIC='topic'):
        news = pd.read_csv('./datasets/technology_dataset.csv', sep=';')
        news['id'] = news.index
        
        self.subset_news = news.head(MAX_NEWS)
        #display(self.subset_news)
        
        self.document = DOCUMENT
        self.topic = TOPIC
        self.max_news = MAX_NEWS

        
    def create_db_collection(self):
        chroma_client = chromadb.PersistentClient(path='chromadb')
        collection_name = 'news_tech_collection' #+ datetime.now().strftime('%S')
        #print(collection_name)

        if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
            chroma_client.delete_collection(collection_name)
        
        self.collection = chroma_client.create_collection(name=collection_name) 
        self.collection.add(
                                documents=self.subset_news[self.document].tolist(),
                                metadatas=[{self.topic: topic} for topic in self.subset_news[self.topic].tolist()],
                                ids=[f'id{x}' for x in range(self.max_news)]
                            )
        return self
    
    def collection_query(self,query_texts,n_results=10):
        results = self.collection.query(query_texts=query_texts,n_results=n_results)
        self.query_results = results['documents'][0]
        return self
    
    def question(self,question):
        context = ' '.join([f"#{str(i)}" for i in self.query_results])
        prompt_template = f"""
            Relevant context: {context}
            Considering the relevant context, answer the questions?
            Question: {question}
            Answer:
        """

        lm_response = pipe(prompt_template)
        self.final_response = lm_response[0]['generated_text']
        return self



In [5]:
rag = (
    NewsRagger()
        .create_db_collection()
        .collection_query(['laptop'])
)
    

In [6]:
rag.question('Can i buy a new toshiba laptop?').final_response

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


"\n            Relevant context: #The Legendary Toshiba is Officially Done With Making Laptops #Lenovo and HP control half of the global laptop market #Acer Swift 3 featuring a 10th-generation Intel Ice Lake CPU, 2K screen, and more launched in India for INR 64999 (US$865) #Apple's Next MacBook Could Be the Cheapest in Company's History #Features of Huawei's Desktop Computer Revealed #Redmi to launch its first gaming laptop on August 14: Here are all the details #Toshiba shuts the lid on laptops after 35 years #Apple to Reportedly Launch Its Cheapest MacBook Ever #Dell announces the premium Latitude 7410 Chromebook Enterprise: available now #Surface Reveals Microsoft’s Turbocharged Android\n            Considering the relevant context, answer the questions?\n            Question: Can i buy a new toshiba laptop?\n            Answer:\n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n     