## Simple RAG

Oracle Database を検索システムとして用いたシンプルな RAG 構成のコードです。

In [None]:
import os
from dotenv import load_dotenv, find_dotenv

from langfuse import Langfuse
from langfuse.callback import CallbackHandler

import oracledb

from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import OracleVS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI

必要な環境変数を `.env` から読み込みます。

In [None]:
_ = load_dotenv(find_dotenv())

# Oracle Database
un = os.getenv("ORACLE_USERNAME")
pw = os.getenv("ORACLE_PASSWORD")
dsn = os.getenv("ORACLE_DSN")
config_dir = "/tmp/wallet"
wallet_location = "/tmp/wallet"
wallet_password = os.getenv("WALLET_PASSWORD")
table_name = os.getenv("TABLE_NAME")

# OCI
compartment_id = os.getenv("COMPARTMENT_ID")
service_endpoint = os.getenv("SERVICE_ENDPOINT")

# Langfuse
secret_key = os.getenv("LANGFUSE_SECRET_KEY")
public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_host = os.getenv("LANGFUSE_HOST")

Langfuse のクライアントを設定する

In [None]:
langfuse = Langfuse(
    secret_key=secret_key,
    public_key=public_key,
    host=langfuse_host
)
langfuse_handler = CallbackHandler(
    secret_key=secret_key,
    public_key=public_key,
    host=langfuse_host,
    sample_rate=0.5
)

モデルに与えるパラメータ  
アプリケーションにする際は、ユーザーからの選択項目にする

In [None]:
model_name = "cohere.command-r-plus"
is_stream = True

model_args = {
    "temperature": 0.3,
    "max_tokens": 1024,
    "top_p": 0.75,
    "top_k": 0,
    "frequency_penalty": 0,
    "presence_penalty": 0
}

In [None]:
with oracledb.connect(
    user=un,
    password=pw,
    dsn=dsn,
    config_dir=config_dir,
    wallet_location=wallet_location,
    wallet_password=wallet_password
) as connection:
    embeddings = OCIGenAIEmbeddings(
        auth_type="INSTANCE_PRINCIPAL",
        model_id="cohere.embed-multilingual-v3.0",
        service_endpoint=service_endpoint,
        compartment_id=compartment_id,
    )
    oracle_vs = OracleVS(
        client=connection,
        embedding_function=embeddings,
        table_name=table_name,
        distance_strategy=DistanceStrategy.COSINE,
        query="What is a Oracle Database"
    )
    # Vector Store(Oracle Database 23ai)を Retriever として使用する
    retriever = oracle_vs.as_retriever()
    chat = ChatOCIGenAI(
        auth_type="INSTANCE_PRINCIPAL",
        service_endpoint=service_endpoint,
        compartment_id=compartment_id,
        model_id=model_name,
        is_stream=is_stream,
        model_kwargs=model_args
    )
    # Langfuse 管理下のプロンプトを取得する
    chat_prompt = ChatPromptTemplate(
        messages=langfuse.get_prompt(name="ochat-prompt-with-tools", type="chat").get_langchain_prompt(),
    )
    # Simple な RAG Chain
    chain = (
        {"query": RunnablePassthrough(), "context": retriever}
        | chat_prompt
        | chat
        | StrOutputParser()
    )
    
    res = chain.stream(
        "OCHaCafeってなんですか？",
        config={"callbacks": [langfuse_handler]},
    )
    
    for chunk in res:
        print(chunk, end="")