In [None]:
# !pip install -r requirements.txt 
# MAIN CLASS FOR RAG PIPELINE
import os
import re
import time
import math
from typing import List, Dict, Tuple, Optional

from pydantic import BaseModel

# from mem0 import Memory
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_mistralai import ChatMistralAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder

# Новое: адаптер поверх реального MemOS/MOS
from memos.configs.mem_os import MOSConfig
from memos.mem_os.main import MOS
import uuid



class MemOSAdapter:
    def __init__(self, config_path: str, memcube_dir: str = "./mem_cube", user_id: Optional[str] = None):
        config_path = os.path.abspath(config_path)
        self.mos = MOS(MOSConfig.from_json_file(config_path))
        self.user_id = user_id or str(uuid.uuid4())
        self.mos.create_user(user_id=self.user_id)
        self.mos.register_mem_cube(memcube_dir, user_id=self.user_id)
        self._text2meta: Dict[int, Dict] = {}

    def add(self, messages: List[Dict], user_id: str = 'user', metadata: Optional[Dict] = None, infer: bool = False):

        self.mos.add(messages=messages, user_id=self.user_id)

    def _to_dict(self, obj):
        if isinstance(obj, dict):
            return obj
        if hasattr(obj, "model_dump"): 
            return obj.model_dump()

        out = {}
        for k in ("id", "memory", "content", "text", "metadata", "score", "updated_at"):
            if hasattr(obj, k):
                out[k] = getattr(obj, k)
        return out

    def search(self, query: str, user_id: str = 'user', limit: int = 5, filters: Optional[Dict] = None) -> Dict:
        res = self.mos.search(query=query, user_id=self.user_id, top_k=100 )

        text_mem = res.get('text_mem')
        items = text_mem[0].get('memories')

        out = []
        seen_content = set()
        num_unique = 0
        for it in items:
            d = self._to_dict(it)
            if d.get('memory') in seen_content:
                continue
            seen_content.add(d.get('memory'))
            content = d.get('memory')
            num_unique += 1
            

            meta = ''
            score = float(d.get('score', 0.0))
            out.append({"memory": content, "score": score, "metadata": meta})
            if num_unique > limit:
                break

        return {"results": out}


import warnings
warnings.filterwarnings("ignore")

MistralAPIKey = os.getenv("MISTRAL_API_KEY", "GrUMa9JZTcJk8KodxVCEZspBaQgEiqXH")

def tokenize(text: str):
    return re.findall(r"\w+", text.lower(), flags=re.UNICODE)


class RetrieverCandidate(BaseModel):
    memory: str
    score: float
    metadata: Dict

    @property
    def id(self) -> str:
        md = self.metadata or {}
        if 'doc_basename' in md and 'chunk_idx' in md:
            return f"{md['doc_basename']}_{md['chunk_idx']}"
        return f"mem_{abs(hash(self.memory))}"



class RAGPipeline:
    def __init__(
        self,
        mistral_model: str = "mistral-medium",
        mistral_api_key: str = MistralAPIKey,
        temperature: float = 0.7,
        embed_model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
        chunk_size: int = 900,
        chunk_overlap: int = 150,
    ):
        os.environ["MEM0_DIR"] = os.path.abspath(f"./.mem0_tmp_{int(time.time())}")

        self.llm = ChatMistralAI(
            model=mistral_model,
            mistral_api_key=mistral_api_key,
            temperature=temperature,
        )
        self.text2meta: Dict[int, Dict] = {} 


        self.embed_model_name = embed_model_name
        self.embedder = HuggingFaceEmbeddings(model_name=self.embed_model_name)
        self.index_dir = os.path.abspath(
            f"./faiss_{self.embed_model_name.split('/')[-1]}_384_{int(time.time())}"
        )

        self.system_prompt = (
            "You answer strictly from the provided context. "
            "If the answer is not present, say you don't know."
            "Do not tell about the context, just answer the question."
            "If you don't know the answer, say 'I don't know'."
            "Answer in language of the question, which is after the word 'Question:'. "
        )


        self.mem = MemOSAdapter(
            config_path="/Users/andreisuhov/Desktop/memos/config.json",
            memcube_dir="/Users/andreisuhov/Desktop/memos/mem_cube_2"
        )

        self.splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", " ", ""]
        )

        self.chunk_texts: List[str] = []
        self.chunk_meta: List[Dict] = []
        self.bm25: Optional[BM25Okapi] = None
        self.cross_encoder_name = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
        self.reranker = CrossEncoder(self.cross_encoder_name, device="cpu")

    def read_txt(self, path: str) -> str:
        with open(path, "r", encoding="utf-8") as f:
            return f.read()

    def _infer_chunk_idx_from_text(self, text: str) -> Optional[int]:

        if self.bm25 is None or not self.chunk_texts:
            return None
        q_tokens = tokenize(text)
        scores = self.bm25.get_scores(q_tokens)
        if not len(scores):
            return None
        best_i = max(range(len(scores)), key=lambda i: scores[i])
        return int(best_i)

    def _get_chunk_idx_safe(self, c: "RetrieverCandidate") -> Optional[int]:
        md = c.metadata or {}
        if "chunk_idx" in md:
            return int(md["chunk_idx"])

        m2 = self.text2meta.get(hash(c.memory))
        if m2 and "chunk_idx" in m2:
            return int(m2["chunk_idx"])

        return self._infer_chunk_idx_from_text(c.memory)


    def add_mem(self, text_paths: list[str]):
        for text_path in text_paths:
            text = self.read_txt(text_path)
            chunks = self.splitter.split_text(text)
            for idx, piece in enumerate(chunks):
                self.mem.add(
                    [{"role": "user", "content": piece, "metadata": {"doc_basename": os.path.basename(text_path), "chunk_idx": idx}}],
                )
                self.chunk_texts.append(piece)
                self.chunk_meta.append({"doc_basename": os.path.basename(text_path), "chunk_idx": idx})
                self.text2meta[hash(piece)] = {"doc_basename": os.path.basename(text_path), "chunk_idx": idx}

        tokenized_chunks = [tokenize(t) for t in self.chunk_texts]
        self.bm25 = BM25Okapi(tokenized_chunks)
        return self.bm25, self.chunk_texts, self.chunk_meta


    def search_mem(self, query: str, k: int = 5, user_id: str = "user", filters: Dict | None = None) -> List[RetrieverCandidate]:
        q = f"query: {query}"

        raw = self.mem.search(q, user_id=user_id, limit=k, filters=filters)
        out = []
        for res in raw['results']:
            mem = res['memory']
            if "passage: " in mem:
                mem = mem.split("passage: ")[1]
            meta = dict(res.get('metadata') or {})
            if 'chunk_idx' not in meta or 'doc_basename' not in meta:
                m2 = self.text2meta.get(hash(mem))

            out.append(RetrieverCandidate(memory=mem, score=0.0, metadata={"doc_basename": m2["doc_basename"], "chunk_idx": m2['chunk_idx']}))
        return out

    

    def search_bm25(self, query: str, n: int = 5) -> List[RetrieverCandidate]:
        if self.bm25 is None or not self.chunk_texts:
            return []
        q_tokens = tokenize(query)  # <-- раньше сюда шла строка!
        scores = self.bm25.get_scores(q_tokens)
        order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n]
        out = [RetrieverCandidate(memory=self.chunk_texts[i],
                                score=float(scores[i]),
                                metadata=self.chunk_meta[i]) for i in order]
        return out


    def _make_id(self, item: Dict | RetrieverCandidate) -> str:
        if isinstance(item, RetrieverCandidate):
            return item.id
        md = item.get("metadata") or {}
        return md['doc_basename'] + '_' + str(md['chunk_idx'])

    def fuse_with_rrf(
        self,
        hits_mem: List[RetrieverCandidate],
        hits_bm25: List[RetrieverCandidate],
        top_k: Optional[int] = None,
        rrf_k: int = 60,
        weights: List[float] = [0.4, 0.6]
    ) -> List[RetrieverCandidate]:
        wb, wm = weights

        ranks_mem = { c.id: i for i, c in enumerate(hits_mem) }
        ranks_bm25 = { c.id: i for i, c in enumerate(hits_bm25) }

        pool: Dict[str, RetrieverCandidate] = {}
        for c in hits_mem + hits_bm25:
            if c.id not in pool:
                pool[c.id] = RetrieverCandidate(memory=c.memory, score=0.0, metadata=dict(c.metadata))
            else:
                pool[c.id].metadata |= dict(c.metadata)

        fused_scores: List[Tuple[str, float]] = []
        for cid in pool.keys():
            part_mem = wb*(1.0 / (rrf_k + 1 + ranks_mem[cid])) if cid in ranks_mem else 0.0
            part_bm  = wm*(1.0 / (rrf_k + 1 + ranks_bm25[cid])) if cid in ranks_bm25 else 0.0
            fused_scores.append((cid, part_mem + part_bm))

        fused_scores.sort(key=lambda t: t[1], reverse=True)
        if top_k is not None:
            fused_scores = fused_scores[:top_k]

        out: List[RetrieverCandidate] = []
        for cid, sc in fused_scores:
            out.append(pool[cid].copy(update={"score": float(sc)}))
        return out


    def rerank_with_cross_encoder(self, query: str, candidates: List[RetrieverCandidate], batch_size: int = 32) -> List[RetrieverCandidate]:
        pairs = [(query, c.memory) for c in candidates]
        scores = self.reranker.predict(pairs, batch_size=batch_size)
        order = sorted(range(len(candidates)), key=lambda i: scores[i], reverse=True)
        return [candidates[i] for i in order]

    def ask(self, question: str, k: int = 5, user_id: str = "user", filters: Dict | None = None) -> str:
        hits_vec = self.search_mem(question, k=k, user_id=user_id, filters=filters)
        hits_bm25 = self.search_bm25(question, n=k)
        fused = self.fuse_with_rrf(hits_vec, hits_bm25)
        reranked = self.rerank_with_cross_encoder(question, fused)
        context = "\n\n---\n\n".join(c.memory for c in reranked)

        prompt = f"{self.system_prompt}\n\nContext:\n{context}\n\nQuestion: {question}\nAnswer:"
        res = self.llm.invoke(prompt)
        return res.content



rag = RAGPipeline()
rag.add_mem(["/Users/andreisuhov/Desktop/memos/data/sarts.txt", '/Users/andreisuhov/Desktop/memos/data/robert_fittcpatrik_sprosi_mamu_k.txt'])



  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.53s/it]




Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.85s/it]




(<rank_bm25.BM25Okapi at 0x1532dd950>,
 ["Jean-Paul Sartre was born on 21 June 1905 in Paris as the only child of Jean-Baptiste Sartre, an officer of the French Navy, and Anne-Marie (Schweitzer).[7] When Sartre was two years old, his father died of an illness, which he most likely contracted in Indochina. Anne-Marie moved back to her parents' house in Meudon, where she raised Sartre with help from her father Charles Schweitzer, a teacher of German who taught Sartre mathematics and introduced him to classical literature at a very early age.[8] When he was twelve, Sartre's mother remarried, and the family moved to La Rochelle, where he was frequently bullied, in part due to the wandering of his blind right eye (sensory exotropia).[9]",
  'As a teenager in the 1920s, Sartre became attracted to philosophy upon reading Henri Bergson\'s essay Time and Free Will: An Essay on the Immediate Data of Consciousness.[10] He attended the Cours Hattemer, a private school in Paris.[11] He studied and 

In [None]:
# GENERATE QA DATASET

def return_llm(api_key, model= 'mistral-large-latest'):
    return ChatMistralAI(
            model=model,
            mistral_api_key=api_key,
            temperature=0.1,
        )

SYS_PROMPT = (
    "Ты создаёшь датасеты для оценки RAG. "
    "Тебе дают ОДИН отрывок книги. Сформируй информативные вопросы, ответы на которые "
    "ПОЛНОСТЬЮ выводятся из этого отрывка, без внешних знаний. "
    "Дай точные ответы. Для каждого примера верни также evidence_span — дословную цитату "
    "Из отрывка, которая подтверждает ответ. Если текста недостаточно — пропускай такие вопросы.\n\n"
    "Требования к разнообразию: разные аспекты содержания, факты, определения, причинно-следственные связи."
)

HUMAN_TEMPLATE = """Отрывок:
---
{chunk}
---

Сгенерируй до {n} пар в формате JSON-массива:
[
  {{
    "question": "...",
    "answer": "...",
    "evidence_span": "дословная цитата из отрывка"
  }}
]

Только JSON. Без комментариев и пояснений.
"""

def strip_code_fences(s: str) -> str:
    s = s.strip()
    if s.startswith("```"):
        s = re.sub(r"^```[a-zA-Z0-9]*\n", "", s)
        if s.endswith("```"):
            s = s[:-3]
    return s.strip()


def safe_json_loads(s: str) -> Any:
    try:
        return json.loads(s)
    except Exception:
        return None


def dedup_lower(seq: List[str]) -> List[str]:
    seen, out = set(), []
    for x in seq:
        xl = x.strip().lower()
        if xl and xl not in seen:
            seen.add(xl)
            out.append(x)
    return out

import time
import httpx

def is_capacity_error(e):
    msg = str(e)
    if isinstance(e, httpx.HTTPStatusError):
        if e.response.status_code == 429:
            return True
        try:
            err_json = e.response.json()
            if "capacity" in err_json.get("message", "").lower():
                return True
        except Exception:
            pass
    if "capacity" in msg.lower() or "429" in msg:
        return True
    return False

def generate_qa_for_chunk(llm: ChatMistralAI, chunk_id: int, chunk: str, n: int = 3, max_retries: int = 2, api_keys: list = None, timeout: float = 5.0) -> List[Dict[str, str]]:
    """
    llm: ChatMistralAI instance (will be replaced with new instance if api key rotates)
    api_keys: list of API keys to try if capacity error occurs
    timeout: seconds to wait between retries or on busy
    """
    output = []
    messages = [
        SystemMessage(content=SYS_PROMPT),
        HumanMessage(content=HUMAN_TEMPLATE.format(chunk=chunk, n=n))
    ]
    api_keys = api_keys or [getattr(llm, "mistral_api_key", None)]
    api_keys = [k for k in api_keys if k] 
    key_idx = 0
    last_exception = None

    for attempt in range(max_retries + 1):
        for _ in range(len(api_keys)):
            current_key = api_keys[key_idx % len(api_keys)]
            if getattr(llm, "mistral_api_key", None) != current_key:
                llm = ChatMistralAI(
                    model=getattr(llm, "model", "mistral-medium"),
                    mistral_api_key=current_key,
                    temperature=getattr(llm, "temperature", 0.3),
                )
            try:
                resp = llm.invoke(messages)
                raw = strip_code_fences(resp.content)
                data = safe_json_loads(raw)
                if isinstance(data, list):
                    out = []
                    for item in data:
                        if not isinstance(item, dict):
                            continue
                        q = (item.get("question") or "").strip()
                        a = (item.get("answer") or "").strip()
                        ev = (item.get("evidence_span") or "").strip()
                        if q and a and ev and ev in chunk:
                            out.append({"question": q, "answer": a, "evidence_span": ev, "chunk_id": chunk_id})
                    if out:
                        seen, clean = set(), []
                        for it in out:
                            key = it["question"].lower()
                            if key not in seen:
                                seen.add(key)
                                clean.append(it)
                        return clean
                break
            except Exception as e:
                last_exception = e
                if is_capacity_error(e):
                    key_idx += 1
                    time.sleep(timeout)
                    continue
                else:

                    time.sleep(timeout)
                    break
        else:

            time.sleep(timeout)

    if last_exception:
        print(f"Failed to generate QA for chunk {chunk_id}: {last_exception}")
    return []

def save_qa_to_pickle(qa_list, filename):
    import pickle
    filename = f"/Users/andreisuhov/Desktop/memos/qas.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(qa_list, f)

all_qa = []
from tqdm import tqdm
pickle.dump(rag.chunk_texts, open("/Users/andreisuhov/Desktop/memos/chunk_texts.pkl", "wb"))
for text_chunk_id, text_chunk in tqdm(enumerate(rag.chunk_texts), total=len(rag.chunk_texts)):
    api_keys = [
    #API_KYES_HERE
    ]
    try:
        qa_list = generate_qa_for_chunk(llm, text_chunk_id, text_chunk, n=3, max_retries=2, api_keys=api_keys, timeout=5.0)
    except Exception as e:
        print(f"Error on chunk {text_chunk_id}: {e}")
        qa_list = []
    for qa in qa_list:
        question = qa["question"].strip()
        answer = qa["answer"].strip()
        evidence = qa["evidence_span"].strip()
        chunk_id = qa["chunk_id"]
        all_qa.append({
            "question": question,
            "answer": answer,
            "evidence": evidence,
            "chunk_id": chunk_id
        })
    print(f"Processed chunk {text_chunk_id} with {len(qa_list)} QA pairs")
    save_qa_to_pickle(all_qa, f"/Users/andreisuhov/Desktop/memos/qas.pkl")

# 250


In [3]:
import pickle
qa_list = pickle.load(open("/Users/andreisuhov/Desktop/memos/qas.pkl", "rb"))

In [7]:
# metric calculation
import math
from collections import defaultdict
from tqdm import tqdm

def _rank(pred_ids, rel_id):
    try: return pred_ids.index(rel_id)
    except ValueError: return 10**9

def _per_query_metrics(pred_ids, rel_id, ks):
    r = _rank(pred_ids, rel_id)
    out = {"MRR": 0.0}
    out["MRR"] = 1.0/(r+1) if r < 10**9 else 0.0
    for k in ks:
        hit = 1.0 if r < k else 0.0
        out[f"Hit@{k}"] = hit
        out[f"Recall@{k}"] = hit  
        out[f"P@{k}"] = hit/float(k)
        out[f"nDCG@{k}"] = (1.0/math.log2(r+2)) if r < k else 0.0
        out[f"MAP@{k}"] = (1.0/(r+1)) if r < k else 0.0
    return out

def _avg(dicts):
    acc = defaultdict(float); n = len(dicts) or 1
    for d in dicts:
        for k,v in d.items(): acc[k]+=v
    return {k: acc[k]/n for k in acc}

def evaluate_retrievers(rag, qa_list, ks=(1,3,5,10,20), top_k=20):
    all_metrics = {
        "vec": [], "bm25": [], "rrf": [], "rerank": []
    }
    for ex in tqdm(qa_list):
        q, rel = ex["question"], int(ex["chunk_id"])
        hv = rag.search_mem(q, k=top_k)
        hb = rag.search_bm25(q, n=top_k)
        fr = rag.fuse_with_rrf(hv, hb, top_k=top_k)
        rr = rag.rerank_with_cross_encoder(q, fr)

        def ids(lst):
            xs = []
            for c in lst:
                ci = rag._get_chunk_idx_safe(c)
                if ci is not None:
                    xs.append(ci)
            return xs

        ids_vec   = ids(hv)
        ids_bm25  = ids(hb)
        ids_rrf   = ids(fr)
        ids_rer   = ids(rr)


        all_metrics["vec"   ].append(_per_query_metrics(ids_vec,  rel, ks))
        all_metrics["bm25"  ].append(_per_query_metrics(ids_bm25, rel, ks))
        all_metrics["rrf"   ].append(_per_query_metrics(ids_rrf,  rel, ks))
        all_metrics["rerank"].append(_per_query_metrics(ids_rer,  rel, ks))

    return {name: _avg(ms) for name, ms in all_metrics.items()}


metrics = evaluate_retrievers(rag, qa_list, ks=(1,3,5,10,20), top_k=20)
from pprint import pprint; pprint(metrics)


Batches: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.37it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.11it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.05it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]
Batches: 100%|██████████| 1/1 [00:01<00:00,  2.00s/it]
Batches: 1

{'bm25': {'Hit@1': 0.4807692307692308,
          'Hit@10': 0.7045454545454546,
          'Hit@20': 0.7692307692307693,
          'Hit@3': 0.6013986013986014,
          'Hit@5': 0.6416083916083916,
          'MAP@1': 0.4807692307692308,
          'MAP@10': 0.5524482461982462,
          'MAP@20': 0.5570203467166895,
          'MAP@3': 0.5349650349650351,
          'MAP@5': 0.5442307692307693,
          'MRR': 0.5570203467166895,
          'P@1': 0.4807692307692308,
          'P@10': 0.07045454545454598,
          'P@20': 0.03846153846153877,
          'P@3': 0.2004662004661995,
          'P@5': 0.1283216783216792,
          'Recall@1': 0.4807692307692308,
          'Recall@10': 0.7045454545454546,
          'Recall@20': 0.7692307692307693,
          'Recall@3': 0.6013986013986014,
          'Recall@5': 0.6416083916083916,
          'nDCG@1': 0.4807692307692308,
          'nDCG@10': 0.5888674559893995,
          'nDCG@20': 0.6053273543537964,
          'nDCG@3': 0.5520710282717307,
      




In [6]:
print(rag.ask("Как производить деление клиентов?"))

Batches: 100%|██████████| 1/1 [00:00<00:00,  3.28it/s]


Для деления клиентской базы необходимо:

1. **Выделить тех, кому нравится ваш продукт**, и сосредоточиться на этой группе.
2. **Провести детальную сегментацию** для общения с клиентами, выходящую за рамки общей стратегии продвижения.
3. **Фиксировать результаты встреч** и распределять работу по изучению клиентов между всеми членами команды, чтобы избежать "бутылочного горлышка" (когда информация сосредоточена в одной голове).
4. **Четко формулировать проблемы и цели** клиентов, чтобы сегмент не оставался размытым.
5. **Просить клиентов продемонстрировать ситуации** (действия), а не просто описывать их словами, чтобы получить объективные данные.

Также важно:
- Анализировать отклики, даже если они противоречивые, и корректировать стратегию.
- Использовать вопрос: *«С кем еще мне следует переговорить?»* в конце беседы, чтобы расширять сегмент.
- Избегать поверхностных комплиментов и добиваться конкретных обязательств от клиентов.
