In [50]:
from datetime import datetime
from langchain.schema.document import Document
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings as STE
from embedding import EmbeddingLoader
import pandas as pd
import re

In [51]:
#decorator
def checktime(func):
    def wrapper(*args, **kwargs):
        start_time = datetime.now()
        result = func(*args, **kwargs)
        end_time = datetime.now()
        print(f"Function call {func.__name__} took {(end_time - start_time).total_seconds()}s to run.\n")
        return result
    return wrapper

#get model file from .txt file
@checktime
def get_hf_model_names() -> list:
    try:
        with open(file="model/model_list.txt", mode="r", encoding="utf-8") as file:
            model_list = ["model/"+line.strip() for line in file]
    except:
        print("""file not exsist. check directory or file.""")

    return model_list

@checktime
def get_collection_names(model_list:list, if_openai=True) -> list:
    collection_names = [model_name.split("/")[-1] for model_name in model_list]
    if if_openai:
        collection_names.append("text-embedding-ada-002")

    return collection_names

@checktime
def loading_hf_embedding(model_path):
    memorystore = []
    for path in model_path:
        memorystore.append(EmbeddingLoader.SentenceTransformerEmbedding(model_name=path, encode_kwargs={'normalize_embeddings': True}).load())

    return memorystore

In [52]:
### calling(get model names, collections)
hf_model_path = get_hf_model_names()
collection_names = get_collection_names(hf_model_path)

### loading embedding with hf_model_path
embedding_memorystore = loading_hf_embedding(hf_model_path) #list 저장

### also loading OPENAI Embedding and save
openai_embedding = EmbeddingLoader.OpenAIEmbedding().load() #loading OPENAI Embedding(@text-ada-002) object
embedding_memorystore.append(openai_embedding) #append openai object

Function call get_hf_model_names took 0.00051s to run.

Function call get_collection_names took 0.0s to run.

embedding model in path <model/sentence-transformers/paraphrase-multilingual-mpnet-base-v2> has been loaded successfully.
Function call load took 3.755441s to run.

embedding model in path <model/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2> has been loaded successfully.
Function call load took 2.191913s to run.

embedding model in path <model/sentence-transformers/distiluse-base-multilingual-cased-v2> has been loaded successfully.
Function call load took 1.60226s to run.

embedding model in path <model/sentence-transformers/stsb-xlm-r-multilingual> has been loaded successfully.
Function call load took 3.754748s to run.

embedding model in path <model/jhgan/ko-sroberta-multitask> has been loaded successfully.
Function call load took 1.970871s to run.

embedding model in path <model/snunlp/KR-SBERT-V40K-klueNLI-augSTS> has been loaded successfully.
Function call l

In [53]:
# 12 hf embeddings + openai embedding(@text-ada-002)
print(len(embedding_memorystore))
#check collection names before launch
print(collection_names)

13
['paraphrase-multilingual-mpnet-base-v2', 'paraphrase-multilingual-MiniLM-L12-v2', 'distiluse-base-multilingual-cased-v2', 'stsb-xlm-r-multilingual', 'ko-sroberta-multitask', 'KR-SBERT-V40K-klueNLI-augSTS', 'moco-sentencedistilbertV2.1', 'kpf-sbert-128d-v1', 'M-BERT-Distil-40', 'canine-c', 'roberta-ko-small-tsdae', 'KoSimCSE-roberta-multitask', 'text-embedding-ada-002']


In [54]:
def _ab_getter(collection_name:str)->list:
    "get embedding list and add collection names"
    footer = ["-a", "-b"]
    return [collection_name+footer_string for footer_string in footer]

def _strip_replace_text(s: str)->str:
    "regex"
    regex = '([^가-힣0-9a-zA-Z])'
    s = re.sub(pattern=regex, repl="", string=s)
    return s

In [55]:
## check regexed document name
# for check in hitrate_doc:
#     print(_strip_replace_text(check))

#### 여기에서부터 QA query로 넣어서 hitrate 평가(hitrate_test_qa.csv)
- k 개수에 따라서 (@1, @3, @5, @10) hitrate 평가 후 csv file로 저장

In [56]:
## wrapper 함수들 ->

def get_hitrate(embedding_function, collection_list:list, k:int, hitrate_csv:pd.DataFrame, result_df:pd.DataFrame) -> pd.DataFrame:
    hitrate_q = hitrate_csv["Question"]
    hitrate_doc = hitrate_csv["Documents"] 

    for collection_name_complete in collection_list:
        hit = 0
    
        db = Chroma(persist_directory="chroma", collection_name=collection_name_complete, embedding_function=embedding_function)
        print(f"with document set <Team {collection_name_complete[-1].upper()}>: {db._collection.name}, {db._collection.count()}, with k={k}")

        for hitq, hitdoc in zip(hitrate_q, hitrate_doc):
            search_result = db.similarity_search(query=hitq, k=k)

            if search_result_rating(search_result, answer_document=hitdoc):
                print(f"question <{hitq}> hit!")
                hit += 1

        result_df.loc[f'@{k}', collection_name_complete] = (hit/len(hitrate_q))*100
        print((hit/len(hitrate_q))*100)

        ### result sep
        print("="*80,"\n")

    return result_df

def search_result_rating(search_result:Document, answer_document:str)->bool:
    """
        Set flag and check if it meets answer_document.
        Args: 
            search_result: search result. Langchain Document object.
            answer_document:
    """
    answer_parsed = _strip_replace_text(answer_document)
    for document in search_result:
        parsed_title = _strip_replace_text(document.metadata["title"])
    
        if parsed_title == answer_parsed:
            return True

    return False

In [57]:
### main
## hit rate test csv load
hitrate_csv = pd.read_csv("hitrate_test_qa.csv", encoding='utf-8')
result_df = pd.DataFrame()

for embedding_function, collection_name in zip(embedding_memorystore, collection_names):
    ab_names = _ab_getter(collection_name=collection_name)

    for knum in [1,3,5,10]:
        result_df = get_hitrate(embedding_function=embedding_function, collection_list=ab_names, k=knum, hitrate_csv=hitrate_csv, result_df=result_df)

result_df.to_csv("hitrate_result.csv", encoding="utf-8")

with document set <Team A>: paraphrase-multilingual-mpnet-base-v2-a, 1782, with k=1


question <우리 집이 최근 화재로 인해 살 곳을 잃었어요. 긴급하게 지원받을 수 있는 방법이 있을까요?> hit!
question <제가 저소득층인데, 쌀 지원을 위한 프로그램 신청은 어떻게 하나요?> hit!
question <집안의 낡은 전구를 에너지 효율이 좋은 램프로 교체하고 싶어요. 지원을 받으려면 어디로 연락해야 하나요?> hit!
question <최근 직장을 그만두었는데, 건강보험을 계속해서 유지하고 싶어요. 어떤 제도를 활용할 수 있을까요?> hit!
question <고용보험에 18개월간 가입했지만, 총 근무일수가 180일이 안 되는데 실업급여를 신청할 수 있나요?> hit!
question <범죄 경력이 있는 20살 청년이 취업을 희망할 때 이용할 수 있는 프로그램은 무엇인가요?> hit!
question <제가 고용보험에 가입하지 않았는데, 출산 후 지원을 받고 싶어요. 어디로 문의해야 하나요?> hit!
question <임신 중에 어떤 예방접종을 해야 하는지 알고 싶습니다. 정보를 어디서 찾을 수 있나요?> hit!
question <조산이나 저체중으로 태어난 아기가 병원 치료를 받을 때 부모님의 부담금 비율은 어떻게 되나요?> hit!
question <만 1세 된 아이를 어린이집에 보내지 않고 가정에서 양육할 때 부모가 받을 수 있는 현금 지원은 어떻게 되나요?> hit!
question <열 살짜리 아이가 입원을 해서 병원비를 내야 하는데, 어떤 부분에 대해 지원을 받을 수 있나요?> hit!
question <만 19세인데 학교를 다니지 않아요. 건강검진을 받을 수 있나요?> hit!
question <위기상황에 처한 청소년이 안전의 위협을 받을때 어디로 연락해야할까요?> hit!
question <고등학교 1학년 학생이 폰 없이는 못 사는 것 같아 걱정이에요. 무슨 조치를 취할 수 있을까요?> hit!
question <대체복무요원으로 일하고 있는데 가정에 긴급한 경제적 문제가 생겼어요. 이럴 때 병역을 감면받을 수 있