# Query Transform을 통한 First-Stage Retrieve 단계 성능 개선
* 기존의 basic + custom 데이터셋을 통해 최적화한 결과, 목표 성능(recall: 0.8)에 도달
    * k: 20, alpha: 40, morphological_analyzer: bm25_kiki_pos, score_threshold: 0.1
    * 위의 조합에서 recall: 0.815476 달성
* rag 시스템 연구를 위해 k: 10 조합을 기준으로 Query Transform 관련 실험 진행
* 아래의 5 가지 기법을 활용하여 input query를 변경 후, first-stage retrieve의 성능 개선 도모
    * prompt decomposition
    * query rewrite
    * query rewrite + prompt decomposition
    * query2doc
    * query expansion

### 결과
* 실험 결과 transform을 통해 특정 문제 유형에 대한 성능 향상이 가능하지만, 그 외 유형에 부정적인 결과를 초래
* 결과적으로 전체 유형에 대한 평가 지표가 original query에 대한 최적화 성능보다 떨어짐

In [1]:
import os
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

from typing import List
from pydantic import BaseModel, Field

In [2]:
import sys
sys.path.append('../code/graphParser')
sys.path.append('../code/ragas_custom')
from rateLimit import handle_rate_limits
from retrieve.sparse import BM25
from retrieve.config import generate_retriever_configs
from evaluation.retrieve import optimization, combine_hybrid_results


from langchain_core.prompts import load_prompt
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document

from ragas.testset.graph import KnowledgeGraph


from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [5]:
merged_dataset = pd.read_csv('../data/rag/valid_dataset.csv')
merged_dataset['reference_contexts'] = merged_dataset['reference_contexts'].apply(lambda x : eval(x))
merged_dataset = merged_dataset.iloc[:, :4]
# merged_dataset.to_csv('../data/rag/first_retrieve.csv', index=False)

In [3]:
merged_dataset = pd.read_csv('../data/rag/first_retrieve.csv')
merged_dataset['reference_contexts'] = merged_dataset['reference_contexts'].apply(lambda x : eval(x))

In [4]:
kg = KnowledgeGraph.load('../data/rag/kg.json')

documents = [Document(page_content=node.properties['page_content'],
                      metadata=node.properties['document_metadata'])
                       for node in kg.nodes]

In [5]:
configs = generate_retriever_configs(
    k_values=[10], 
    analyzers=["bm25_kiwi", "bm25_kiwi_pos"],
    hybrid_alphas=[20, 40, 60, 80],
    fetch_k=[],
    lambda_mult=[],
    score_threshold=[0.1, 0.2, 0.3, 0.4]
)
configs = configs[6:]

In [6]:
kiwi_pos = BM25(k=10, type='kiwi_pos')
kiwi = BM25(k=10, type='kiwi')
texts = [node.properties['page_content'] for node in kg.nodes]
kiwi.from_texts(texts)
kiwi_pos.from_texts(texts)

embeddings = OpenAIEmbeddings()
db = FAISS.from_documents(documents, embeddings)

In [7]:
def retrieve_for_list(query_list, retriever, k):
    results = []
    base_k = k // len(query_list)
    remainder = k % len(query_list)
    
    for i, query in enumerate(query_list):
        if i == len(query_list) - 1 and remainder > 0:
            curr_k = base_k + remainder
        else:
            curr_k = base_k
            
        if hasattr(retriever, 'similarity_search_with_score'):
            search_results = retriever.similarity_search_with_score(query, k=curr_k)
        else:
            search_results = retriever.search(query)[:curr_k]
            
        results.extend(search_results)
        
    return results

def precompute_list(column, df, db, pos, k=15):        
    df['precompute_dense'] = df[column].progress_apply(lambda x: retrieve_for_list(x, db, k))
    df['precompute_sparse_bm25_kiwi'] = df[column].progress_apply(lambda x: retrieve_for_list(x, kiwi, k))
    df['precompute_sparse_bm25_kiwi_pos'] = df[column].progress_apply(lambda x: retrieve_for_list(x, pos, k))
    
    return df

## 1. Query Decomposition

In [8]:
@handle_rate_limits
def decomposition_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class DecompositionOutput(BaseModel):
        sub_questions: List[str] = Field(..., description="List of decomposed sub-questions in Korean")

    parser = PydanticOutputParser(pydantic_object=DecompositionOutput)
    prompt = load_prompt('../prompt/transform/decomposition.yaml')
    prompt = prompt.partial(format=parser.get_format_instructions())

    
    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    decomposition_chain = prompt | llm 

    results = decomposition_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [parser.parse(result.content).sub_questions for result in results]

    return results

In [10]:
max_concurrency = 5
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    results.extend(decomposition_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    results.extend(decomposition_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query_decomposition'] = results
merged_dataset.to_csv('../data/rag/first_retrieve.csv', index=False)

  llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
Failed to multipart ingest runs: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=b4ffad3f-9e19-4194-b286-a9d032881d11,id=b4ffad3f-9e19-4194-b286-a9d032881d11; trace=b4ffad3f-9e19-4194-b286-a9d032881d11,id=8f440ab6-be55-4a1b-9bb3-179dfb42dca8; trace=b4ffad3f-9e19-4194-b286-a9d032881d11,id=c62fab11-c5aa-4607-8215-ab6977444e22; trace=ef2bb0cc-198f-4421-89d5-6f738cc66721,id=ef2bb0cc-198f-4421-89d5-6f738cc66721; trace=ef2bb0cc-198f-4421-89d5-6f738cc66721,id=a3b7a119-5359-428c-a275-8119228a16c2; trace=ef2bb0cc-198f-4421-89d5-6f738cc66721,id=7b51699f-cf5b-4b5e-9015-0f3179494368; trace=9ad33790-5e57-4dfb-a3ec-307dc7d919bc,id=9ad33790-5e5

Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=17c41ee7-8246-4352-9266-9ce7e444e8b2,id=a54b1b84-981e-40dd-871a-28249c6c37ec; trace=17c41ee7-8246-4352-9266-9ce7e444e8b2,id=17c41ee7-8246-4352-9266-9ce7e444e8b2
Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=1a4711f4-694d-413f-abf0-2c2cecde89a3,id=1a4711f4-694d-413f-abf0-2c2cecde89a3; trace=c0b

In [11]:
decompo_df = merged_dataset.copy()
decompo_df = decompo_df[decompo_df.columns[:4].to_list() + ['query_decomposition']]
decompo_df = precompute_list('query_decomposition', decompo_df, db, kiwi_pos)
decompo_results = optimization(configs, decompo_df)

100%|██████████| 56/56 [01:25<00:00,  1.52s/it]
100%|██████████| 56/56 [00:00<00:00, 256.86it/s]
100%|██████████| 56/56 [00:00<00:00, 398.55it/s]
100%|██████████| 32/32 [00:00<00:00, 229.15it/s]


In [12]:
decompo_results.sort_values(by=['recall', 'map', 'ndcg'], ascending=False).head(10)

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
28,10,80,threshold,bm25_kiwi_pos,,,0.1,0.429981,0.604167,0.434729,0.6,0.4,0.4375,0.702381,0.75,0.75
4,10,20,threshold,bm25_kiwi_pos,,,0.1,0.481304,0.60119,0.504578,0.5,0.5,0.4375,0.690476,0.75,0.75
7,10,20,threshold,bm25_kiwi_pos,,,0.4,0.477779,0.60119,0.500935,0.5,0.55,0.4375,0.654762,0.75,0.75
15,10,40,threshold,bm25_kiwi_pos,,,0.4,0.477779,0.60119,0.500935,0.5,0.55,0.4375,0.654762,0.75,0.75
23,10,60,threshold,bm25_kiwi_pos,,,0.4,0.477779,0.60119,0.500935,0.5,0.55,0.4375,0.654762,0.75,0.75
31,10,80,threshold,bm25_kiwi_pos,,,0.4,0.477779,0.60119,0.500935,0.5,0.55,0.4375,0.654762,0.75,0.75
29,10,80,threshold,bm25_kiwi_pos,,,0.2,0.41519,0.599702,0.411451,0.6,0.4,0.4375,0.684524,0.75,0.75
20,10,60,threshold,bm25_kiwi_pos,,,0.1,0.445601,0.590774,0.473427,0.6,0.45,0.4375,0.613095,0.75,0.75
21,10,60,threshold,bm25_kiwi_pos,,,0.2,0.434491,0.590774,0.453387,0.6,0.45,0.4375,0.613095,0.75,0.75
5,10,20,threshold,bm25_kiwi_pos,,,0.2,0.474108,0.583333,0.514101,0.45,0.5,0.4375,0.654762,0.75,0.75


## 2. Rewrite-Retrieve-Read

In [13]:
@handle_rate_limits
def rewrite_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    prompt = load_prompt('../prompt/transform/rewrite_domain.yaml')
    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    rewrite_chain = prompt | llm | StrOutputParser()

    results = rewrite_chain.batch(data_batches, config={"max_concurrency": max_concurrency})

    return results

In [14]:
max_concurrency = 5
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    results.extend(rewrite_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    results.extend(rewrite_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query_rewrite'] = results
merged_dataset.to_csv('../data/rag/first_retrieve.csv', index=False)

100%|██████████| 11/11 [00:27<00:00,  2.51s/it]


In [15]:
rewrite_df = merged_dataset.copy()
rewrite_df = rewrite_df[rewrite_df.columns[:4].to_list() + ['query_rewrite']]
rewrite_df['precompute_dense'] = rewrite_df['query_rewrite'].progress_apply(lambda x: db.similarity_search_with_score(x, k=15))
rewrite_df['precompute_sparse_bm25_kiwi'] = rewrite_df['query_rewrite'].apply(lambda x : kiwi.search(x))
rewrite_df['precompute_sparse_bm25_kiwi_pos'] = rewrite_df['query_rewrite'].apply(lambda x : kiwi_pos.search(x))
rewrite_results = optimization(configs, rewrite_df)

100%|██████████| 56/56 [00:24<00:00,  2.26it/s]
100%|██████████| 32/32 [00:00<00:00, 229.55it/s]


In [16]:
rewrite_results.sort_values(by=['recall', 'map', 'ndcg'], ascending=False).head(10)

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
20,10,60,threshold,bm25_kiwi_pos,,,0.1,0.518104,0.645833,0.553026,0.7,0.55,0.5,0.511905,0.85,1.0
28,10,80,threshold,bm25_kiwi_pos,,,0.1,0.503466,0.645833,0.531417,0.7,0.55,0.5,0.511905,0.85,1.0
21,10,60,threshold,bm25_kiwi_pos,,,0.2,0.464989,0.623512,0.47669,0.7,0.55,0.5,0.494048,0.85,0.75
12,10,40,threshold,bm25_kiwi_pos,,,0.1,0.516721,0.619048,0.564626,0.65,0.45,0.5,0.547619,0.8,1.0
29,10,80,threshold,bm25_kiwi_pos,,,0.2,0.438696,0.614583,0.445507,0.7,0.55,0.5,0.494048,0.8,0.75
24,10,80,threshold,bm25_kiwi,,,0.1,0.485679,0.599702,0.536997,0.7,0.45,0.4375,0.470238,0.8,1.0
16,10,60,threshold,bm25_kiwi,,,0.1,0.492917,0.592262,0.567049,0.65,0.45,0.4375,0.440476,0.85,1.0
13,10,40,threshold,bm25_kiwi_pos,,,0.2,0.486111,0.592262,0.530201,0.65,0.45,0.5,0.511905,0.8,0.75
4,10,20,threshold,bm25_kiwi_pos,,,0.1,0.472933,0.592262,0.50496,0.55,0.45,0.5,0.511905,0.8,1.0
17,10,60,threshold,bm25_kiwi,,,0.2,0.449747,0.583333,0.48969,0.65,0.45,0.4375,0.404762,0.85,1.0


## 3. Query2Doc

In [17]:
@handle_rate_limits
def query2Doc_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    prompt = load_prompt('../prompt/transform/query2Doc.yaml')
    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    query2Doc_chain = prompt | llm | StrOutputParser()

    results = query2Doc_chain.batch(data_batches, config={"max_concurrency": max_concurrency})

    return results

In [18]:
max_concurrency = 2
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    results.extend(query2Doc_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    results.extend(query2Doc_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query2Doc'] = results
merged_dataset.to_csv('../data/rag/first_retrieve.csv', index=False)

100%|██████████| 28/28 [01:44<00:00,  3.72s/it]


In [19]:
query2Doc_df = merged_dataset.copy()
query2Doc_df = query2Doc_df[query2Doc_df.columns[:4].to_list() + ['query2Doc']]
query2Doc_df['precompute_dense'] = query2Doc_df['query2Doc'].progress_apply(lambda x: db.similarity_search_with_score(x, k=15))
query2Doc_df['precompute_sparse_bm25_kiwi'] = query2Doc_df['query2Doc'].apply(lambda x : kiwi.search(x))
query2Doc_df['precompute_sparse_bm25_kiwi_pos'] = query2Doc_df['query2Doc'].apply(lambda x : kiwi_pos.search(x))
query2Doc_results = optimization(configs, query2Doc_df)

100%|██████████| 56/56 [00:29<00:00,  1.92it/s]
100%|██████████| 32/32 [00:00<00:00, 220.90it/s]


In [20]:
query2Doc_results.sort_values(by=['recall', 'map', 'ndcg'], ascending=False).head(10)

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
20,10,60,threshold,bm25_kiwi_pos,,,0.1,0.513968,0.659226,0.530938,0.7,0.4,0.625,0.708333,0.7,1.0
12,10,40,threshold,bm25_kiwi_pos,,,0.1,0.523184,0.659226,0.530754,0.7,0.5,0.625,0.636905,0.7,1.0
28,10,80,threshold,bm25_kiwi_pos,,,0.1,0.478755,0.630952,0.520359,0.7,0.4,0.5625,0.666667,0.65,1.0
4,10,20,threshold,bm25_kiwi_pos,,,0.1,0.507927,0.619048,0.537727,0.65,0.4,0.5625,0.619048,0.7,1.0
5,10,20,threshold,bm25_kiwi_pos,,,0.2,0.49137,0.596726,0.527863,0.65,0.45,0.5625,0.565476,0.7,0.75
21,10,60,threshold,bm25_kiwi_pos,,,0.2,0.386143,0.596726,0.367829,0.65,0.4,0.625,0.636905,0.7,0.5
24,10,80,threshold,bm25_kiwi,,,0.1,0.464203,0.595238,0.519558,0.65,0.35,0.5625,0.666667,0.55,1.0
7,10,20,threshold,bm25_kiwi_pos,,,0.4,0.485863,0.587798,0.51909,0.7,0.4,0.5,0.565476,0.7,0.75
15,10,40,threshold,bm25_kiwi_pos,,,0.4,0.485863,0.587798,0.51909,0.7,0.4,0.5,0.565476,0.7,0.75
23,10,60,threshold,bm25_kiwi_pos,,,0.4,0.485863,0.587798,0.51909,0.7,0.4,0.5,0.565476,0.7,0.75


## 4. Rewrite Retrieve Read + Query Decomposition 조합

In [10]:
max_concurrency = 5
input_querys = merged_dataset['query_rewrite'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    results.extend(decomposition_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    results.extend(decomposition_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query_rewrite_decomposition'] = results
merged_dataset.to_csv('../data/rag/first_retrieve.csv', index=False)

  llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
Failed to multipart ingest runs: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=1d84f83f-c5f5-44d2-b9d5-bebafac50c26,id=1d84f83f-c5f5-44d2-b9d5-bebafac50c26; trace=1d84f83f-c5f5-44d2-b9d5-bebafac50c26,id=c531e607-282b-4c56-a987-628994db54f8; trace=1d84f83f-c5f5-44d2-b9d5-bebafac50c26,id=9fe630db-329f-4751-baf3-0c7eb943f86f; trace=126b09c2-c076-4c01-a559-4305c4e3d223,id=126b09c2-c076-4c01-a559-4305c4e3d223; trace=126b09c2-c076-4c01-a559-4305c4e3d223,id=ac6ee521-e746-4699-b7b5-0ae0470e6714; trace=126b09c2-c076-4c01-a559-4305c4e3d223,id=f371fb68-054b-45c4-a417-05d7c06268c8; trace=491711ad-e393-401c-a577-495083a68c6f,id=491711ad-e39

Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=c584f720-adbe-4985-9298-a222679bd393,id=733c0017-585d-42c8-a366-af196bab8a10; trace=c584f720-adbe-4985-9298-a222679bd393,id=c584f720-adbe-4985-9298-a222679bd393
Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=b20c5b28-6e08-4d47-870f-26549fe0ab92,id=b20c5b28-6e08-4d47-870f-26549fe0ab92; trace=2d6

In [11]:
rewrite_decompo_df = merged_dataset.copy()
rewrite_decompo_df = rewrite_decompo_df[rewrite_decompo_df.columns[:4].to_list() + ['query_rewrite_decomposition']]
rewrite_decompo_df = precompute_list('query_rewrite_decomposition', rewrite_decompo_df, db, kiwi_pos)
rewrite_decompo_results = optimization(configs, rewrite_decompo_df)

100%|██████████| 56/56 [01:40<00:00,  1.79s/it]
100%|██████████| 56/56 [00:00<00:00, 248.20it/s]
100%|██████████| 56/56 [00:00<00:00, 355.77it/s]
100%|██████████| 32/32 [00:00<00:00, 217.15it/s]


In [12]:
rewrite_decompo_results.sort_values(by=['recall', 'map', 'ndcg'], ascending=False).head(10)

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
20,10,60,threshold,bm25_kiwi_pos,,,0.1,0.469599,0.645833,0.481548,0.6,0.6,0.5,0.654762,0.8,0.75
21,10,60,threshold,bm25_kiwi_pos,,,0.2,0.459447,0.645833,0.462415,0.6,0.6,0.5,0.654762,0.8,0.75
12,10,40,threshold,bm25_kiwi_pos,,,0.1,0.48939,0.63244,0.516454,0.55,0.6,0.5,0.672619,0.75,0.75
28,10,80,threshold,bm25_kiwi_pos,,,0.1,0.430422,0.630952,0.433585,0.65,0.5,0.375,0.702381,0.8,0.75
13,10,40,threshold,bm25_kiwi_pos,,,0.2,0.485458,0.627976,0.51318,0.55,0.6,0.5,0.654762,0.75,0.75
24,10,80,threshold,bm25_kiwi,,,0.1,0.398258,0.626488,0.382731,0.65,0.55,0.375,0.684524,0.75,0.75
29,10,80,threshold,bm25_kiwi_pos,,,0.2,0.420433,0.622024,0.423739,0.65,0.5,0.375,0.666667,0.8,0.75
25,10,80,threshold,bm25_kiwi,,,0.2,0.394374,0.622024,0.379854,0.65,0.55,0.375,0.666667,0.75,0.75
4,10,20,threshold,bm25_kiwi_pos,,,0.1,0.484467,0.614583,0.519211,0.5,0.55,0.5,0.672619,0.75,0.75
5,10,20,threshold,bm25_kiwi_pos,,,0.2,0.484317,0.614583,0.518892,0.5,0.55,0.5,0.672619,0.75,0.75


## 5. Query Expansion

In [13]:
@handle_rate_limits
def queryExpansion_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class QueryExpansionOutput(BaseModel):
        expansioned_queries: List[str] = Field(..., description="List of expansioned queries in Korean")
    
    parser = PydanticOutputParser(pydantic_object=QueryExpansionOutput)
    prompt = load_prompt('../prompt/transform/query_expandsion.yaml')
    prompt = prompt.partial(format=parser.get_format_instructions())

    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    queryExpansion_chain = prompt | llm 

    results = queryExpansion_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [parser.parse(result.content).expansioned_queries for result in results]

    return results

In [14]:
max_concurrency = 5
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    results.extend(queryExpansion_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    results.extend(queryExpansion_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query_expansion'] = results
merged_dataset.to_csv('../data/rag/first_retrieve.csv', index=False)

100%|██████████| 11/11 [00:32<00:00,  2.96s/it]


In [15]:
queryExpansion_df = merged_dataset.copy()
queryExpansion_df = queryExpansion_df[queryExpansion_df.columns[:4].to_list() + ['query_expansion']]
queryExpansion_df = precompute_list('query_expansion', queryExpansion_df, db, kiwi_pos)
queryExpansion_results = optimization(configs, queryExpansion_df)

100%|██████████| 56/56 [01:53<00:00,  2.02s/it]
100%|██████████| 56/56 [00:00<00:00, 218.39it/s]
100%|██████████| 56/56 [00:00<00:00, 327.87it/s]
100%|██████████| 32/32 [00:00<00:00, 223.76it/s]


In [16]:
queryExpansion_results.sort_values(by=['recall', 'map', 'ndcg'], ascending=False).head(10)

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
21,10,60,threshold,bm25_kiwi_pos,,,0.2,0.429664,0.552083,0.456037,0.6,0.4,0.5625,0.529762,0.6,0.75
20,10,60,threshold,bm25_kiwi_pos,,,0.1,0.423226,0.547619,0.452664,0.6,0.4,0.5625,0.511905,0.6,0.75
29,10,80,threshold,bm25_kiwi_pos,,,0.2,0.398926,0.543155,0.408121,0.6,0.35,0.5625,0.529762,0.6,0.75
13,10,40,threshold,bm25_kiwi_pos,,,0.2,0.423826,0.540179,0.448795,0.55,0.4,0.4375,0.517857,0.7,0.75
22,10,60,threshold,bm25_kiwi_pos,,,0.3,0.397323,0.540179,0.407051,0.55,0.45,0.4375,0.553571,0.6,0.75
28,10,80,threshold,bm25_kiwi_pos,,,0.1,0.406656,0.53869,0.424582,0.6,0.35,0.5625,0.511905,0.6,0.75
12,10,40,threshold,bm25_kiwi_pos,,,0.1,0.421811,0.535714,0.449986,0.55,0.4,0.4375,0.5,0.7,0.75
14,10,40,threshold,bm25_kiwi_pos,,,0.3,0.42076,0.53125,0.447534,0.5,0.45,0.4375,0.553571,0.6,0.75
24,10,80,threshold,bm25_kiwi,,,0.1,0.350387,0.525298,0.347548,0.65,0.3,0.5625,0.422619,0.65,0.75
4,10,20,threshold,bm25_kiwi_pos,,,0.1,0.413917,0.522321,0.437961,0.55,0.45,0.375,0.517857,0.6,0.75
