In [1]:
import sys
sys.path.append("/data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src")

import os
import time
import json
import random
import warnings
import anthropic
import threading
import huggingface_hub

from tqdm import tqdm
from openai import OpenAI
from itertools import product
from scipy.spatial.distance import cosine
from concurrent.futures import ThreadPoolExecutor, as_completed

from langchain.retrievers import EnsembleRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", category=FutureWarning)

from dotenv import load_dotenv
load_dotenv("../keys.env")

upstage_api_key = os.getenv("UPSTAGE_API_KEY")
os.environ['UPSTAGE_API_KEY'] = upstage_api_key

openai_api_key = os.getenv('OPENAI_API_KEY')
os.environ['OPENAI_API_KEY'] = openai_api_key

anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')
os.environ['ANTHROPIC_API_KEY'] = anthropic_api_key

hf_token = os.getenv("HF_TOKEN")
huggingface_hub.login(hf_token)

from config import Args
from data.data import load_document, load_query, chunking

from sparse_retriever.model import load_sparse_model
from dense_retriever.model import load_dense_model, load_hf_encoder, load_upstage_encoder, load_openai_encoder, load_voyage_encoder

  from .autonotebook import tqdm as notebook_tqdm


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /data/ephemeral/home/.cache/huggingface/token
Login successful


In [2]:
args = Args()

documents_data_path = "../dataset/gpt_contextual_retrieval_documents_v3.jsonl"
questions_data_path = "../dataset/processed_documents_queries.jsonl"

documents = load_document(documents_data_path)
questions = load_query(questions_data_path)

random.shuffle(questions)
questions = questions[:100]

In [3]:
retriever_weights_list = [[0.5, 0.5], [0.4, 0.6], [0.3, 0.7], [0.2, 0.8], [0.1, 0.9]]
ensemble_weights_list = [[0.3, 0.3, 0.4], [0.25, 0.25, 5], [0.2, 0.2, 0.6], [0.15, 0.15, 0.7], [0.1, 0.1, 0.8], [0.05, 0.05, 0.9], [0, 0, 1]]

In [4]:
def calc_map(gt, pred):    
    sum_average_precision = 0    
    for j in pred:        
        if gt[j["eval_id"]]:            
            hit_count = 0            
            sum_precision = 0            
            for i,docid in enumerate(j["topk"][:3]):                
                if docid in gt[j["eval_id"]]:                    
                    hit_count += 1                    
                    sum_precision += hit_count/(i+1)            
            average_precision = sum_precision / hit_count if hit_count > 0 else 0        
        else:            
            average_precision = 0 if j["topk"] else 1        
        sum_average_precision += average_precision    
    return sum_average_precision/len(pred)

In [5]:
# folder_path = f"./index_files/{args.encoder_method}"
# os.makedirs(folder_path, exist_ok=True)

In [6]:
best_map = 0
best_params = {}

args.faiss_index_file = f"/data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3"
args.ensemble_models = [
    {'type' : 'hf', 'name' : "BAAI/bge-m3"},
    {'type': 'hf', 'name': "dragonkue/bge-m3-ko"},
    {'type': 'upstage', 'name': "solar-embedding-1-large-query"},
]

ensemble_encoders = []
for model_info in args.ensemble_models:
    model_type = model_info.get('type', 'hf')
    model_name = model_info['name']
    if model_type == 'hf':
        encoder = load_hf_encoder(model_name, args.model_kwargs, args.encode_kwargs)
    elif model_type == 'upstage':
        encoder = load_upstage_encoder(model_name)
    elif model_type == 'voyage':
        encoder = load_voyage_encoder(model_name)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    ensemble_encoders.append(encoder)

# 각 파라미터 조합에 대해 검색 및 평가 수행
for retriever_weights, ensemble_weights in product(retriever_weights_list, ensemble_weights_list):
    # 문서 chunking
    # chunk_documents = chunking(args, documents)

    # sparse retriever 로드
    sparse_retriever = load_sparse_model(documents, args.src_lang)
    sparse_retriever.k = 10

    # dense retriever 로드
    dense_retriever = load_dense_model(args, documents).as_retriever(search_kwargs={"k": 10})

    # 앙상블 retriever 설정
    retriever = EnsembleRetriever(
        retrievers=[sparse_retriever, dense_retriever],
        weights=args.retriever_weights,
        search_type="mrr"
    )

    # 정답 레이블 생성
    gt = {}
    for question in questions:
        query, question_id = question['query'], question['metadata']['docid']
        gt[question_id] = [question_id]
    
    # 예측 결과 수집
    pred = []
    for question in tqdm(questions):
        query, question_id = question['query'], question['metadata']['docid']

        # 쿼리 임베딩 계산 및 가중치 적용
        query_embeddings = []
        for idx, encoder in enumerate(ensemble_encoders):
            query_embedding = encoder.embed_query(query)
            query_embeddings.append((query_embedding, ensemble_weights[idx]))
        
        # 검색 수행
        search_result = retriever.invoke(query)
        # for result in search_result:
            # print(result)
        
        combined_scores = []
        for doc in search_result:
            combined_similarity = 0
            for idx, (query_embedding, weight) in enumerate(query_embeddings):
                # 문서 임베딩 가져오기
                doc_embedding_key = f'embedding_{args.ensemble_models[idx]["name"]}'
                doc_embedding = doc.metadata.get(doc_embedding_key) or ensemble_encoders[idx].embed_query(doc.page_content)
                
                # 유사도 계산 및 가중치 적용
                similarity = 1 - cosine(query_embedding, doc_embedding)
                combined_similarity += weight * similarity
            combined_scores.append((doc, combined_similarity))
        
        # top-k 결과 수집
        topk_result = [doc.metadata.get('docid') for doc, _ in sorted(combined_scores, key=lambda x: x[1], reverse=True)]
        
        pred.append({
            "eval_id": question_id,
            "topk": topk_result
        })
    
    # MAP 계산
    mean_average_precision = calc_map(gt, pred)
    print(f"Parameters: chunk_size={args.chunk_size}, chunk_overlap={args.chunk_overlap}, retriever_weights={retriever_weights}, ensemble_weights={ensemble_weights}")
    print(f"Mean Average Precision (MAP): {mean_average_precision:.4f}\n")
    
    # 최적의 MAP 값과 파라미터 저장
    if mean_average_precision > best_map:
        best_map = mean_average_precision
        best_params = {
            'chunk_size': args.chunk_size,
            'chunk_overlap': args.chunk_overlap,
            'retriever_weights': retriever_weights,
            'ensemble_weights': ensemble_weights
        }


  return self.fget.__get__(instance, owner)()


FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:15<00:00,  9.76s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.5, 0.5], ensemble_weights=[0.3, 0.3, 0.4]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:18<00:00,  9.79s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.5, 0.5], ensemble_weights=[0.25, 0.25, 5]
Mean Average Precision (MAP): 0.9033

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:32<00:00,  9.93s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.5, 0.5], ensemble_weights=[0.2, 0.2, 0.6]
Mean Average Precision (MAP): 0.9175

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:23<00:00,  9.84s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.5, 0.5], ensemble_weights=[0.15, 0.15, 0.7]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:14<00:00,  9.75s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.5, 0.5], ensemble_weights=[0.1, 0.1, 0.8]
Mean Average Precision (MAP): 0.8992

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:04<00:00,  9.65s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.5, 0.5], ensemble_weights=[0.05, 0.05, 0.9]
Mean Average Precision (MAP): 0.9017

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:13<00:00,  9.73s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.5, 0.5], ensemble_weights=[0, 0, 1]
Mean Average Precision (MAP): 0.9025

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:28<00:00,  9.88s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.4, 0.6], ensemble_weights=[0.3, 0.3, 0.4]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:16<00:00,  9.76s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.4, 0.6], ensemble_weights=[0.25, 0.25, 5]
Mean Average Precision (MAP): 0.9033

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:04<00:00,  9.64s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.4, 0.6], ensemble_weights=[0.2, 0.2, 0.6]
Mean Average Precision (MAP): 0.9175

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:23<00:00,  9.83s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.4, 0.6], ensemble_weights=[0.15, 0.15, 0.7]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [16:25<00:00,  9.86s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.4, 0.6], ensemble_weights=[0.1, 0.1, 0.8]
Mean Average Precision (MAP): 0.8992

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [15:58<00:00,  9.59s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.4, 0.6], ensemble_weights=[0.05, 0.05, 0.9]
Mean Average Precision (MAP): 0.9017

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [15:46<00:00,  9.47s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.4, 0.6], ensemble_weights=[0, 0, 1]
Mean Average Precision (MAP): 0.9025

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [17:56<00:00, 10.76s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.3, 0.7], ensemble_weights=[0.3, 0.3, 0.4]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [19:38<00:00, 11.78s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.3, 0.7], ensemble_weights=[0.25, 0.25, 5]
Mean Average Precision (MAP): 0.9033

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [19:29<00:00, 11.69s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.3, 0.7], ensemble_weights=[0.2, 0.2, 0.6]
Mean Average Precision (MAP): 0.9175

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [19:33<00:00, 11.73s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.3, 0.7], ensemble_weights=[0.15, 0.15, 0.7]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [19:37<00:00, 11.78s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.3, 0.7], ensemble_weights=[0.1, 0.1, 0.8]
Mean Average Precision (MAP): 0.8992

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [19:36<00:00, 11.77s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.3, 0.7], ensemble_weights=[0.05, 0.05, 0.9]
Mean Average Precision (MAP): 0.9017

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [19:11<00:00, 11.52s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.3, 0.7], ensemble_weights=[0, 0, 1]
Mean Average Precision (MAP): 0.9025

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [19:18<00:00, 11.58s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.2, 0.8], ensemble_weights=[0.3, 0.3, 0.4]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [18:17<00:00, 10.97s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.2, 0.8], ensemble_weights=[0.25, 0.25, 5]
Mean Average Precision (MAP): 0.9033

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [18:46<00:00, 11.27s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.2, 0.8], ensemble_weights=[0.2, 0.2, 0.6]
Mean Average Precision (MAP): 0.9175

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [23:43<00:00, 14.24s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.2, 0.8], ensemble_weights=[0.15, 0.15, 0.7]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [17:04<00:00, 10.25s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.2, 0.8], ensemble_weights=[0.1, 0.1, 0.8]
Mean Average Precision (MAP): 0.8992

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [17:00<00:00, 10.21s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.2, 0.8], ensemble_weights=[0.05, 0.05, 0.9]
Mean Average Precision (MAP): 0.9017

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [17:06<00:00, 10.26s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.2, 0.8], ensemble_weights=[0, 0, 1]
Mean Average Precision (MAP): 0.9025

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [18:59<00:00, 11.39s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.1, 0.9], ensemble_weights=[0.3, 0.3, 0.4]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [17:10<00:00, 10.31s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.1, 0.9], ensemble_weights=[0.25, 0.25, 5]
Mean Average Precision (MAP): 0.9033

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [17:15<00:00, 10.35s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.1, 0.9], ensemble_weights=[0.2, 0.2, 0.6]
Mean Average Precision (MAP): 0.9175

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [18:28<00:00, 11.08s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.1, 0.9], ensemble_weights=[0.15, 0.15, 0.7]
Mean Average Precision (MAP): 0.9158

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [28:33<00:00, 17.14s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.1, 0.9], ensemble_weights=[0.1, 0.1, 0.8]
Mean Average Precision (MAP): 0.8992

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [23:11<00:00, 13.92s/it]


Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.1, 0.9], ensemble_weights=[0.05, 0.05, 0.9]
Mean Average Precision (MAP): 0.9017

FAISS 인덱스 로드 중: /data/ephemeral/home/Upstage_Ai_Lab/Final/IR/src/index_files/upstage/CRV3
FAISS 인덱스 로드 완료, 총 문서 수: 24799


100%|██████████| 100/100 [23:12<00:00, 13.93s/it]

Parameters: chunk_size=100, chunk_overlap=50, retriever_weights=[0.1, 0.9], ensemble_weights=[0, 0, 1]
Mean Average Precision (MAP): 0.9025






In [7]:
print(f"Best Mean Average Precision (MAP): {best_map}")
print(f"Best Parameters:")
print(f"  chunk_size: {best_params['chunk_size']}")
print(f"  chunk_overlap: {best_params['chunk_overlap']}")
print(f"  retriever_weights: {best_params['retriever_weights']}")
print(f"  ensemble_weights: {best_params['ensemble_weights']}")

Best Mean Average Precision (MAP): 0.9175000000000001
Best Parameters:
  chunk_size: 100
  chunk_overlap: 50
  retriever_weights: [0.5, 0.5]
  ensemble_weights: [0.2, 0.2, 0.6]
