# Modular Rag를 통한 First-Stage Retrieve 단계 성능 개선


In [None]:
import os
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

from typing import List
from pydantic import BaseModel, Field

import sys

  from .autonotebook import tqdm as notebook_tqdm


True

In [None]:
sys.path.append('../code/graphParser')
sys.path.append('../code/ragas_custom')
from rateLimit import handle_rate_limits
from retrieve.sparse import BM25
from retrieve.config import generate_retriever_configs
from evaluation.retrieve import optimization


from langchain_core.prompts import load_prompt
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

from ragas.testset.graph import KnowledgeGraph


from dotenv import load_dotenv
load_dotenv()

In [2]:
origin_path = '../data/document/역도/generated_persona_dataset.csv'
custom_path = '../data/document/역도/customDataset.csv'

origin_dataset = pd.read_csv(origin_path)
custom_dataset = pd.read_csv(custom_path)

origin_dataset['reference_contexts'] = origin_dataset['reference_contexts'].apply(lambda x : eval(x))
custom_dataset['reference_contexts'] = custom_dataset['reference_contexts'].apply(lambda x : eval(x))

merged_dataset = pd.concat([origin_dataset, custom_dataset], ignore_index=True)
merged_dataset = merged_dataset.iloc[:, :4]

In [3]:
from langchain_core.documents import Document

kg = KnowledgeGraph.load('../data/document/역도/kg.json')

documents = [Document(page_content=node.properties['page_content'],
                      metadata=node.properties['document_metadata'])
                       for node in kg.nodes]

## 1. Prompt Transform
* 아래의 5 가지 기법을 활용하여 input query를 변경 후, first-stage retrieve의 성능 개선 도모
    * prompt decomposition
    * query rewrite
    * query rewrite + prompt decomposition
    * query2doc
    * query expansion

### 결과
* 실험 결과 transform을 통해 특정 문제 유형에 대한 성능 향상이 가능하지만, 그 외 유형에 부정적인 결과를 초래
* 결과적으로 전체 유형에 대한 평가 지표가 original query에 대한 최적화 성능보다 떨어짐

In [4]:
configs = generate_retriever_configs(
    k_values=[15], 
    analyzers=["bm25_kiwi"],
    hybrid_alphas=[20, 40, 60, 80],
    fetch_k=[1.5, 2, 2.5],
    lambda_mult=[],
    score_threshold=[0.3, 0.5, 0.7]
)
configs = configs[4:]

In [5]:
kiwi = BM25(k=15, type='kiwi')
texts = [node.properties['page_content'] for node in kg.nodes]
kiwi.from_texts(texts)

embeddings = OpenAIEmbeddings()
db = FAISS.from_documents(documents, embeddings)

In [6]:
def retrieve_for_list(query_list, retriever, k):
    results = []
    base_k = k // len(query_list)
    remainder = k % len(query_list)
    
    for i, query in enumerate(query_list):
        if i == len(query_list) - 1 and remainder > 0:
            curr_k = base_k + remainder
        else:
            curr_k = base_k
            
        if hasattr(retriever, 'similarity_search_with_score'):
            search_results = retriever.similarity_search_with_score(query, k=curr_k)
        else:
            search_results = retriever.search(query)[:curr_k]
            
        results.extend(search_results)
        
    return results

def precompute_list(column, df, db, kiwi, k=15):        
    df['precompute_dense'] = df[column].progress_apply(lambda x: retrieve_for_list(x, db, k))
    df['precompute_sparse_bm25_kiwi'] = df[column].progress_apply(lambda x: retrieve_for_list(x, kiwi, k))
    
    return df

In [7]:
merged_dataset = pd.read_csv('../data/document/역도/first-retrieve.csv')
for column in ['reference_contexts', 'query_decomposition', 'query_rewrite_decomposition']:
    merged_dataset[column] = merged_dataset[column].apply(lambda x : eval(x))

### 1-1. Query Decomposition

In [7]:
@handle_rate_limits
def decomposition_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class DecompositionOutput(BaseModel):
        sub_questions: List[str] = Field(..., description="List of decomposed sub-questions in Korean")

    parser = PydanticOutputParser(pydantic_object=DecompositionOutput)
    prompt = load_prompt('../prompt/firtst-stage retrieve/decomposition.yaml')
    prompt = prompt.partial(format=parser.get_format_instructions())

    
    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    decomposition_chain = prompt | llm 

    results = decomposition_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [parser.parse(result.content).sub_questions for result in results]

    return results

In [19]:
max_concurrency = 5
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    # results.extend(decomposition_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    # results.extend(decomposition_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
# merged_dataset['query_decomposition'] = results
# merged_dataset.to_csv('../data/document/역도/first-retrieve.csv', index=False)

100%|██████████| 22/22 [02:29<00:00,  6.79s/it]


In [9]:
decompo_df = merged_dataset.copy()
decompo_df = decompo_df[decompo_df.columns[:4].to_list() + ['query_decomposition']]
decompo_df = precompute_list('query_decomposition', decompo_df, db, kiwi)
decompo_results = optimization(configs, decompo_df)

100%|██████████| 112/112 [02:05<00:00,  1.12s/it]
100%|██████████| 112/112 [00:00<00:00, 259.28it/s]
100%|██████████| 12/12 [00:00<00:00, 110.52it/s]


In [10]:
decompo_results.sort_values(by=['recall', 'map', 'ndcg'], ascending=False)

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part","recall_exercise name, exercise phase",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
0,15,20,threshold,bm25_kiwi,,,0.3,0.334008,0.556335,0.28551,0.461538,0.653846,0.636364,0.32,0.390476,0.933333,0.733333
3,15,40,threshold,bm25_kiwi,,,0.3,0.331565,0.552764,0.283218,0.423077,0.653846,0.636364,0.34,0.370476,0.933333,0.733333
6,15,60,threshold,bm25_kiwi,,,0.3,0.321201,0.550978,0.270761,0.461538,0.653846,0.545455,0.34,0.360476,0.966667,0.733333
9,15,80,threshold,bm25_kiwi,,,0.3,0.298001,0.546514,0.242043,0.461538,0.653846,0.5,0.36,0.360476,0.933333,0.733333
1,15,20,threshold,bm25_kiwi,,,0.5,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
2,15,20,threshold,bm25_kiwi,,,0.7,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
4,15,40,threshold,bm25_kiwi,,,0.5,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
5,15,40,threshold,bm25_kiwi,,,0.7,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
7,15,60,threshold,bm25_kiwi,,,0.5,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
8,15,60,threshold,bm25_kiwi,,,0.7,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333


### 1-2. Rewrite-Retrieve-Read

In [21]:
@handle_rate_limits
def rewrite_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    prompt = load_prompt('../prompt/firtst-stage retrieve/rewrite_domain.yaml')
    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    rewrite_chain = prompt | llm | StrOutputParser()

    results = rewrite_chain.batch(data_batches, config={"max_concurrency": max_concurrency})

    return results

In [22]:
max_concurrency = 5
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    # results.extend(rewrite_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    # results.extend(rewrite_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query_rewrite'] = results
# merged_dataset.to_csv('../data/document/역도/first-retrieve.csv', index=False)

100%|██████████| 22/22 [01:40<00:00,  4.55s/it]


In [13]:
rewrite_df = merged_dataset.copy()
rewrite_df = rewrite_df[rewrite_df.columns[:4].to_list() + ['query_rewrite']]
rewrite_df['precompute_dense'] = rewrite_df['query_rewrite'].progress_apply(lambda x: db.similarity_search_with_score(x, k=15))
rewrite_df['precompute_sparse_bm25_kiwi'] = rewrite_df['query_rewrite'].apply(lambda x : kiwi.search(x))
rewrite_results = optimization(configs, rewrite_df)

100%|██████████| 112/112 [00:49<00:00,  2.25it/s]
100%|██████████| 12/12 [00:00<00:00, 101.15it/s]


In [15]:
rewrite_results

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part","recall_exercise name, exercise phase",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
0,15,20,threshold,bm25_kiwi,,,0.3,0.390375,0.568835,0.371724,0.5,0.576923,0.590909,0.46,0.435476,0.8,0.733333
1,15,20,threshold,bm25_kiwi,,,0.5,0.392721,0.568835,0.37147,0.461538,0.615385,0.590909,0.44,0.435476,0.833333,0.733333
2,15,20,threshold,bm25_kiwi,,,0.7,0.392721,0.568835,0.37147,0.461538,0.615385,0.590909,0.44,0.435476,0.833333,0.733333
3,15,40,threshold,bm25_kiwi,,,0.3,0.383337,0.573299,0.361216,0.5,0.576923,0.590909,0.44,0.435476,0.8,0.8
4,15,40,threshold,bm25_kiwi,,,0.5,0.392721,0.568835,0.37147,0.461538,0.615385,0.590909,0.44,0.435476,0.833333,0.733333
5,15,40,threshold,bm25_kiwi,,,0.7,0.392721,0.568835,0.37147,0.461538,0.615385,0.590909,0.44,0.435476,0.833333,0.733333
6,15,60,threshold,bm25_kiwi,,,0.3,0.340727,0.549702,0.308278,0.5,0.576923,0.545455,0.4,0.428333,0.8,0.733333
7,15,60,threshold,bm25_kiwi,,,0.5,0.392721,0.568835,0.37147,0.461538,0.615385,0.590909,0.44,0.435476,0.833333,0.733333
8,15,60,threshold,bm25_kiwi,,,0.7,0.392721,0.568835,0.37147,0.461538,0.615385,0.590909,0.44,0.435476,0.833333,0.733333
9,15,80,threshold,bm25_kiwi,,,0.3,0.311349,0.53631,0.272995,0.5,0.538462,0.545455,0.38,0.428333,0.766667,0.733333


### 1-3. Query2Doc

In [2]:
@handle_rate_limits
def query2Doc_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    prompt = load_prompt('../prompt/firtst-stage retrieve/query2Doc.yaml')
    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    query2Doc_chain = prompt | llm | StrOutputParser()

    results = query2Doc_chain.batch(data_batches, config={"max_concurrency": max_concurrency})

    return results

In [5]:
max_concurrency = 5
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    # results.extend(query2Doc_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    # results.extend(query2Doc_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query2Doc'] = results
# merged_dataset.to_csv('../data/document/역도/first-retrieve.csv', index=False)

  llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
100%|██████████| 22/22 [02:29<00:00,  6.82s/it]


In [16]:
query2Doc_df = merged_dataset.copy()
query2Doc_df = query2Doc_df[query2Doc_df.columns[:4].to_list() + ['query2Doc']]
query2Doc_df['precompute_dense'] = query2Doc_df['query2Doc'].progress_apply(lambda x: db.similarity_search_with_score(x, k=15))
query2Doc_df['precompute_sparse_bm25_kiwi'] = query2Doc_df['query2Doc'].apply(lambda x : kiwi.search(x))
query2Doc_results = optimization(configs, query2Doc_df)

100%|██████████| 112/112 [00:43<00:00,  2.58it/s]
100%|██████████| 12/12 [00:00<00:00, 101.15it/s]


In [17]:
query2Doc_results

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part","recall_exercise name, exercise phase",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
0,15,20,threshold,bm25_kiwi,,,0.3,0.311993,0.45744,0.3019,0.307692,0.461538,0.454545,0.32,0.236667,0.766667,0.8
1,15,20,threshold,bm25_kiwi,,,0.5,0.310035,0.452976,0.301581,0.307692,0.461538,0.409091,0.32,0.236667,0.766667,0.8
2,15,20,threshold,bm25_kiwi,,,0.7,0.310035,0.452976,0.301581,0.307692,0.461538,0.409091,0.32,0.236667,0.766667,0.8
3,15,40,threshold,bm25_kiwi,,,0.3,0.310645,0.45744,0.2987,0.307692,0.461538,0.454545,0.32,0.236667,0.766667,0.8
4,15,40,threshold,bm25_kiwi,,,0.5,0.310035,0.452976,0.301581,0.307692,0.461538,0.409091,0.32,0.236667,0.766667,0.8
5,15,40,threshold,bm25_kiwi,,,0.7,0.310035,0.452976,0.301581,0.307692,0.461538,0.409091,0.32,0.236667,0.766667,0.8
6,15,60,threshold,bm25_kiwi,,,0.3,0.295997,0.45744,0.281141,0.307692,0.461538,0.454545,0.34,0.236667,0.733333,0.8
7,15,60,threshold,bm25_kiwi,,,0.5,0.310035,0.452976,0.301581,0.307692,0.461538,0.409091,0.32,0.236667,0.766667,0.8
8,15,60,threshold,bm25_kiwi,,,0.7,0.310035,0.452976,0.301581,0.307692,0.461538,0.409091,0.32,0.236667,0.766667,0.8
9,15,80,threshold,bm25_kiwi,,,0.3,0.285528,0.45744,0.267606,0.307692,0.461538,0.454545,0.34,0.236667,0.733333,0.8


### 1-4. Rewrite Retrieve Read + Query Decomposition 조합

In [8]:
max_concurrency = 5
input_querys = merged_dataset['query_rewrite'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    # results.extend(decomposition_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    # results.extend(decomposition_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query_rewrite_decomposition'] = results
# merged_dataset.to_csv('../data/document/역도/first-retrieve.csv', index=False)

100%|██████████| 22/22 [06:51<00:00, 18.72s/it]


In [11]:
rewrite_decompo_df = merged_dataset.copy()
rewrite_decompo_df = rewrite_decompo_df[rewrite_decompo_df.columns[:4].to_list() + ['query_rewrite_decomposition']]
rewrite_decompo_df = precompute_list('query_rewrite_decomposition', rewrite_decompo_df, db, kiwi)
rewrite_decompo_results = optimization(configs, rewrite_decompo_df)

100%|██████████| 112/112 [02:16<00:00,  1.22s/it]
100%|██████████| 112/112 [00:00<00:00, 225.81it/s]
100%|██████████| 12/12 [00:00<00:00, 106.41it/s]


In [12]:
decompo_results.sort_values(by=['recall', 'map', 'ndcg'], ascending=False)

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part","recall_exercise name, exercise phase",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
0,15,20,threshold,bm25_kiwi,,,0.3,0.334008,0.556335,0.28551,0.461538,0.653846,0.636364,0.32,0.390476,0.933333,0.733333
3,15,40,threshold,bm25_kiwi,,,0.3,0.331565,0.552764,0.283218,0.423077,0.653846,0.636364,0.34,0.370476,0.933333,0.733333
6,15,60,threshold,bm25_kiwi,,,0.3,0.321201,0.550978,0.270761,0.461538,0.653846,0.545455,0.34,0.360476,0.966667,0.733333
9,15,80,threshold,bm25_kiwi,,,0.3,0.298001,0.546514,0.242043,0.461538,0.653846,0.5,0.36,0.360476,0.933333,0.733333
1,15,20,threshold,bm25_kiwi,,,0.5,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
2,15,20,threshold,bm25_kiwi,,,0.7,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
4,15,40,threshold,bm25_kiwi,,,0.5,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
5,15,40,threshold,bm25_kiwi,,,0.7,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
7,15,60,threshold,bm25_kiwi,,,0.5,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333
8,15,60,threshold,bm25_kiwi,,,0.7,0.316596,0.513095,0.276785,0.384615,0.576923,0.636364,0.3,0.323333,0.866667,0.733333


### 1-5. Query Expansion

In [13]:
@handle_rate_limits
def queryExpansion_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class QueryExpansionOutput(BaseModel):
        expansioned_queries: List[str] = Field(..., description="List of expansioned queries in Korean")
    
    parser = PydanticOutputParser(pydantic_object=QueryExpansionOutput)
    prompt = load_prompt('../prompt/firtst-stage retrieve/query_expandsion.yaml')
    prompt = prompt.partial(format=parser.get_format_instructions())

    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    queryExpansion_chain = prompt | llm 

    results = queryExpansion_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [parser.parse(result.content).expansioned_queries for result in results]

    return results

In [14]:
max_concurrency = 5
input_querys = merged_dataset['user_input'].to_list()

results = []
for i in tqdm(range(len(input_querys) // max_concurrency)):
    results.extend(queryExpansion_chain(input_querys[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_querys) % max_concurrency != 0:
    results.extend(queryExpansion_chain(input_querys[len(input_querys) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))
    
merged_dataset['query_expansion'] = results
merged_dataset.to_csv('../data/document/역도/first-retrieve.csv', index=False)

100%|██████████| 22/22 [00:54<00:00,  2.47s/it]


In [15]:
queryExpansion_df = merged_dataset.copy()
queryExpansion_df = queryExpansion_df[queryExpansion_df.columns[:4].to_list() + ['query_expansion']]
queryExpansion_df = precompute_list('query_expansion', queryExpansion_df, db, kiwi)
queryExpansion_results = optimization(configs, queryExpansion_df)

100%|██████████| 112/112 [02:27<00:00,  1.32s/it]
100%|██████████| 112/112 [00:00<00:00, 221.19it/s]
100%|██████████| 12/12 [00:00<00:00, 104.55it/s]


In [16]:
queryExpansion_results

Unnamed: 0,k,alpha,dense_type,morphological_analyzer,fetch_k,lambda_mult,score_threshold,ndcg,recall,map,recall_body part,recall_exercise name,"recall_exercise name, body part","recall_exercise name, exercise phase",recall_multi_hop_abstract_query_synthesizer,recall_multi_hop_specific_query_synthesizer,recall_single_hop_specifc_query_synthesizer
0,15,20,threshold,bm25_kiwi,,,0.3,0.303868,0.500595,0.267352,0.5,0.538462,0.727273,0.34,0.228333,0.766667,0.666667
1,15,20,threshold,bm25_kiwi,,,0.5,0.283226,0.447024,0.260202,0.423077,0.461538,0.681818,0.26,0.178333,0.733333,0.666667
2,15,20,threshold,bm25_kiwi,,,0.7,0.283226,0.447024,0.260202,0.423077,0.461538,0.681818,0.26,0.178333,0.733333,0.666667
3,15,40,threshold,bm25_kiwi,,,0.3,0.308393,0.500595,0.272353,0.5,0.538462,0.727273,0.34,0.228333,0.766667,0.666667
4,15,40,threshold,bm25_kiwi,,,0.5,0.283226,0.447024,0.260202,0.423077,0.461538,0.681818,0.26,0.178333,0.733333,0.666667
5,15,40,threshold,bm25_kiwi,,,0.7,0.283226,0.447024,0.260202,0.423077,0.461538,0.681818,0.26,0.178333,0.733333,0.666667
6,15,60,threshold,bm25_kiwi,,,0.3,0.293074,0.496131,0.255321,0.5,0.538462,0.727273,0.32,0.228333,0.766667,0.666667
7,15,60,threshold,bm25_kiwi,,,0.5,0.283226,0.447024,0.260202,0.423077,0.461538,0.681818,0.26,0.178333,0.733333,0.666667
8,15,60,threshold,bm25_kiwi,,,0.7,0.283226,0.447024,0.260202,0.423077,0.461538,0.681818,0.26,0.178333,0.733333,0.666667
9,15,80,threshold,bm25_kiwi,,,0.3,0.295951,0.500595,0.258801,0.5,0.538462,0.727273,0.34,0.228333,0.766667,0.666667


## 2. Adaptive Pattern