### Retrieve Image

In [None]:
import os
from langchain_core.documents import Document
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever
from typing import List
import oracledb
import pandas as pd
import base64
from dotenv import find_dotenv, load_dotenv

from langchain_community.embeddings import OCIGenAIEmbeddings
from langchain_community.chat_models import ChatOCIGenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, SystemMessage

_ = load_dotenv(find_dotenv())
oracledb.init_oracle_client()

UN = os.getenv("UN")
PW = os.getenv("PW")
DSN = os.getenv("DSN")
OCI_COMPARTMENT_ID = os.getenv("OCI_COMPARTMENT_ID")

In [None]:
# Utils
def get_embedding(text: str) -> list:
  embeddings = OCIGenAIEmbeddings(
    model_id="cohere.embed-multilingual-v3.0",
    service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
    compartment_id=OCI_COMPARTMENT_ID,
  )
  return embeddings.embed_query(text)

def chat_with_image(image_path: str, prompt: str, system_prompt: str = None) -> str:
  with open("/home/opc/multimodal_oci_genai/" + image_path, "rb") as img_file:
    image_data = base64.b64encode(img_file.read()).decode("utf-8")
  
  prompt_with_image = [
    SystemMessage(
        content=system_prompt
    ),
    HumanMessage(
        content=[
            {"type": "text", "text": prompt},
            {
              "type": "image_url",
              "image_url": {
                "url": "data:image/png;base64,"+image_data,
            }
        },
        ]
      )
  ]

  llm = ChatOCIGenAI(
      model_id="meta.llama-3.2-90b-vision-instruct",
      service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
      compartment_id=OCI_COMPARTMENT_ID,
      )
  result = llm.invoke(prompt_with_image)
  print(f"Result: {result}") 
  return result.content

In [None]:
class CustomImageRetriever(BaseRetriever):
    """
    Custom image retriever.
    """

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        docs: List[Document] = []
        embed_query = str(get_embedding(query))
        try:
            with oracledb.connect(user=UN, password=PW, dsn=DSN) as connection:
                with connection.cursor() as cursor:
                    cursor.setinputsizes(oracledb.DB_TYPE_VECTOR)
                    select_sql = f"""
                        SELECT
                            file_id,
                            image_path,
                            summary
                        FROM
                            image_contents
                        ORDER BY VECTOR_DISTANCE(embedding, to_vector(:1, 1024, FLOAT32), COSINE)
                        FETCH FIRST 3 ROWS ONLY
                    """
                    cursor.execute(select_sql, [embed_query])
                    index = 1
                    for row in cursor:
                        doc = Document(
                            page_content=row[2],
                            metadata={
                                'file_id':row[0], 
                                'file_path': row[1], 
                                'vector_index': index
                                }
                            )
                        docs.append(doc)
                        index += 1
                connection.close()
                        
        except oracledb.DatabaseError as e:
            print(f"Database error: {e}")
            raise
        except Exception as e:
            print("Error Vector Search:", e)
        
        return docs

### Retrieve Image & Answer

In [None]:
def get_text_by_image(query: str):
    """
    Text to Image
    """
    llm = ChatOCIGenAI(
        model_id="cohere.command-r-08-2024",
        service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
        compartment_id=OCI_COMPARTMENT_ID,
        )
    
    prompt = ChatPromptTemplate([
        ("system", "あなたは質疑応答のAIアシスタントです。必ず日本語で答えてください。"),
        ("human", "{query} 以下のコンテキストに基づいて答えてください。{context}"),
    ])

    retriever = CustomImageRetriever()
    chain = {'query': RunnablePassthrough(), 'context': retriever} | prompt | llm | StrOutputParser()

    result_images = chain.invoke(query)
    print(result_images)
    
    file_id = result_images[0].metadata['file_id']
    try:
        with oracledb.connect(user=UN, password=PW, dsn=DSN) as connection:
            with connection.cursor() as cursor:
                cursor.setinputsizes(oracledb.DB_TYPE_VECTOR)
                select_sql = f"""
                    SELECT
                        image_blob
                    FROM
                        image_contents
                    WHERE file_id = :1
                """
                cursor.execute(select_sql, [file_id])
                blob, = cursor.fetchone()
                offset = 1
                bufsize = 65536
                with open('tmp_image.png', 'wb') as f:
                    while True:
                        data = blob.read(offset, bufsize)
                        if data:
                            f.write(data)
                        if len(data) < bufsize:
                            break
                        offset += bufsize
            connection.close()
                    
    except oracledb.DatabaseError as e:
        print(f"Database error: {e}")
        raise
    except Exception as e:
        print("Error Vector Search:", e)
    
    return result_images

In [None]:
def get_text_by_text_with_image(question: str):
    
    prompt = ChatPromptTemplate([
        ("system", "あなたは言語翻訳のAIアシスタントです。日本語を英語に翻訳してください。"),
        ("human", "{input} "),
    ])
    
    llm = ChatOCIGenAI(
        model_id="cohere.command-r-08-2024",
        service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
        compartment_id=OCI_COMPARTMENT_ID,
    )
    chain = {'input': RunnablePassthrough()} | prompt | llm | StrOutputParser()
    question_en = chain.invoke({"input":question})
    
    retriever = CustomImageRetriever()
    result_images = retriever.invoke(question)
    print(result_images)
    
    image_path = result_images[0].metadata['file_path']

    res = chat_with_image(
        image_path=image_path,
        prompt=question_en,
        system_prompt="You are a AI assistant. Please answer the question based on the image."
    )
    return res