# Self-Reflection 통한 First-Stage Retrieve 성능 개선
* 기존의 basic + custom 데이터셋을 통해 최적화한 결과, 목표 성능(recall: 0.8)에 도달
    * k: 20, alpha: 40, morphological_analyzer: bm25_kiki_pos, score_threshold: 0.1
    * 위의 조합에서 recall: 0.815476 달성
* rag 시스템 연구를 위해 k: 10, alpha: 40, morphological_analyzer: bm25_kiki_pos, score_threshold: 0.1 조합을 기준으로 Self-Reflection 관련 실험 진행
* 아래의 5 가지 self-reflection을 대상으로 first-stage retrieve의 성능 개선 실험
    * retrieved contexts + input query 대상
        * Binary Classification
        * Confidence Score
        * CoT + Few-Shot Confidence Score
    * generated answer + input query 대상
        * Binary Classification
        * Confidence Score

### 결과
* retrieved contexts + input query 기반 grader는 generated answer + input query 기반 grader보다 성능이 떨어짐.
    * retrieved contexts는 10개 * 500 토큰 약 5000 토큰 이상의 텍스트를 입력받음
    * 문서 평가 및 종합 과정에서 과도한 정보량으로 인해 정확한 예측이 어려움
* Binary classification과 Confidence Score 비교
    * Confidence Score와 threshold 값을 통해 Binary classification보다 좋은 성능의 로직 구축을 희망했으나 목표 달성 실패
    * Binary Classification이 더 안정적이고 높은 성능(precision: 0.70-0.76, f1-score: 0.73)을 보임
* Self-Reflection을 통한 First-Stage Retrieve 성능 개선은 가능하나, 구현 복잡도와 비용 대비 효과가 제한적

In [27]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

from typing import List
from pydantic import BaseModel, Field
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score

In [4]:
import sys
sys.path.append('../code/graphParser')
sys.path.append('../code/ragas_custom')
from rateLimit import handle_rate_limits
from retrieve.sparse import BM25
from retrieve.config import generate_retriever_configs
from evaluation.retrieve import optimization, combine_hybrid_results

from langchain_core.prompts import load_prompt
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_ollama import OllamaLLM

from ragas.testset.graph import KnowledgeGraph

from dotenv import load_dotenv
load_dotenv()

True

In [3]:
from langchain_core.documents import Document

kg = KnowledgeGraph.load('../data/rag/kg.json')

documents = [Document(page_content=node.properties['page_content'],
                      metadata=node.properties['document_metadata'])
                       for node in kg.nodes]

In [4]:
standard_config = {
    'k': 10,
    'alpha': 40,
    'dense_type': 'threshold',
    'dense_params': {'score_threshold': 0.1},
    'morphological_analyzer': 'bm25_kiwi_pos'
}

In [6]:
kiwi_pos = BM25(k=15, type='kiwi_pos')
texts = [node.properties['page_content'] for node in kg.nodes]
kiwi_pos.from_texts(texts)

embeddings = OpenAIEmbeddings()
db = FAISS.from_documents(documents, embeddings)

In [None]:
merged_dataset = pd.read_csv('../data/rag/valid_dataset.csv')
merged_dataset['reference_contexts'] = merged_dataset['reference_contexts'].apply(lambda x : eval(x))

for column in merged_dataset.columns[5:]:
    merged_dataset[column] = merged_dataset[column].apply(lambda x : eval(x))

merged_dataset['precompute_dense'] = merged_dataset['user_input'].apply(lambda x : db.similarity_search_with_score(x, k=int(15)))
merged_dataset['precompute_sparse_bm25_kiwi_pos'] = merged_dataset['user_input'].apply(lambda x : kiwi_pos.search(x))

merged_dataset['precompute_dense'] = merged_dataset['precompute_dense'].apply(lambda results: [doc.page_content for result in results for doc, score in [result] if score >= standard_config['dense_params']['score_threshold']])
merged_dataset['result_retrieve'] = combine_hybrid_results(merged_dataset['precompute_dense'], merged_dataset['precompute_sparse_bm25_kiwi_pos'], standard_config['alpha'], standard_config['k'])
merged_dataset['need_retrieve'] = merged_dataset.apply(lambda x : 'no' if len(set(x['reference_contexts']).intersection(set(x['result_retrieve']))) == len(x['reference_contexts']) else 'yes', axis=1)

merged_dataset = merged_dataset[merged_dataset.columns[:5].to_list() + ['result_retrieve', 'need_retrieve']]
# merged_dataset.to_csv('../data/rag/self_reflection.csv', index=False)

In [16]:
def evaluate_binary_sufficiency(df: pd.DataFrame, column: str):
    df = df.copy()
    df['prediction'] = df[column].apply(lambda x : x[0].lower())

    df['label'] = df['need_retrieve'].str.lower()

    y_true = df['label']
    y_pred = df['prediction']

    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, pos_label='yes'),
        "recall": recall_score(y_true, y_pred, pos_label='yes'),
        "f1_score": f1_score(y_true, y_pred, pos_label='yes'),
    }

    report = classification_report(y_true, y_pred, target_names=["no", "yes"])

    return metrics, report

In [28]:
def analyze_confidence_thresholds(df, score_col='score', label_col='need_retrieve', thresholds=None):
    if thresholds is None:
        thresholds = np.arange(0.0, 1.01, 0.1)

    results = []

    df[score_col] = df[score_col].apply(lambda x : x[0])

    for thresh in thresholds:
        df['pred'] = df[score_col] >= thresh

        y_true = df[label_col].map({'yes': 1, 'no': 0})
        y_pred = df['pred'].astype(int)

        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)

        total = len(df)
        predicted_positive = y_pred.sum()
        retrieve_rate = predicted_positive / total

        # 실제 필요했던 retrieve 수
        true_positive = ((y_pred == 1) & (y_true == 1)).sum()
        false_positive = ((y_pred == 1) & (y_true == 0)).sum()
        actual_positive = y_true.sum()

        results.append({
            'threshold': thresh,
            'precision': round(precision, 3),
            'recall': round(recall, 3),
            'f1_score': round(f1, 3),
            'retrieve_count': predicted_positive,
            'retrieve_rate': round(retrieve_rate, 3),
            'true_positive': true_positive,
            'false_positive': false_positive,
            'missed_retrieves': actual_positive - true_positive
        })

    return pd.DataFrame(results)

## 1. Retrieved Contexts 대상 Grader

### 1-1. Binary Classification

In [17]:
# @handle_rate_limits
def binaryClassification_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class BinaryClassificationOutput(BaseModel):
        binary_classification: str = Field(..., description="Binary Response, 'Yes' or 'No'")
        reasoning: str = Field(..., description="Maximum 3 sentences explaining the decision")
    
    parser = PydanticOutputParser(pydantic_object=BinaryClassificationOutput)
    prompt = load_prompt('../prompt/reflection/binaryClassification.yaml')
    prompt = prompt.partial(format=parser.get_format_instructions())
    # llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    llm = OllamaLLM(model='gemma3:12b')
    
    binaryClassification_chain = prompt | llm

    results = binaryClassification_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [(parser.parse(result).binary_classification, parser.parse(result).reasoning) for result in results]
    
    return results

In [18]:
max_concurrency = 2
input_querys = merged_dataset['user_input'].to_list()
documents = merged_dataset['result_retrieve'].apply(lambda x : '\n\n'.join([f"[Document {i+1}]\n{doc.strip()}" for i, doc in enumerate(x)])).to_list()
input_data = [{'document': document, 'question': question} for document, question in zip(documents, input_querys)]

results = []
for i in tqdm(range(len(input_data) // max_concurrency)):
    results.extend(binaryClassification_chain(input_data[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_data) % max_concurrency != 0:
    results.extend(binaryClassification_chain(input_data[len(input_data) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))

merged_dataset['binary_classification'] = results
merged_dataset.to_csv('../data/rag/self_reflection.csv', index=False)

100%|██████████| 28/28 [06:30<00:00, 13.94s/it]


In [29]:
merged_dataset.head(2)

Unnamed: 0,user_input,reference_contexts,reference,synthesizer_name,reference_contexts_section,result_retrieve,need_retrieve,binary_classification,confidence_score,generated_answer,generated_need_retrieve,generated_confidence_need_retrieve
0,바벨 잡는 방법과 앉아받기 동작의 올바른 수행을 결합하여 최적의 리프팅 기술을 설명해줘.,"[바벨을 잡는 방법에는 크게 오버그립(over grip), 언더그립(under gr...","바벨 잡는 방법에는 오버그립, 언더그립, 리버스그립, 훅그립이 있으며, 각각의 그립...",multi_hop_abstract_query_synthesizer,"['Ⅲ. 역도경기 기술의 구조와 훈련법', 'Ⅲ. 역도경기 기술의 구조와 훈련법']",['순간적으로 무거운 물체를 들어 올리는 데에는 근력 외에도 강인한 정신력이 요구\...,yes,"(Yes, Combining a proper barbell grip with a c...",0.85,바벨을 잡는 방법과 앉아받기 동작의 올바른 수행을 결합하여 최적의 리프팅 기술을 설...,No,0.7
1,"스포츠 심리학자들은 선수의 성과 향상에 어떤 역할을 하며, 그들의 개인 문제에 대한...",[성공적인 상\n담 진행을 위해서 상담사는 내담자의 감정에 공감할 수 있어야 한다....,스포츠 심리학자들은 운동선수의 성과를 향상시키기 위해 여러 가지 역할을 수행합니다....,multi_hop_abstract_query_synthesizer,"['II. 역도의 스포츠 과학적 원리', 'II. 역도의 스포츠 과학적 원리']",['스포츠심리학을 연구하고 스포츠심리학의 연구 성과를 스포츠 현장에 적용하는\n스포...,no,"(Yes, Sports psychologists play a vital role i...",0.85,"스포츠 심리학자들은 선수의 성과 향상에 중요한 역할을 하며, 선수의 개인 문제에 대...",No,0.6


In [32]:
metrics, report = evaluate_binary_sufficiency(merged_dataset, 'binary_classification')
print(metrics)
print(report)

{'accuracy': 0.4642857142857143, 'precision': 0.5, 'recall': 0.7666666666666667, 'f1_score': 0.6052631578947368}
              precision    recall  f1-score   support

          no       0.30      0.12      0.17        26
         yes       0.50      0.77      0.61        30

    accuracy                           0.46        56
   macro avg       0.40      0.44      0.39        56
weighted avg       0.41      0.46      0.40        56



### 1-2. Confidence Score

In [36]:
# @handle_rate_limits
def confidenceScore_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class ConfidenceScoreOutput(BaseModel):
        confidence_score: float = Field(..., description="Confidence Score, 0.00 to 1.00")
        reasoning: str = Field(..., description="Maximum 3 sentences explaining the decision")
    
    parser = PydanticOutputParser(pydantic_object=ConfidenceScoreOutput)
    prompt = load_prompt('../prompt/reflection/confidenceScore.yaml')
    prompt = prompt.partial(format=parser.get_format_instructions())
    # llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    llm = OllamaLLM(model='gemma3:12b')
    
    confidenceScore_chain = prompt | llm

    results = confidenceScore_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [(parser.parse(result).confidence_score, parser.parse(result).reasoning) for result in results]
    
    return results

In [39]:
max_concurrency = 3
input_querys = merged_dataset['user_input'].to_list()
documents = merged_dataset['result_retrieve'].apply(lambda x : '\n\n'.join([f"[Document {i+1}]\n{doc.strip()}" for i, doc in enumerate(x)])).to_list()
input_data = [{'document': document, 'question': question} for document, question in zip(documents, input_querys)]

results = []
for i in tqdm(range(len(input_data) // max_concurrency)):
    results.extend(confidenceScore_chain(input_data[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_data) % max_concurrency != 0:
    results.extend(confidenceScore_chain(input_data[len(input_data) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))

merged_dataset['confidence_score'] = results
merged_dataset.to_csv('../data/rag/self_reflection.csv', index=False)

100%|██████████| 18/18 [06:28<00:00, 21.57s/it]


In [41]:
tmp = merged_dataset.copy()

tmp['confidence_score'] = tmp['confidence_score'].apply(lambda x : x[0])
tmp.groupby(['need_retrieve', 'confidence_score']).size().reset_index(name='count')

Unnamed: 0,need_retrieve,confidence_score,count
0,no,0.7,3
1,no,0.75,3
2,no,0.85,3
3,no,0.9,2
4,no,0.95,15
5,yes,0.7,1
6,yes,0.75,5
7,yes,0.85,2
8,yes,0.95,22


### 1-3. CoT + Few-Shot Confidence Score

In [42]:
# @handle_rate_limits
def gradeDocument_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class GradeDocument(BaseModel):
        """Score from 0.00 to 1.00 indicating how necessary it is to re-retrieve documents to answer the question."""
        re_retrieve_score: float = Field(
            description="A score between 0.0 (no re-retrieval needed) and 1.0 (re-retrieval strongly needed)"
        )
        reasoning: str = Field(..., description="Maximum 3 sentences explaining the decision")

    parser = PydanticOutputParser(pydantic_object=GradeDocument)
    prompt = load_prompt('../prompt/reflection/retrieveGrader.yaml')
    prompt = prompt.partial(format=parser.get_format_instructions())

    # llm = ChatOpenAI(model=model_name, temperature=0.2)
    llm = OllamaLLM(model='gemma3:12b')
    gradeDocuments_chain = prompt | llm

    results = gradeDocuments_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [(parser.parse(result).re_retrieve_score, parser.parse(result).reasoning) for result in results]

    return results

In [None]:
max_concurrency = 3
input_querys = merged_dataset['user_input'].to_list()
documents = merged_dataset['result_retrieve'].apply(lambda x : '\n\n'.join([f"[Document {i+1}]\n{doc.strip()}" for i, doc in enumerate(x)])).to_list()

input_data = [{'document': document, 'question': question} for document, question in zip(documents, input_querys)]

results = []
for i in tqdm(range(len(input_data) // max_concurrency)):
    results.extend(gradeDocument_chain(input_data[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_data) % max_concurrency != 0:
    results.extend(gradeDocument_chain(input_data[len(input_data) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))

result_df = merged_dataset.copy()
result_df = result_df[result_df.columns[:5].to_list() + ['need_retrieve']]
result_df['gradeDocument'] = results

merged_dataset.to_csv('../data/rag/self_reflection.csv', index=False)

In [None]:
analyze_confidence_thresholds(result_df, 'gradeDocument')

In [23]:
result_df.groupby(['need_retrieve', 'gradeDocument'])['user_input'].count().reset_index()                                                          

Unnamed: 0,need_retrieve,gradeDocument,user_input
0,no,0.75,8
1,no,0.8,18
2,yes,0.75,8
3,yes,0.8,19
4,yes,1.0,3


## 2. Generated answer 대상 Grader

In [11]:
merged_dataset = pd.read_csv('../data/rag/self_reflection.csv')
merged_dataset['reference_contexts'] = merged_dataset['reference_contexts'].apply(lambda x : eval(x))

In [12]:
@handle_rate_limits
def generate_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    prompt = PromptTemplate.from_template(
        """You are an assistant for question-answering tasks. 
    Use the following pieces of retrieved context to answer the question. 
    If you don't know the answer, just say that you don't know. 
    Answer in Korean.

    #Question: 
    {question} 
    #Context: 
    {context} 

    #Answer:"""
    )

    # llm = OllamaLLM(model='gemma3:12b')
    llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    generate_chain = prompt | llm

    results = generate_chain.batch(data_batches, config={"max_concurrency": max_concurrency})

    return results

In [13]:
max_concurrency = 2
input_querys = merged_dataset['user_input'].to_list()
documents = merged_dataset['result_retrieve'].apply(lambda x : '\n\n'.join([doc for doc in x])).to_list()
input_data = [{'context': document, 'question': question} for document, question in zip(documents, input_querys)]

results = []
for i in tqdm(range(len(input_data) // max_concurrency)):
    results.extend(generate_chain(input_data[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_data) % max_concurrency != 0:
    results.extend(generate_chain(input_data[len(input_data) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))

results = [result.content for result in results]
merged_dataset['generated_answer'] = results
merged_dataset.to_csv('../data/rag/self_reflection.csv', index=False)

  llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
100%|██████████| 28/28 [02:27<00:00,  5.25s/it]


### 2-1. Binary Classification

In [14]:
# @handle_rate_limits
def generate_binary_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class NeedRetrieve(BaseModel):
        need_retrieve: str = Field(description="Whether the retrieval step in a retrieval-augmented generation (RAG) system should be re-executed to improve the answer quality. Respond with 'Yes' or 'No'.")
        reasoning: str = Field(description="Maximum 3 sentences explaining the decision")

    parser = PydanticOutputParser(pydantic_object=NeedRetrieve)

    prompt = PromptTemplate.from_template("""
        You are tasked with evaluating whether the retrieval step in a retrieval-augmented generation (RAG) system should be re-executed to improve the generated answer.

        Consider the following factors:
        1. Does the generated answer fully and directly address the user's intent?
        2. Are there important missing details, factual inaccuracies, or hallucinations?
        3. Would retrieving better or more specific documents likely improve the response quality?

        Follow this process:
        - First, provide a brief reasoning explaining your judgment in korean. Be specific: refer to any missing or inadequate parts of the answer.
        - Then, decide whether retrieval should be redone by responding with **"Yes"** or **"No"**.

        Instructions:
        - Respond with "Yes" if redoing retrieval is likely to produce a significantly better answer.
        - Respond with "No" if the current answer is sufficient and retrieval would not meaningfully improve it.

        ---

        User Input:
        {user_input}

        Generated Answer:
        {generated_answer}

        ---

        Reasoning:
        (Write your explanation here.)

        Final decision (Yes/No):

        format:
        {format}
        """)

    prompt = prompt.partial(format=parser.get_format_instructions())

    llm = OllamaLLM(model='gemma3:12b')
    # llm = ChatOpenAI(model=model_name, temperature=0.2, api_key=current_api_key)
    generate_binary_chain = prompt | llm

    results = generate_binary_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [(parser.parse(result).need_retrieve, parser.parse(result).reasoning) for result in results]

    return results

In [15]:
max_concurrency = 3
input_querys = merged_dataset['user_input'].to_list()
generated_answers = merged_dataset['generated_answer'].to_list()
input_data = [{'user_input': question, 'generated_answer': answer} for question, answer in zip(input_querys, generated_answers)]

results = []
for i in tqdm(range(len(input_data) // max_concurrency)):
    results.extend(generate_binary_chain(input_data[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_data) % max_concurrency != 0:
    results.extend(generate_binary_chain(input_data[len(input_data) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))

merged_dataset['generated_need_retrieve'] = results
merged_dataset.to_csv('../data/rag/self_reflection.csv', index=False)

100%|██████████| 18/18 [04:25<00:00, 14.76s/it]


In [17]:
metrics, report = evaluate_binary_sufficiency(merged_dataset, 'generated_need_retrieve')
print(metrics)
print(report)

{'accuracy': 0.7321428571428571, 'precision': 0.7586206896551724, 'recall': 0.7333333333333333, 'f1_score': 0.7457627118644068}
              precision    recall  f1-score   support

          no       0.70      0.73      0.72        26
         yes       0.76      0.73      0.75        30

    accuracy                           0.73        56
   macro avg       0.73      0.73      0.73        56
weighted avg       0.73      0.73      0.73        56



### 2-2. Confidence Score

In [32]:
# @handle_rate_limits
def generate_confidence_chain(data_batches, model_name='gpt-4o-mini', current_api_key=None, max_concurrency=5) -> List:
    class NeedRetrieve(BaseModel):
        retrieval_necessity: float = Field(
            description="A confidence score between 0.0 and 1.0 indicating whether the retrieval step should be re-executed. 0.0 = no need, 1.0 = strongly needs to be re-executed."
        )
        reasoning: str = Field(description="Maximum 3 sentences explaining the decision")

    parser = PydanticOutputParser(pydantic_object=NeedRetrieve)

    prompt = PromptTemplate.from_template("""
        Given the user's input and the generated answer, evaluate how strongly the retrieval step in a
        retrieval-augmented generation (RAG) system should be re-executed to improve the answer quality.
        Output a confidence score between 0.0 and 1.0.
        before you output the score, reason about the factors that influence the score in korean.
        ---

        Consider the following factors:

        - How well the generated answer addresses the user's intent (0.0 = completely irrelevant, 1.0 = fully aligned)
        - Whether the answer contains hallucinations, missing key details, or irrelevant information
        - Whether better or more specific supporting documents could significantly improve the response

        User Input:
        {user_input}

        Generated Answer:
        {generated_answer}

        format:
        {format}
        """
    )
    prompt = prompt.partial(format=parser.get_format_instructions())

    llm = OllamaLLM(model='gemma3:12b')
    generate_confidence_chain = prompt | llm

    results = generate_confidence_chain.batch(data_batches, config={"max_concurrency": max_concurrency})
    results = [(parser.parse(result).retrieval_necessity, parser.parse(result).reasoning) for result in results]

    return results    

In [33]:
max_concurrency = 3
input_querys = merged_dataset['user_input'].to_list()
generated_answers = merged_dataset['generated_answer'].to_list()
input_data = [{'user_input': question, 'generated_answer': answer} for question, answer in zip(input_querys, generated_answers)]

results = []
for i in tqdm(range(len(input_data) // max_concurrency)):
    results.extend(generate_confidence_chain(input_data[i * max_concurrency: (i + 1) * max_concurrency], max_concurrency=max_concurrency))

if len(input_data) % max_concurrency != 0:
    results.extend(generate_confidence_chain(input_data[len(input_data) // max_concurrency * max_concurrency:], max_concurrency=max_concurrency))

merged_dataset['generated_confidence_need_retrieve'] = results
merged_dataset.to_csv('../data/rag/self_reflection.csv', index=False)

100%|██████████| 18/18 [05:32<00:00, 18.46s/it]


In [34]:
analyze_confidence_thresholds(merged_dataset, 'generated_confidence_need_retrieve')

Unnamed: 0,threshold,precision,recall,f1_score,retrieve_count,retrieve_rate,true_positive,false_positive,missed_retrieves
0,0.0,0.536,1.0,0.698,56,1.0,30,26,0
1,0.1,0.536,1.0,0.698,56,1.0,30,26,0
2,0.2,0.536,1.0,0.698,56,1.0,30,26,0
3,0.3,0.618,0.7,0.656,34,0.607,21,13,9
4,0.4,0.618,0.7,0.656,34,0.607,21,13,9
5,0.5,0.618,0.7,0.656,34,0.607,21,13,9
6,0.6,1.0,0.033,0.065,1,0.018,1,0,29
7,0.7,0.0,0.0,0.0,0,0.0,0,0,30
8,0.8,0.0,0.0,0.0,0,0.0,0,0,30
9,0.9,0.0,0.0,0.0,0,0.0,0,0,30


In [37]:
tmp = merged_dataset.copy()
# tmp['generated_confidence_need_retrieve'] = tmp['generated_confidence_need_retrieve'].apply(lambda x : x[0])

tmp.groupby(['need_retrieve', 'generated_confidence_need_retrieve']).size().reset_index(name='count')

Unnamed: 0,need_retrieve,generated_confidence_need_retrieve,count
0,no,0.3,13
1,no,0.6,13
2,yes,0.3,9
3,yes,0.6,20
4,yes,0.7,1
