In [None]:
# ==========================================================
# GSM8K MASTER RESEARCH SCRIPT
# Baseline vs Flat Retrieval vs Hierarchical Memory
# ==========================================================

!pip install datasets sentence-transformers faiss-cpu -q

import time
import re
import json
import numpy as np
import faiss
from tqdm import tqdm
from datasets import load_dataset
from openai import OpenAI
from sentence_transformers import SentenceTransformer
from collections import Counter
import random

# ==========================================================
# CONFIGURATION (TOKEN SAFE)
# ==========================================================

API_KEY = "csk-hkc9kwef3nwvvdftd3pthdwpj4jfj2n4t2rd8yvc5ww558p6"
BASE_URL = "https://api.cerebras.ai/v1"
MODEL_NAME = "llama3.1-8b"

DATASET_SIZE = 180
TRAIN_RETRIEVAL_SIZE = 250

NUM_RUNS = 1
NUM_SAMPLES = 1
TEMPERATURE = 0.1
MAX_TOKENS = 200

TOP_K = 3
TOP_K_EPISODIC = 2
TOP_K_FAILURE = 1

MAX_MEMORY_SIZE = 120

# ==========================================================
# CLIENT
# ==========================================================

client = OpenAI(api_key=API_KEY, base_url=BASE_URL)

# ==========================================================
# PROMPT
# ==========================================================

SYSTEM_PROMPT = """You are a careful mathematical reasoner.

Solve step by step.

Final answer must be on a new line:
#### <integer>
"""

# ==========================================================
# UTILS
# ==========================================================

def extract_answer(text):
    if text is None:
        return None
    match = re.search(r"####\s*(-?\d+\.?\d*)", text.replace(",", ""))
    return match.group(1) if match else None

def normalize_answer(ans):
    try:
        return float(ans)
    except:
        return None

def embed_normalized(model, text):
    emb = model.encode([text])[0]
    emb = np.array([emb]).astype("float32")
    faiss.normalize_L2(emb)
    return emb

# ==========================================================
# LOAD DATA
# ==========================================================

print("Loading dataset...")
dataset = load_dataset("gsm8k", "main", split="test")
dataset = dataset.select(range(DATASET_SIZE))

train_dataset = load_dataset("gsm8k", "main", split="train")
train_dataset = train_dataset.select(range(TRAIN_RETRIEVAL_SIZE))

# ==========================================================
# EMBEDDING MODEL (LOAD ONCE)
# ==========================================================

embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
EMBED_DIM = 384

# ==========================================================
# BASELINE
# ==========================================================

def run_baseline():

    correct = 0
    latencies = []
    token_usages = []

    for i in tqdm(range(len(dataset))):

        question = dataset[i]["question"]
        gt = normalize_answer(extract_answer(dataset[i]["answer"]))

        answers = []
        sample_lat = []
        sample_tok = []

        for _ in range(NUM_SAMPLES):

            start = time.time()

            response = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role":"system","content":SYSTEM_PROMPT},
                    {"role":"user","content":question}
                ],
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS
            )

            sample_lat.append(time.time()-start)

            output = response.choices[0].message.content
            pred = normalize_answer(extract_answer(output))

            if pred is not None:
                answers.append(pred)

            if response.usage:
                sample_tok.append(response.usage.total_tokens)

        if answers:
            vote = Counter(answers)
            final = vote.most_common(1)[0][0]
            if final == gt:
                correct += 1

        latencies.append(np.mean(sample_lat))
        token_usages.append(np.sum(sample_tok))

    return correct/len(dataset), np.mean(latencies), np.mean(token_usages)

# ==========================================================
# FLAT RETRIEVAL
# ==========================================================

print("Embedding flat retrieval DB...")
train_questions = [train_dataset[i]["question"] for i in range(len(train_dataset))]
train_embeddings = embed_model.encode(train_questions, convert_to_numpy=True).astype("float32")
faiss.normalize_L2(train_embeddings)

flat_index = faiss.IndexFlatIP(EMBED_DIM)
flat_index.add(train_embeddings)

def retrieve_flat(query):
    q_emb = embed_normalized(embed_model, query)
    _, indices = flat_index.search(q_emb, TOP_K)

    context = ""
    for idx in indices[0]:
        item = train_dataset[int(idx)]
        context += f"Example:\nQ: {item['question']}\nA: {item['answer']}\n\n"
    return context[:1500]

def run_flat():

    correct = 0
    latencies = []
    token_usages = []

    for i in tqdm(range(len(dataset))):

        question = dataset[i]["question"]
        gt = normalize_answer(extract_answer(dataset[i]["answer"]))

        context = retrieve_flat(question)
        prompt = context + "\nNow solve:\n" + question

        answers = []
        sample_lat = []
        sample_tok = []

        for _ in range(NUM_SAMPLES):

            start = time.time()

            response = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role":"system","content":SYSTEM_PROMPT},
                    {"role":"user","content":prompt}
                ],
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS
            )

            sample_lat.append(time.time()-start)

            output = response.choices[0].message.content
            pred = normalize_answer(extract_answer(output))

            if pred is not None:
                answers.append(pred)

            if response.usage:
                sample_tok.append(response.usage.total_tokens)

        if answers:
            vote = Counter(answers)
            final = vote.most_common(1)[0][0]
            if final == gt:
                correct += 1

        latencies.append(np.mean(sample_lat))
        token_usages.append(np.sum(sample_tok))

    return correct/len(dataset), np.mean(latencies), np.mean(token_usages)

# ==========================================================
# IMPROVED HIERARCHICAL MEMORY
# ==========================================================

class HierarchicalMemory:

    def __init__(self):
        self.episodic_index = faiss.IndexFlatIP(EMBED_DIM)
        self.failure_index = faiss.IndexFlatIP(EMBED_DIM)

        self.episodic_store = []
        self.failure_store = []

    def add_episode(self, question, reasoning, correct, vote_count):

        confidence = vote_count / NUM_SAMPLES

        emb = embed_normalized(embed_model, question)

        if correct and confidence >= 0.7:
            self.episodic_index.add(emb)
            self.episodic_store.append({
                "question": question,
                "reasoning": reasoning
            })

            if len(self.episodic_store) > MAX_MEMORY_SIZE:
                self.episodic_store.pop(0)

        elif not correct:
            self.failure_index.add(emb)
            self.failure_store.append({
                "question": question,
                "hint": "Double-check arithmetic and units."
            })

            if len(self.failure_store) > MAX_MEMORY_SIZE:
                self.failure_store.pop(0)

    def retrieve(self, question):

        context = ""

        if self.episodic_index.ntotal > 0:
            q_emb = embed_normalized(embed_model, question)
            _, idx = self.episodic_index.search(
                q_emb, min(TOP_K_EPISODIC, self.episodic_index.ntotal)
            )
            for i in idx[0]:
                item = self.episodic_store[int(i)]
                context += f"Similar solved:\nQ:{item['question']}\n{item['reasoning']}\n\n"

        if self.failure_index.ntotal > 0:
            q_emb = embed_normalized(embed_model, question)
            _, idx = self.failure_index.search(
                q_emb, min(TOP_K_FAILURE, self.failure_index.ntotal)
            )
            for i in idx[0]:
                item = self.failure_store[int(i)]
                context += f"Past mistake warning: {item['hint']}\n\n"

        return context[:1500]

def run_hierarchical():

    memory = HierarchicalMemory()

    correct = 0
    latencies = []
    token_usages = []

    for i in tqdm(range(len(dataset))):

        question = dataset[i]["question"]
        gt = normalize_answer(extract_answer(dataset[i]["answer"]))

        context = memory.retrieve(question)
        prompt = context + "\nNow solve:\n" + question

        answers = []
        reasoning_samples = []
        sample_lat = []
        sample_tok = []

        for _ in range(NUM_SAMPLES):

            start = time.time()

            response = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role":"system","content":SYSTEM_PROMPT},
                    {"role":"user","content":prompt}
                ],
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS
            )

            sample_lat.append(time.time()-start)

            output = response.choices[0].message.content
            reasoning_samples.append(output)

            pred = normalize_answer(extract_answer(output))
            if pred is not None:
                answers.append(pred)

            if response.usage:
                sample_tok.append(response.usage.total_tokens)

        if answers:
            vote = Counter(answers)
            final, vote_count = vote.most_common(1)[0]
            is_correct = (final == gt)
            if is_correct:
                correct += 1
        else:
            final = None
            vote_count = 0
            is_correct = False

        # store only best reasoning
        best_reasoning = reasoning_samples[0] if reasoning_samples else ""
        memory.add_episode(question, best_reasoning, is_correct, vote_count)

        latencies.append(np.mean(sample_lat))
        token_usages.append(np.sum(sample_tok))

    return correct/len(dataset), np.mean(latencies), np.mean(token_usages)

# ==========================================================
# RUN ALL SYSTEMS
# ==========================================================

print("\n===== BASELINE =====")
baseline = run_baseline()

print("\n===== FLAT RETRIEVAL =====")
flat = run_flat()

print("\n===== HIERARCHICAL MEMORY =====")
hier = run_hierarchical()

# ==========================================================
# FINAL COMPARISON
# ==========================================================

print("\n================ FINAL COMPARISON ================")
print("System                Acc     Latency    Tokens")
print("--------------------------------------------------")
print(f"Baseline             {baseline[0]:.4f}   {baseline[1]:.2f}s   {baseline[2]:.1f}")
print(f"Flat Retrieval       {flat[0]:.4f}   {flat[1]:.2f}s   {flat[2]:.1f}")
print(f"Hierarchical         {hier[0]:.4f}   {hier[1]:.2f}s   {hier[2]:.1f}")
print("===================================================")


Loading dataset...


Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Embedding flat retrieval DB...

===== BASELINE =====


100%|██████████| 180/180 [05:50<00:00,  1.95s/it]



===== FLAT RETRIEVAL =====


100%|██████████| 180/180 [05:24<00:00,  1.80s/it]



===== HIERARCHICAL MEMORY =====


100%|██████████| 180/180 [06:25<00:00,  2.14s/it]


System                Acc     Latency    Tokens
--------------------------------------------------
Baseline             0.4722   1.95s   289.9
Flat Retrieval       0.5500   1.76s   720.0
Hierarchical         0.4667   2.02s   684.8



