In [None]:
import os
import sys
import time
import sqlite3
import numpy as np
import faiss
import ollama

MODEL_NAME = 'llama3.2'
DB_FILE = "mem.db"
INDEX_FILE = "mem.index"
SHORT_TERM_TOKENS = 1000
TOP_K = 5

class OllamaClient:
    def __init__(self, model_name):
        self.model = model_name

    def chat(self, messages):
        """Send a messages list to the model and return the reply"""
        start = time.time()
        resp = ollama.chat(model=self.model, messages=messages)
        elapsed = time.time() - start
        print(f"[Time: {elapsed:.2f}s]")
        return resp["message"]["content"]
    
    def embed(self, text):
        """Get and embedding vector for text"""
        resp = ollama.embed(model=self.model, input=text)
        embs = resp.get("embeddings")
        if embs is None:
            raise KeyError(f"No embeddings in response-got keys: {list(resp.keys)}")
        vec = embs[0] if isinstance(embs[0], list) else embs

        return np.array(vec, dtype="float32")
    
class MemoryManager:
    def __init__(self, dim, db_file=DB_FILE, index_file=INDEX_FILE):
        if os.path.exists(index_file):
            self.index = faiss.read_index(index_file)
        else:
            self.index = faiss.IndexFlatL2(dim)
        self.index_file = index_file

        # SQLite for storing text and metadata
        self.conn = sqlite3.connect(db_file)
        self._ensure_tables()

    def _ensure_tables(self):
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS memories (
                id          INTEGER     PRIMARY KEY,
                text        TEXT        NOT NULL,
                is_summary  INTEGER     NOT NULL DEFAULT 0
                ts          DATETIME    DEFAULT CURRENT_TIMESTAMP)
            """)
        self.conn.commit()

    def add(self, text, vec, is_summary=False):
        """Add a new memory"""
        idx = self.index.ntotal
        self.index.add(vec.reshape(1, -1))
        self.conn.execute(
            "INSERT INTO memories (id, text, is_summary) VALUES (?, ?, ?)",
            (idx, text, 1 if is_summary else 0)
        )
        self.conn.commit()

    def query(self, vec, top_k=TOP_K):
        """Return top_k most similar memory texts"""
        if self.index.ntotal == 0:
            return []
        
        D, I = self.index.search(vec.reshape(1, -1), top_k)
        placeholders = ",".join("?" for _ in I[0])
        rows = self.conn.execute(
            f"SELECT text FROM memories WHERE id IN ({placeholders})",
            tuple(int(i) for i in I[0])
        ).fetchall()
        return [r[0] for r in rows]
    
    def save(self):
        faiss.write_index(self.index, self.index_file)

def count_tokens(text):
    '''Rough token count'''
    return len(text.split())

class ChatCLI:
    def __init__(self, client, memory_mgr):
        self.client = client
        self.mem_mgr = memory_mgr
        self.short_term = []
        self.token_buff = 0

    def run(self):
        print("ChatCLI (type 'quit' to exit)\n")
        try:
            while True:
                prompt = input("You: ").strip()
                if not prompt:
                    continue
                if prompt.lower() == "quit":
                    print("Goodbye!")
                    break
                    
                user_vec = self.client.embed(prompt)
                self.mem_mgr.add(prompt, user_vec, is_summary=False)

                # Fetch relevant long-term memories
                retrieved = self.mem_mgr.query(user_vec)

                # Assemble messages for the model
                messages = [{"role": "system", "content": "You are a helpful assistant. Keep your responses concise."}]

                for mem in retrieved:
                    messages.append({"role": "system", "content": f"Memory: {mem}"})

                for role, text in self.short_term:
                    messages.append({"role": role, "content": text})
                messages.append({"role": "user", "content": prompt})

                reply = self.client.chat(messages)
                print("Bot:", reply, "\n")

                # Track short term context and token count
                self.short_term.append(("user", prompt))
                self.short_term.append(('assistant', reply))
                self.token_buff += count_tokens(prompt) + count_tokens(reply)
            
                bot_vec = self.client.embed(reply)
                self.mem_mgr.add(reply, bot_vec, is_summary=False)

                if self.token_buff > SHORT_TERM_TOKENS:
                    summary_prompt = (
                        "Summarize the following conversation in about 200 tokens: \n\n"
                        + "\n".join(f"{r}: {t}" for r, t in self.short_term)
                    )
                    summary = self.client.chat([
                        {"role": "system", "content": "You are an expert summarizer."},
                        {"role": "user", "content": summary_prompt}
                    ])