In [73]:
import os
import sys
import warnings
import torch
from pathlib import Path, PosixPath
from langchain.docstore.document import Document
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModel
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.retrievers import ContextualCompressionRetriever



from langchain_community.docstore.in_memory import InMemoryDocstore

from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder

sys.path.append("..")
from scr.readers import *
from scr.download_models import *

In [None]:
from transformers import AutoTokenizer, AutoModel


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
data_folder = Path("../data")

In [4]:
raw_documents = get_raw_documents(data_folder)

In [28]:
model_name = "deepvk/USER-bge-m3"

In [32]:
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [36]:
CHUNK_SIZE = 512

text_splitter = RecursiveCharacterTextSplitter(
    separators=' ', 
    chunk_size=CHUNK_SIZE,
    chunk_overlap=int(CHUNK_SIZE / 5),
    strip_whitespace=True, 
    length_function=lambda x: len(tokenizer.encode(x, add_special_tokens=False))
)

chunked_documents_with_page_content = text_splitter.split_documents(raw_documents)
print(len(chunked_documents_with_page_content))

133


In [66]:
type(text_splitter)

langchain_text_splitters.character.RecursiveCharacterTextSplitter

In [40]:
chunked_documents = [doc.page_content for doc in chunked_documents_with_page_content]

In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using {device}")

embedding_model = HuggingFaceEmbeddings(
    model_name="deepvk/USER-bge-m3",
    model_kwargs={'device': device})

using cpu


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [45]:
db = FAISS.from_texts(chunked_documents, embedding_model, distance_strategy=DistanceStrategy.COSINE)

In [49]:
db.save_local(folder_path="../data/vector_dbs", index_name="faiss_first")

In [48]:
num_docs_retrieve = 5

retriever = db.as_retriever(search_kwargs={"k": num_docs_retrieve})

bm25_retriever = BM25Retriever.from_documents(
    chunked_documents_with_page_content)
bm25_retriever.k = num_docs_retrieve

In [51]:
vector_database = EnsembleRetriever(
    retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5])

In [64]:
num_docs_rerank = 3

model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
compressor = CrossEncoderReranker(model=model, top_n=num_docs_rerank)

In [74]:
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=vector_database
)

In [75]:
docs = compression_retriever.invoke('революция')

In [102]:
def process_data(doc: Document) -> dict:
    """Редактирование документа, полученного из ретривера

    Args:
        doc (Document): документ из ретривера

    Returns:
        dict: словарь с содержанием 'содержание' и названием источника 'источник'
    """
    page_content = doc.page_content
    source = doc.metadata['source'].split("/")[-1].split(".")[0]
    return {'источник': source, 'содержание': page_content}

def process_retrieve_output(docs: list[Document]) -> list[dict]:
    """Функция обработки документов из ретривера.

    Args:
        docs (list[Document]): документы из ретривера

    Returns:
        list[dict]: обработанные документы
    """
    processed_documents = []
    for doc in docs:
        processed_doc = process_data(doc)
        processed_documents.append(processed_doc)
    return processed_documents

In [103]:
processed_documents = process_retrieve_output(docs)

___ 