In [1]:
import os

os.environ["GEMINI_API_KEY"] = 'api_key here'

In [3]:
import textwrap
from IPython.display import Markdown

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

### Embedding the text


In [4]:
import google.generativeai as genai
from chromadb import Documents, EmbeddingFunction, Embeddings
import os

class GeminiEmbeddingFunction(EmbeddingFunction):
    def __call__(self, input: Documents) -> Embeddings:
        gemini_api_key = os.getenv("GEMINI_API_KEY")
        if not gemini_api_key:
            raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
        genai.configure(api_key=gemini_api_key)
        model = "models/embedding-001"
        title = "Custom query"
        return genai.embed_content(model=model,
                                   content=input,
                                   task_type="retrieval_document",
                                   title=title)["embedding"]

  from .autonotebook import tqdm as notebook_tqdm


### Storing vectors into DB


In [5]:
import chromadb

def create_chroma_db(documents, path, name):
    chroma_client = chromadb.PersistentClient(path=path)
    db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFunction())

    for i, d in enumerate(documents):
        db.add(documents=d, ids=str(i))

    return db, name

In [6]:
import csv

with open('data\data.csv', encoding='utf-8') as file:
    lines = csv.reader(file)

    documents = []
    ids = []
    id = 1

    for i, line in enumerate(lines):
        if i==0:
            # Skip the first row (the column headers)
            continue

        documents.append(f'time: {line[0]}, ticker: {line[6]}, open: {line[1]}, high: {line[2]}, low: {line[3]}, close: {line[4]}, volume: {line[5]}, comGroupCode: {line[7]}, organName: {line[8]}, organShortName: {line[9]}, icbName: {line[12]}, icbCode: {line[17]}')
        ids.append(str(id))
        id+=1

In [7]:
documents

['time: 2024-04-24, ticker: SSI, open: 35000, high: 35950, low: 34900, close: 35650, volume: 21326800, comGroupCode: HOSE, organName: Công ty Cổ phần Chứng khoán SSI, organShortName: Chứng khoán SSI, icbName: Môi giới chứng khoán, icbCode: 8777',
 'time: 2024-04-24, ticker: BCM, open: 51000, high: 52600, low: 50900, close: 52300, volume: 771300, comGroupCode: HOSE, organName: Tổng Công ty Đầu tư và Phát triển Công nghiệp - CTCP, organShortName: Becamex IDC, icbName: Bất động sản, icbCode: 8633',
 'time: 2024-04-24, ticker: VHM, open: 40350, high: 40900, low: 40250, close: 40600, volume: 5041500, comGroupCode: HOSE, organName: Công ty Cổ phần Vinhomes, organShortName: Vinhomes, icbName: Bất động sản, icbCode: 8633',
 'time: 2024-04-24, ticker: VIC, open: 41100, high: 41950, low: 41100, close: 41600, volume: 1583300, comGroupCode: HOSE, organName: Tập đoàn Vingroup - Công ty CP, organShortName: VinGroup, icbName: Bất động sản, icbCode: 8633',
 'time: 2024-04-24, ticker: VRE, open: 22200,

In [8]:
ids

['1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '30',
 '31',
 '32',
 '33',
 '34',
 '35',
 '36',
 '37',
 '38',
 '39',
 '40',
 '41',
 '42',
 '43',
 '44',
 '45',
 '46',
 '47',
 '48',
 '49',
 '50',
 '51',
 '52',
 '53',
 '54',
 '55',
 '56',
 '57',
 '58',
 '59',
 '60',
 '61',
 '62',
 '63',
 '64',
 '65',
 '66',
 '67',
 '68',
 '69',
 '70',
 '71',
 '72',
 '73',
 '74',
 '75',
 '76',
 '77',
 '78',
 '79',
 '80',
 '81',
 '82',
 '83',
 '84',
 '85',
 '86',
 '87',
 '88',
 '89',
 '90',
 '91',
 '92',
 '93',
 '94',
 '95',
 '96',
 '97',
 '98',
 '99',
 '100',
 '101',
 '102',
 '103',
 '104',
 '105',
 '106',
 '107',
 '108',
 '109',
 '110',
 '111',
 '112',
 '113',
 '114',
 '115',
 '116',
 '117',
 '118',
 '119',
 '120',
 '121',
 '122',
 '123',
 '124',
 '125',
 '126',
 '127',
 '128',
 '129',
 '130',
 '131',
 '132',
 '133',
 '134',
 '135',
 '136',
 '137',
 '138',
 '13

In [9]:
db,name = create_chroma_db(documents=documents, path="my_vectordb", name="rag_experiment")

In [10]:
db

Collection(name=rag_experiment)

In [11]:
def load_chroma_collection(path, name):
    chroma_client = chromadb.PersistentClient(path=path)
    db = chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFunction())

    return db

In [12]:
db=path=load_chroma_collection("my_vectordb", name="rag_experiment")

### Retrieval

In [14]:
def get_relevant_passage(query, db, n_results):
  passage = db.query(query_texts=[query], n_results=n_results)['documents'][0]
  return passage

In [16]:
relevant_text = get_relevant_passage("Mã cổ phiếu BVH", db, 3)
relevant_text

['time: 2024-04-24, ticker: BVH, open: 39250, high: 39650, low: 39050, close: 39600, volume: 239800, comGroupCode: HOSE, organName: Tập đoàn Bảo Việt, organShortName: Tập đoàn Bảo Việt, icbName: Bảo hiểm nhân thọ, icbCode: 8575',
 'time: 2024-04-24, ticker: BVS, open: 34600, high: 36700, low: 34600, close: 36300, volume: 875900, comGroupCode: HNX, organName: Công ty Cổ phần Chứng khoán Bảo Việt, organShortName: Chứng khoán Bảo Việt, icbName: Môi giới chứng khoán, icbCode: 8777',
 'time: 2024-04-24, ticker: CMV, open: 9300, high: 9890, low: 9300, close: 9890, volume: 1500, comGroupCode: HOSE, organName: Công ty Cổ phần Thương nghiệp Cà Mau, organShortName: Thương nghiệp Cà Mau, icbName: Bán lẻ phức hợp, icbCode: 5373']

### Generation

In [17]:
def make_rag_prompt(query, relevant_passage):
  escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
  prompt = ("""You are a helpful and informative bot that answers questions using text from the reference passage included below. \
  Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
  However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
  strike a friendly and converstional tone. \
  If the passage is irrelevant to the answer, you may ignore it.
  QUESTION: '{query}'
  PASSAGE: '{relevant_passage}'

  ANSWER:
  """).format(query=query, relevant_passage=escaped)

  return prompt

In [18]:
import google.generativeai as genai

def generate_response(prompt):
    gemini_api_key = os.getenv("GEMINI_API_KEY")
    if not gemini_api_key:
        raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
    genai.configure(api_key=gemini_api_key)
    model = genai.GenerativeModel('gemini-pro')
    answer = model.generate_content(prompt)
    return answer.text

In [21]:
promt = 'IUH là trường đại học nào ở việt nam?'
response = generate_response(promt)

to_markdown(response)

> Trường Đại học Công nghiệp TP.HCM (IUH)

### Bringing it all together

In [22]:
def generate_answer(db,query):
    #retrieve top 3 relevant text chunks
    relevant_text = get_relevant_passage(query,db,n_results=3)
    prompt = make_rag_prompt(query, 
                             relevant_passage="".join(relevant_text)) # joining the relevant chunks to create a single passage
    answer = generate_response(prompt)

    return answer
    

In [23]:
db=load_chroma_collection(path="my_vectordb", name="rag_experiment") 

answer = generate_answer(db, query="Giá mở cửa mã chứng khoán SSI ngày 24/04/2024")

to_markdown(answer)

> The opening price of SSI stock on April 24, 2024 was 35,000 VND.

In [27]:
answer = generate_answer(db,query="organName của mã chứng khoán ACB")
to_markdown(answer)

> The organName of the stock code ACB is Ngân hàng Thương mại Cổ phần Á Châu, which translates to Asia Commercial Joint Stock Bank.