# Create dataset for stf

In [1]:
import os
import unicodedata

import torch
import pandas as pd
from tqdm import tqdm
import fitz  # PyMuPDF
import pickle

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    BitsAndBytesConfig,
    TrainingArguments
)
from accelerate import Accelerator

# Langchain 관련
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def process_pdf(file_path, chunk_size=800, chunk_overlap=50):
    """PDF 텍스트 추출 후 chunk 단위로 나누기"""
    # PDF 파일 열기
    doc = fitz.open(file_path)
    text = ''
    # 모든 페이지의 텍스트 추출
    for page in doc:
        text += page.get_text()
    # 텍스트를 chunk로 분할
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    chunk_temp = splitter.split_text(text)
    # Document 객체 리스트 생성
    chunks = [Document(page_content=t) for t in chunk_temp]
    return chunks


def create_vector_db(chunks, model_path="intfloat/multilingual-e5-base"):
    """FAISS DB 생성"""
    # 임베딩 모델 설정
    model_kwargs = {'device': 'cuda'}
    encode_kwargs = {'normalize_embeddings': True}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_path,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs
    )
    # FAISS DB 생성 및 반환
    db = FAISS.from_documents(chunks, embedding=embeddings)
    return db

def normalize_path(path):
    """경로 유니코드 정규화"""
    return unicodedata.normalize('NFC', path)


def process_pdfs_from_dataframe(df, base_directory):
    """딕셔너리에 pdf명을 키로해서 DB, retriever 저장"""
    pdf_databases = {}
    unique_paths = df['Source_path'].unique()
    
    for path in tqdm(unique_paths, desc="Processing PDFs"):
        # 경로 정규화 및 절대 경로 생성
        normalized_path = normalize_path(path)
        full_path = os.path.normpath(os.path.join(base_directory, normalized_path.lstrip('./'))) if not os.path.isabs(normalized_path) else normalized_path
        
        pdf_title = os.path.splitext(os.path.basename(full_path))[0]
        print(f"Processing {pdf_title}...")
        
        # PDF 처리 및 벡터 DB 생성
        chunks = process_pdf(full_path)
        db = create_vector_db(chunks)
        
        # Retriever 생성
        retriever = db.as_retriever(search_type="mmr", 
                                    search_kwargs={'k': 3, 'fetch_k': 8})
        
        # 결과 저장
        pdf_databases[pdf_title] = {
                'db': db,
                'retriever': retriever
        }
    return pdf_databases


In [3]:
base_directory = "./"
dataset_directory = base_directory + "train.csv"
df = pd.read_csv(dataset_directory)
pdf_databases = process_pdfs_from_dataframe(df, base_directory)

# PDF 데이터베이스를 pickle 파일로 저장
with open('pdf_databases_e5_base.pickle', 'wb') as f:
    pickle.dump(pdf_databases, f)

print("pdf_databases가 pickle 파일로 저장되었습니다.")

Processing PDFs:   0%|          | 0/16 [00:00<?, ?it/s]

Processing 1-1 2024 주요 재정통계 1권...


  warn_deprecated(
Processing PDFs:   6%|▋         | 1/16 [00:05<01:28,  5.92s/it]

Processing 2024 나라살림 예산개요...


Processing PDFs:  12%|█▎        | 2/16 [00:12<01:26,  6.21s/it]

Processing 재정통계해설...


Processing PDFs:  19%|█▉        | 3/16 [00:17<01:12,  5.54s/it]

Processing 국토교통부_전세임대(융자)...


Processing PDFs:  25%|██▌       | 4/16 [00:20<00:58,  4.84s/it]

Processing 고용노동부_청년일자리창출지원...


Processing PDFs:  31%|███▏      | 5/16 [00:24<00:48,  4.41s/it]

Processing 고용노동부_내일배움카드(일반)...


Processing PDFs:  38%|███▊      | 6/16 [00:28<00:42,  4.21s/it]

Processing 보건복지부_노인일자리 및 사회활동지원...


Processing PDFs:  44%|████▍     | 7/16 [00:32<00:37,  4.12s/it]

Processing 중소벤처기업부_창업사업화지원...


Processing PDFs:  50%|█████     | 8/16 [00:35<00:31,  3.95s/it]

Processing 보건복지부_생계급여...


Processing PDFs:  56%|█████▋    | 9/16 [00:39<00:27,  3.91s/it]

Processing 국토교통부_소규모주택정비사업...


Processing PDFs:  62%|██████▎   | 10/16 [00:42<00:22,  3.73s/it]

Processing 국토교통부_민간임대(융자)...


Processing PDFs:  69%|██████▉   | 11/16 [00:46<00:18,  3.74s/it]

Processing 고용노동부_조기재취업수당...


Processing PDFs:  75%|███████▌  | 12/16 [00:50<00:14,  3.65s/it]

Processing 2024년도 성과계획서(총괄편)...


Processing PDFs:  81%|████████▏ | 13/16 [01:35<00:48, 16.28s/it]

Processing 「FIS 이슈 & 포커스」 23-3호 《조세지출 연계관리》...


Processing PDFs:  88%|████████▊ | 14/16 [01:42<00:26, 13.43s/it]

Processing 「FIS 이슈 & 포커스」 22-3호 《재정융자사업》...


Processing PDFs:  94%|█████████▍| 15/16 [01:50<00:11, 11.88s/it]

Processing 월간 나라재정 2023년 12월호...


Processing PDFs: 100%|██████████| 16/16 [02:52<00:00, 10.79s/it]


pdf_databases가 pickle 파일로 저장되었습니다.


In [2]:
base_directory = "./"
dataset_directory = base_directory + "train.csv"
df = pd.read_csv(dataset_directory)
with open('pdf_databases_e5_base.pickle', 'rb') as f:
    pdf_databases = pickle.load(f)

  return torch.load(io.BytesIO(b))


In [4]:
def normalize_string(s):
    """유니코드 정규화"""
    return unicodedata.normalize('NFC', s)

def format_docs(docs):
    """검색된 문서들을 하나의 문자열로 포맷팅"""
    context = ""
    for doc in docs:
        context += doc.page_content
        context += '\n'
    return context

# 결과를 저장할 리스트 초기화
results = []

# 배치 사이즈 설정
batch_size = 1  # 원하는 배치 크기로 설정

# DataFrame의 각 행에 대해 처리
for start in tqdm(range(0, len(df), batch_size), desc="Creating Q&A including RAG info"):
    # 현재 배치 선택
    batch_rows = df.iloc[start:start + batch_size]

    # 배치 내의 각 행 처리
    for _, row in batch_rows.iterrows():
        # 소스 문자열 정규화
        source = normalize_string(row['Source'])
        question = row['Question']
        answer = row['Answer']

        # 정규화된 키로 데이터베이스 검색
        normalized_keys = {normalize_string(k): v for k, v in pdf_databases.items()}
        retriever = normalized_keys[source]['retriever']
        
        context = format_docs(retriever.invoke(question))

        results.append({
            "Context" : context,
            "Question": question,
            "Answer": answer
        })

Creating Q&A including RAG info: 100%|██████████| 496/496 [00:07<00:00, 65.82it/s]


In [6]:
stf_train_df = pd.DataFrame(results[:int(len(results)*0.8+0.5)])
stf_eval_df = pd.DataFrame(results[int(len(results)*0.2+0.5):])

stf_train_df.to_csv("stf_e5_base_train.csv", index=False, encoding="UTF-8")
stf_eval_df.to_csv("stf_e5_base_eval.csv", index=False, encoding="UTF-8")

In [None]:
stf_train = pd.read_csv("data/stf_train.csv")
print(stf_train.Question[0])