In [1]:
import os
from dotenv import load_dotenv
import pandas as pd
import json
from datetime import datetime
from markdown import markdown
from IPython.display import Markdown, display, HTML
import warnings
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)
load_dotenv()

True

In [2]:
os.chdir(os.path.dirname(os.getcwd()))

In [3]:
from src.embedding_models.models import OpenAIEmbeddingsConfig, OpenAIEmbeddings
from src.embedding_models.models import ColbertReranker
from src.db.lancedb import LanceDBConfig, LanceDB

In [4]:
db_config = LanceDBConfig(
    collection_name="lance-citations",
    replace_collection=False,
    embedding = OpenAIEmbeddingsConfig(),
)
db = LanceDB(
    config=db_config,
)

Non-empty Collection lance-citations already exists
Not replacing collection


In [5]:
db.list_collections()

['lance-citations', 'reddit-legal', 'test-opinions']

In [6]:
table = db.client.open_table("lance-citations")

In [7]:
table.to_pandas().shape

(2995, 4)

In [8]:
table.to_pandas().head()

Unnamed: 0,id,vector,content,metadata
0,3987a43c-95d3-6461-a97a-f9e7c79d2d73,"[-0.01016574, -0.001256924, 0.0021155635, -0.0...",351 F.3d 1229 \n RECORDING INDUSTRY ASSOCIAT...,"{'source': 'context', 'is_chunk': True, 'id': ..."
1,2ad8e729-a1be-6cbb-bb90-d4dd84fd94b9,"[-0.0087040225, -0.006643959, -0.0086300885, -...","GINSBURG, Chief Judge: \n This case concer...","{'source': 'context', 'is_chunk': True, 'id': ..."
2,8605964f-ba70-81df-a37f-3c67bca750dc,"[-0.0070057022, -0.029567793, 0.009822634, -0....","Napster, Inc., 239 F.3d 1004 (9th Cir.2001),...","{'source': 'context', 'is_chunk': True, 'id': ..."
3,3e03d7ea-fe4d-c746-a611-6879bb8d5fc7,"[-0.019195106, -0.012144928, -0.011466515, -0....",The RIAA has used the subpoena provisions of §...,"{'source': 'context', 'is_chunk': True, 'id': ..."
4,728d7a22-51ac-e8b6-9f46-40d03ee67c8e,"[-0.01852692, -0.011990808, 0.00037747578, -0....","On July 24, 2002 the RIAA served Verizon with ...","{'source': 'context', 'is_chunk': True, 'id': ..."


In [5]:
import enum
import logging
import instructor

from typing import List
from openai import OpenAI
from pydantic import Field, BaseModel
from pydantic_settings import BaseSettings

from src.llm.basemodel import OpenAIConfig, OpenAIGPT
from src.embedding_models.models import OpenAIEmbeddingsConfig, OpenAIEmbeddings, ColbertReranker
from src.db.lancedb import LanceDBConfig, LanceDB


logger = logging.getLogger(__name__)

client = instructor.patch(OpenAI())

class Question(BaseModel):
    id: int = Field(..., description="A unique identifier for the question")
    query: str = Field(..., description="The question decomposited as much as possible")
    subquestions: List[int] = Field(
        default_factory=list,
        description="The subquestions that this question is composed of",
    )


class QueryPlan(BaseModel):
    
    root_question: str = Field(
        ...,
        description="The root question that the user asked")
    plan: List[Question] = Field(
        ..., 
        description="The plan to answer the root question and its subquestions"
    )


retrieval = client.chat.completions.create(
    model="gpt-4-1106-preview",
    response_model=QueryPlan,
    messages=[
        {
            "role": "system",
            "content": "You are a query understanding system capable of decomposing a question into subquestions.",
        },
        {
            "role": "user",
            "content": "What is the difference between the population of jason's home country and canada?",
        },
    ],
)

print(retrieval.model_dump_json(indent=4))

{
    "root_question": "What is the difference between the population of jason's home country and canada?",
    "plan": [
        {
            "id": 1,
            "query": "What is the population of Jason's home country?",
            "subquestions": []
        },
        {
            "id": 2,
            "query": "What is the population of Canada?",
            "subquestions": []
        },
        {
            "id": 3,
            "query": "What is the difference in population between two countries?",
            "subquestions": [
                1,
                2
            ]
        }
    ]
}
