# Import

In [1]:
import os

from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import pandas as pd

from configs.exp import build_exp
from configs.env import EnvDefineUnit
from data_engineering.dataset.precendent import CSVPrecendentDataset, install_pipeline
from data_engineering.prompt_engineering.LLM_template import get_prompt_template
from data_engineering.prompt_engineering.precendent_to_docs import get_prompt_precendent
from data_engineering.prompt_engineering.precendent_to_question import get_prompt_question
from data_engineering.dataset.guideline import PDFDataset
from data_engineering.RAG import build_vectorstore
from model.LLM import load_llm_model_huggingface

  from .autonotebook import tqdm as notebook_tqdm
2025-03-05 08:43:04,005	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [None]:
env = EnvDefineUnit()
config_exp = build_exp('exp_1')
# 경로
path_train = os.path.join(env.PATH_DATA_DIR, config_exp.train)
path_test = os.path.join(env.PATH_DATA_DIR, config_exp.test)
paths_pdf = os.path.join(env.PATH_DATA_DIR, 'raw', '건설안전지침')
paths_pdf = [os.path.join(paths_pdf, x) for x in os.listdir(paths_pdf)]


In [3]:

# 실험 파라미터
encoding = config_exp.data_encoding
pipeline = config_exp.data_pipeline
prompt_template = config_exp.prompt_template
chain_type1 = config_exp.RAG_chain_type1
chain_type2 = config_exp.RAG_chain_type2
model_name = config_exp.model_name

# Data Load & Pre-processing

In [4]:
pipeline = install_pipeline(pipeline)

In [None]:
# 데이터 로드
test_data = pd.read_csv(path_test, encoding = encoding)
precendent = pd.read_csv(path_train, encoding = encoding)
guidelines = PDFDataset(paths_pdf)

test_data = pipeline(test_data)
precendent = pipeline(precendent)

precendents = []
for i, row in precendent.iterrows():
    prec = get_prompt_precendent(row)
    precendents.append(prec)

# Vector store 생성

In [7]:

# 벡터스토어 생성
retriever_precendent = build_vectorstore(precendents)
print("벡터스토어 생성 완료")

retriever_guidelines = build_vectorstore(guidelines)

print("벡터스토어 생성 완료")
# 템플릿 프롬프트
prompt_template = get_prompt_template(exp = prompt_template)
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=prompt_template.template,
)


  embedding = HuggingFaceEmbeddings(model_name=embedding_model_name)
23322it [00:00, 3381600.51it/s]


벡터스토어 생성 완료


1883it [00:00, 1168151.82it/s]


벡터스토어 생성 완료


# Model import

In [6]:

# LLM 모델 로드
llm = load_llm_model_huggingface(model_name)
print("모델로드완료")

Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.91s/it]
Device set to use cuda:0


모델로드완료


  llm = HuggingFacePipeline(pipeline=text_generation_pipeline)


# RAG chain 생성

In [8]:
# RAG 체인 (DF 기반)
chain_df = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type=chain_type1,
    retriever=retriever_precendent,
    return_source_documents=True,
    chain_type_kwargs={"prompt": prompt}
)

# Inference

In [9]:
# 추론
test_results = []
for idx, row in test_data.iterrows():
    question = get_prompt_question(row)
    result_df = chain_df.invoke(question)
    
    # result_pdf = chain_pdf.invoke(question)

    # 사용자가 원하는 방식으로 두 결과를 합치거나, 둘 중 하나만 선택
    # 여기서는 DF 결과와 PDF 결과를 단순 연결 예시
    final_result = result_df['result']
    test_results.append(final_result)

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


KeyboardInterrupt: 

# Submission

In [None]:
from sentence_transformers import SentenceTransformer

embedding_model_name = "jhgan/ko-sbert-sts"
embedding = SentenceTransformer(embedding_model_name)

# 문장 리스트를 입력하여 임베딩 생성
pred_embeddings = embedding.encode(test_results)
print(pred_embeddings.shape)  # (샘플 개수, 768)

In [None]:
submission = pd.read_csv('/workspace/Storage/hansoldeco3/Data/sample_submission.csv', encoding = 'utf-8-sig')

# 최종 결과 저장
submission.iloc[:,1] = test_results
submission.iloc[:,2:] = pred_embeddings
submission.head()

# 최종 결과를 CSV로 저장
submission.to_csv('./baseline_submission.csv', index=False, encoding='utf-8-sig')