In [2]:
import pandas as pd
import numpy as np

# import matplotlib.pyplot as plt
# import seaborn as sns

from tqdm import tqdm
from datetime import datetime
import re
import json

import os
from FlagEmbedding import FlagReranker
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
import chromadb

In [13]:
DATA_PATH = '../data/russianPoetryWithTheme_deduped.csv'
EVAL_DATA_PATH = '../data/EvalBench_retriever.csv'

EMBED_MODEL_NAME = "ai-forever/FRIDA"
RERANK_MODEL_NAME ='BAAI/bge-reranker-v2-m3'

AUTHORS_COL = 'author'
POEMS_COL = 'name'
TXT_COL = 'text'
RAG_METADATA_COLS = ['date_to', 'author', 'name', 'idx']
RAG_TXT_COL = 'text'
RAG_SEARCH_METHOD = 'marginal'

In [8]:
data = pd.read_csv(DATA_PATH)
data = data.reset_index(names='idx')

eval_data = pd.read_csv(EVAL_DATA_PATH)
eval_data = eval_data.merge(data[['author', 'name', 'idx']], how='inner', on=['author', 'name'])

In [12]:
EMBED_MODEL_NAME.split('/')[1]

'FRIDA'

In [19]:
class RAGService:
    def __init__(self, embed_model, rerank_model, data, persist_directory='chroma'):
        self.embed_model = embed_model
        self.rerank_model = rerank_model
        self.persist_directory = persist_directory
        self.ini_data = data
        self.db = None

    def load_db(self):
        if os.path.exists(self.persist_directory) and os.listdir(self.persist_directory):
            print("[RAGService] Loading existing ChromaDB from disk...")
            self.db = Chroma(
                persist_directory=self.persist_directory,
                embedding_function=self.embed_model
            )
        else:
            print("[RAGService] No existing DB found at path: ", self.persist_directory)
            self.db = None

    def create_from_data(
        self,
        metadata_cols,
        txt_col,
        rag_separators=["\n\n", "\n", ".", " ", ""],
        prefix_document='search_document: ',
        chunk_size=300,
        chunk_overlap=25
    ):
        all_cols = [txt_col] + metadata_cols
        data = self.ini_data[all_cols]

        print("[RAGService] Creating new ChromaDB...")
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=rag_separators
        )

        docs = []
        for row in data.values:
            text = row[0]
            metadata_vals = [row[data.columns.get_loc(col)] for col in metadata_cols]
            chs = splitter.split_text(text)
            for ch in chs:
                doc = Document(
                    page_content=f"{prefix_document}{ch.strip()}",
                    metadata=dict(zip(metadata_cols, metadata_vals))
                )
                docs.append(doc)

        self.db = Chroma.from_documents(
            documents=docs,
            embedding=self.embed_model,
            persist_directory=self.persist_directory
        )
        print("[RAGService] New DB created and saved.")

    def _rerank(self, query, results):
      to_rerank = [[query, res.page_content.replace('search_document: ', '') ] for res in results]
      ranks = self.rerank_model.compute_score(to_rerank, normalize=True)
      reranked = sorted(zip(results, ranks), key=lambda x: -x[1])

      return [i[0] for i in reranked]

    # def search(self, ini_query, query, prefix_query='search_query: ', method="similarity", k=20, filters=None, rerank=True, rerank_k=5):
    #     if self.db is None:
    #         raise ValueError("Database is not loaded. Please create or load a database first.")

    #     final_query = f'{prefix_query}{query}'
    #     if method == "similarity":
    #         if filters:
    #             results = self.db.similarity_search(final_query, k=k, filter=filters)
    #         else:
    #             results = self.db.similarity_search(final_query, k=k)
    #     elif method == "marginal":
    #         if filters:
    #             results = self.db.max_marginal_relevance_search(final_query, k=k, filter=filters)
    #         else:
    #             results = self.db.max_marginal_relevance_search(final_query, k=k)
    #     else:
    #         raise ValueError(f"Unknown search method: {method}")

    #     if rerank:
    #       results = self._rerank(f"Запрос: {ini_query}. Образы: {query}", results)
    #       results = results[:rerank_k]

    #     return results

    def search(self, query, prefix_query='search_query: ', method="similarity", k=20, filters=None, rerank=True, rerank_k=5):
        if self.db is None:
            raise ValueError("Database is not loaded. Please create or load a database first.")

        final_query = f'{prefix_query}{query}'
        if method == "similarity":
            if filters:
                results = self.db.similarity_search(final_query, k=k, filter=filters)
            else:
                results = self.db.similarity_search(final_query, k=k)
        elif method == "marginal":
            if filters:
                results = self.db.max_marginal_relevance_search(final_query, k=k, filter=filters)
            else:
                results = self.db.max_marginal_relevance_search(final_query, k=k)
        else:
            raise ValueError(f"Unknown search method: {method}")

        if rerank:
          results = self._rerank(query, results)
          results = results[:rerank_k]

        return results


In [28]:
def calc_mrr(target, predictions, n):
    for rank, item in enumerate(predictions[:n], start=1):
        if item == target:
            return 1 / rank
    return 0

def get_metrics(rag_svc):
  metrics = {
      'long': {'mrr@1': [], 'mrr@3': [], 'mrr@5': [], 'mrr@10': []},
      'short': {'mrr@1': [], 'mrr@3': [], 'mrr@5': [], 'mrr@10': []},
  }
  for row in eval_data.values:
      author, name, long_query, short_query, quote_query, idx = row 
      
      results_long = rag_svc.search(long_query, rerank_k=10)
      results_long = [i.metadata['idx'] for i in results_long]
      for n in [1, 3, 5, 10]:
          metrics['long'][f'mrr@{n}'].append(calc_mrr(idx, results_long, n))
      
      results_short = rag_svc.search(short_query, rerank_k=10)
      results_short = [i.metadata['idx'] for i in results_short]
      for n in [1, 3, 5, 10]:
          metrics['short'][f'mrr@{n}'].append(calc_mrr(idx, results_short, n))

  return metrics


def get_fin_res(lst):
  rows = []
  for model_name, metrics in lst:
      for query_type in ['long', 'short']:
          loc_lst = []
          for n in [1, 3, 5, 10]:
              mean_mrr = np.mean(metrics[query_type][f'mrr@{n}'])
              rows.append({
                  'model': model_name,
                  'query_length': query_type,
                  'top_n': f'MRR@{n}',
                  'value': mean_mrr.round(3)
              })
              loc_lst.append(mean_mrr.round(3))
          loc_mean = np.mean(loc_lst).round(3)
          rows.append({
                  'model': model_name,
                  'query_length': query_type,
                  'top_n': 'MEAN',
                  'value': loc_mean
              })
  df = pd.DataFrame(rows)
  df_pivot = df.pivot(index='model', columns=['query_length', 'top_n'], values='value')

  return df_pivot

In [76]:
embed_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
# embed_model = HuggingFaceEmbeddings(model_name='BAAI/bge-m3')

In [80]:
# reranker = FlagReranker(RERANK_MODEL_NAME, use_fp16=True)
reranker = FlagReranker('Alibaba-NLP/gte-multilingual-reranker-base', use_fp16=True, trust_remote_code=True)

In [20]:
print(datetime.now())
svc_path = '../data/chroma_FRIDA_150_25' # !!!!!!!
rag_svc = RAGService(embed_model, reranker, data, persist_directory = svc_path)
# rag_svc.create_from_data(
#     metadata_cols = RAG_METADATA_COLS,
#     txt_col = RAG_TXT_COL,
#     chunk_size=150, # !!!!!!!
#     chunk_overlap=25 # !!!!!!!
# )
rag_svc.load_db()
print(datetime.now())

2025-05-17 17:32:19.183426
[RAGService] Loading existing ChromaDB from disk...
2025-05-17 17:32:19.188988


  self.db = Chroma(


In [32]:
svc_path = '../data/chroma_FRIDA_300_25' # !!!!!!!
rag_svc_300 = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_300.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=300, # !!!!!!!
    chunk_overlap=25 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [36]:
svc_path = '../data/chroma_FRIDA_450_25' # !!!!!!!
rag_svc_450 = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_450.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=450, # !!!!!!!
    chunk_overlap=25 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [82]:
svc_path = '../data/chroma_FRIDA_600_25' # !!!!!!!
rag_svc_600 = RAGService(embed_model, reranker, data, persist_directory = svc_path)
# rag_svc_600.create_from_data(
#     metadata_cols = RAG_METADATA_COLS,
#     txt_col = RAG_TXT_COL,
#     chunk_size=600, # !!!!!!!
#     chunk_overlap=25 # !!!!!!!
# )
rag_svc_600.load_db()

[RAGService] Loading existing ChromaDB from disk...


In [45]:
svc_path = '../data/chroma_FRIDA_750_25' # !!!!!!!
rag_svc_750 = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_750.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=750, # !!!!!!!
    chunk_overlap=25 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [55]:
svc_path = '../data/chroma_FRIDA_600_50' # !!!!!!!
rag_svc_600_50 = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_600_50.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=600, # !!!!!!!
    chunk_overlap=50 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [61]:
svc_path = '../data/chroma_BGE_600_25' # !!!!!!!
rag_svc_600_bge = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_600_bge.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=600, # !!!!!!!
    chunk_overlap=25 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [69]:
svc_path = '../data/chroma_BGE_450_25' # !!!!!!!
rag_svc_450_bge = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_450_bge.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=450, # !!!!!!!
    chunk_overlap=25 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [72]:
svc_path = '../data/chroma_BGE_300_25' # !!!!!!!
rag_svc_300_bge = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_300_bge.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=300, # !!!!!!!
    chunk_overlap=25 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [73]:
svc_path = '../data/chroma_BGE_150_25' # !!!!!!!
rag_svc_150_bge = RAGService(embed_model, reranker, data, persist_directory = svc_path)
rag_svc_150_bge.create_from_data(
    metadata_cols = RAG_METADATA_COLS,
    txt_col = RAG_TXT_COL,
    chunk_size=150, # !!!!!!!
    chunk_overlap=25 # !!!!!!!
)

[RAGService] Creating new ChromaDB...
[RAGService] New DB created and saved.


In [83]:
svc_path = '../data/chroma_FRIDA_600_25' # !!!!!!!
rag_svc_600_gte = RAGService(embed_model, reranker, data, persist_directory = svc_path)
# rag_svc_600.create_from_data(
#     metadata_cols = RAG_METADATA_COLS,
#     txt_col = RAG_TXT_COL,
#     chunk_size=600, # !!!!!!!
#     chunk_overlap=25 # !!!!!!!
# )
rag_svc_600_gte.load_db()

[RAGService] Loading existing ChromaDB from disk...


In [88]:
svc_path = '../data/chroma_FRIDA_450_25' # !!!!!!!
rag_svc_450_gte = RAGService(embed_model, reranker, data, persist_directory = svc_path)
# rag_svc_600.create_from_data(
#     metadata_cols = RAG_METADATA_COLS,
#     txt_col = RAG_TXT_COL,
#     chunk_size=600, # !!!!!!!
#     chunk_overlap=25 # !!!!!!!
# )
rag_svc_450_gte.load_db()

[RAGService] Loading existing ChromaDB from disk...


In [91]:
svc_path = '../data/chroma_FRIDA_300_25' # !!!!!!!
rag_svc_300_gte = RAGService(embed_model, reranker, data, persist_directory = svc_path)
# rag_svc_600.create_from_data(
#     metadata_cols = RAG_METADATA_COLS,
#     txt_col = RAG_TXT_COL,
#     chunk_size=600, # !!!!!!!
#     chunk_overlap=25 # !!!!!!!
# )
rag_svc_300_gte.load_db()

[RAGService] Loading existing ChromaDB from disk...


In [94]:
svc_path = '../data/chroma_FRIDA_750_25' # !!!!!!!
rag_svc_750_gte = RAGService(embed_model, reranker, data, persist_directory = svc_path)
# rag_svc_600.create_from_data(
#     metadata_cols = RAG_METADATA_COLS,
#     txt_col = RAG_TXT_COL,
#     chunk_size=600, # !!!!!!!
#     chunk_overlap=25 # !!!!!!!
# )
rag_svc_750_gte.load_db()

[RAGService] Loading existing ChromaDB from disk...


In [111]:
svc_path = '../data/chroma_FRIDA_600_50' # !!!!!!!
rag_svc_600_50_gte = RAGService(embed_model, reranker, data, persist_directory = svc_path)
# rag_svc_600.create_from_data(
#     metadata_cols = RAG_METADATA_COLS,
#     txt_col = RAG_TXT_COL,
#     chunk_size=600, # !!!!!!!
#     chunk_overlap=25 # !!!!!!!
# )
rag_svc_600_50_gte.load_db()

[RAGService] Loading existing ChromaDB from disk...


In [112]:
metrics = get_metrics(rag_svc)
metrics300 = get_metrics(rag_svc_300)
metrics450 = get_metrics(rag_svc_450)
metrics600 = get_metrics(rag_svc_600)
metrics750 = get_metrics(rag_svc_750)
metrics600_50 = get_metrics(rag_svc_600_50)
metrics600_bge = get_metrics(rag_svc_600_bge)
metrics450_bge = get_metrics(rag_svc_450_bge)
metrics300_bge = get_metrics(rag_svc_300_bge)
metrics150_bge = get_metrics(rag_svc_150_bge)
metrics600_gte = get_metrics(rag_svc_600_gte)
metrics450_gte = get_metrics(rag_svc_450_gte)
metrics300_gte = get_metrics(rag_svc_300_gte)
metrics750_gte = get_metrics(rag_svc_750_gte)
metrics600_50_gte = get_metrics(rag_svc_600_50_gte)

In [113]:
res_data = [
    ('FRIDA/bge (150 | 25)', metrics), 
    ('FRIDA/bge (300 | 25)', metrics300),
    ('FRIDA/bge (450 | 25)', metrics450),
    ('FRIDA/bge (600 | 25)', metrics600),
    ('FRIDA/bge (750 | 25)', metrics750),
    ('FRIDA/bge (600 | 50)', metrics600_50),
    ('bge-m3/bge (600 | 25)', metrics600_bge),
    ('bge-m3/bge (450 | 25)', metrics450_bge),
    ('bge-m3/bge (300 | 25)', metrics300_bge),
    ('bge-m3/bge (150 | 25)', metrics150_bge),
    ('FRIDA/gte (600 | 25)', metrics600_gte),
    ('FRIDA/gte (450 | 25)', metrics450_gte),
    ('FRIDA/gte (300 | 25)', metrics300_gte),
    ('FRIDA/gte (750 | 25)', metrics750_gte),
    ('FRIDA/gte (600 | 50)', metrics600_50_gte),
]

df_pivot = get_fin_res(res_data)
df_pivot[('overall', 'MEAN')] = df_pivot.mean(axis=1).round(3)
df_pivot

query_length,long,long,long,long,long,short,short,short,short,short,overall
top_n,MRR@1,MRR@3,MRR@5,MRR@10,MEAN,MRR@1,MRR@3,MRR@5,MRR@10,MEAN,MEAN
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
FRIDA/bge (150 | 25),0.0,0.0,0.0,0.014,0.004,0.0,0.067,0.067,0.098,0.058,0.031
FRIDA/bge (300 | 25),0.0,0.05,0.05,0.081,0.045,0.0,0.05,0.05,0.062,0.04,0.043
FRIDA/bge (450 | 25),0.0,0.133,0.153,0.164,0.113,0.0,0.05,0.075,0.113,0.06,0.086
FRIDA/bge (600 | 25),0.1,0.15,0.19,0.207,0.162,0.2,0.25,0.275,0.298,0.256,0.209
FRIDA/bge (600 | 50),0.0,0.15,0.15,0.177,0.119,0.0,0.133,0.133,0.156,0.106,0.112
FRIDA/bge (750 | 25),0.0,0.183,0.183,0.194,0.14,0.0,0.083,0.103,0.132,0.08,0.11
FRIDA/gte (300 | 25),0.0,0.083,0.108,0.108,0.075,0.2,0.2,0.2,0.21,0.202,0.139
FRIDA/gte (450 | 25),0.1,0.15,0.225,0.225,0.175,0.1,0.2,0.225,0.239,0.191,0.183
FRIDA/gte (600 | 25),0.1,0.15,0.19,0.207,0.162,0.2,0.25,0.275,0.298,0.256,0.209
FRIDA/gte (600 | 50),0.1,0.15,0.22,0.22,0.172,0.2,0.25,0.275,0.299,0.256,0.214


In [116]:
# ress = rag_svc_750_gte.search('Письмо матери от сына, полное любви и вины', rerank_k=10)
# ress = rag_svc_600_gte.search('Письмо матери от сына, полное любви и вины', rerank_k=10)
ress = rag_svc_300_gte.search('Письмо матери от сына, полное любви и вины', rerank_k=10)

In [109]:
queries = [
    'Стихотворение о Бородинской битве',
    'Прощание с другом, написанное с болью и принятием',
    'Стихотворение о нестандартной любви к Родине без пафоса',
    'Письмо матери от сына, полное любви и вины',
    'Стихотворение о грозе, бурное и радостное, как оживление природы',
    'Женщина отказывается быть слабой и зависимой, с холодной решимостью',
    'Терпеливое ожидание любви, герой готов ждать ради настоящего чувства',
    "Стихотворение, где поэт с нежностью обращается к собаке, как к другу",
    "Стихотворение от лица человека, томящегося в заточении, мечтающего о свободе",
    "Сказание о том, как работа не ладится, если каждый преследует свои интересы"
]
for q in queries:
    # ress = rag_svc_750_gte.search(q, rerank_k=10)
    ress = rag_svc_600_gte.search(q, rerank_k=10)
    print(q)
    for i in ress:
        # print(i.page_content)
        # print('-'*80)
        print(i.metadata.get('author'), i.metadata.get('name'))
    print('-'*80)

Стихотворение о Бородинской битве
Василий Жуковский Бородинская годовщина
Михаил Лермонтов Бородино
Василий Жуковский Бородинская годовщина
Михаил Лермонтов Бородино
Василий Жуковский Бородинская годовщина
Василий Жуковский Бородинская годовщина
Василий Жуковский Бородинская годовщина
Александр Пушкин Бородинская годовщина
Евдокия Ростопчина Одним меньше
Каролина Павлова Москва
--------------------------------------------------------------------------------
Прощание с другом, написанное с болью и принятием
Алексей Плещеев Прости
Николай Огарев К М. Л. Огаревой (Расстались мы...)
Сергей Есенин Прощай, Баку! Тебя я не увижу...
Василий Жуковский К Филалету
Гаврила Державин Разлука
Семен Надсон На разлуку
Николай Тихонов Как след от весла, от берега ушедший...
Владимир Раевский Мое прости друзьям
Сергей Есенин До свиданья, друг мой...
Ольга Берггольц Осень сорок первого
--------------------------------------------------------------------------------
Стихотворение о нестандартной любви к Ро

In [117]:
for i in ress:
    print(i.page_content)
    print('-'*80)

search_document: Отвозит дочь.
Тоска-печаль в душе Алины
И день и ночь.
Три года длилося изгнанье;
Не усладил
Ни разу друг ее страданье:
Но все он мил.
Однажды... о! как свет коварен!..
Сказала мать:
«Любовник твой неблагодарен»,
И ей читать
Она дает письмо Альсима.
Его черты:
--------------------------------------------------------------------------------
search_document: В нескромный час меж вечера и света,
Без матери, одна, полуодета,
Зачем его должна ты принимать?..
Но я любим... Наедине со мною
Ты так нежна! Лобзания твои
Так пламенны! Слова твоей любви
Так искренно полны твоей душою!
Тебе смешны мучения мои;
Но я любим, тебя я понимаю.
--------------------------------------------------------------------------------
search_document: Я пулей ранен был;
Что умер честно за царя,
Что плохи наши лекаря
И что родному краю
Поклон я посылаю.
Отца и мать мою едва ль
Застанешь ты в живых...
Признаться, право, было б жаль
Мне опечалить их;
Но если кто из них и жив,
Скажи, что я писать ленив,