In [None]:
from langchain_text_splitters import  MarkdownHeaderTextSplitter
from langchain_core.documents import Document
from typing import List, Tuple
import os
import ollama
import chromadb
import tiktoken

from config import CONFIG

import logging
import os

os.makedirs("logs", exist_ok=True)

# General pipeline logger
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    handlers=[
        logging.FileHandler("logs/main.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("main")

# LLM interactions logger
llm_logger = logging.getLogger("llm")
llm_handler = logging.FileHandler("logs/llm.log")
llm_handler.setFormatter(logging.Formatter("%(asctime)s | %(message)s"))
llm_logger.addHandler(llm_handler)
llm_logger.setLevel(logging.DEBUG)

# Code generation / retries logger
code_logger = logging.getLogger("codegen")
code_handler = logging.FileHandler("logs/codegen.log")
code_handler.setFormatter(logging.Formatter("%(asctime)s | %(message)s"))
code_logger.addHandler(code_handler)
code_logger.setLevel(logging.DEBUG)


client = chromadb.PersistentClient(path=CONFIG["paths"]["chroma_storage"])
collection = client.get_or_create_collection(name="docs")

client.delete_collection(name="docs")
collection = client.get_or_create_collection(name="docs")


with open(CONFIG["paths"]["hint_variables"]) as f:
    variable_hints_text = f.read()
with open(CONFIG["paths"]["hint_errors"]) as f:
    error_hints_text = f.read()

headers_to_split_on = CONFIG["retriever"]["headers_to_split_on"]
max_context_tokens = CONFIG["context"]["max_context_tokens"]

encoding = tiktoken.get_encoding("cl100k_base")
EMBED_MODEL = CONFIG["models"]["embedding"]
LLM_MODEL = CONFIG["models"]["llm"]

def load_documents(folder_path: str) -> List[Document]:
    documents = []
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if filename.endswith('.md'):
            with open(file_path) as f:
                content = f.read()
            documents.append(Document(page_content=content, metadata={"source": filename}))
    return documents


def split_docs(documents: List[Document]) -> List[Document]:
    final_docs = []
    splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
    for doc in documents:
        if doc.metadata.get("source", "").endswith(".md"):
            md_chunks = splitter.split_text(doc.page_content)
            for chunk in md_chunks:
                final_docs.append(
                    Document(page_content=chunk.page_content,
                             metadata={**doc.metadata, **chunk.metadata})
                )
        else:
            final_docs.append(doc)
    return final_docs


def embednstore(splits, collection):
    for i, doc in enumerate(splits):
        print(f"Chunk {i}:", doc.page_content[:CONFIG["logging"]["preview_chars"]])

        response = ollama.embed(model=EMBED_MODEL, input=doc.page_content)
        embedding = response["embeddings"][0]

        collection.add(
            ids=[str(i)],
            embeddings=[embedding],
            documents=[doc.page_content]
        )
        print(f"embedding {i} length: {len(embedding)} | preview: {embedding[:5]} \n")


def load_questions(path: str = CONFIG["paths"]["questions_file"]) -> List[str]:
    with open(path) as f:
        return [line.strip() for line in f if line.strip()]


def embed_question(question: str) -> List[float]:
    resp = ollama.embed(model=EMBED_MODEL, input=question)
    return resp["embeddings"][0]


def retrieve_top_chunks(query_embedding, n_results: int = CONFIG["retriever"]["top_k"]) -> List[str]:
    results = collection.query(query_embeddings=[query_embedding], n_results=n_results)
    return results['documents'][0]


def build_context(chunks: List[str]) -> str:
    return variable_hints_text + "\n\n" + error_hints_text + "\n\n" + "\n\n".join(chunks)


def truncate_context_if_needed(question: str, full_context: str) -> Tuple[str, int]:
    all_tokens = encoding.encode(question + full_context)
    if len(all_tokens) > max_context_tokens:
        all_tokens = all_tokens[:max_context_tokens]
        full_context = encoding.decode(all_tokens)
    return full_context, len(all_tokens)


def extract_code_from_response(text: str):
    try:
        code = text.split("```")[1]
        return True, code.replace("python", "", 1).strip()
    except IndexError:
        return False, "no code block found"


def try_exec_generated_code(code: str) -> Tuple[bool, str]:
    try:
        exec(code)
        return True, ""
    except Exception as e:
        return False, str(e)

def generate_code_with_retries(question: str, context: str, max_tries: int = CONFIG["generation"]["max_retries"]) -> Tuple[str, bool, str]:
    success, error_message, last_code = False, "", ""

    for trial in range(max_tries):
        prompt = f"""You are a helpful assistant with access to CMS hint files containing Python code snippets, 
                variable names, and common error messages with solutions: {context} 
                Use only this data to answer the following question: {question}
                Expected output: Python code snippet only.
                If the answer is not in the data, respond: "I don't know based on the available information."
                Last attempt: {last_code}
                Error message (if any): {error_message}
                Please fix the code if there was an error; otherwise, provide a solution.
                """

        # Log what we send to the LLM
        llm_logger.debug(f"\n--- TRY {trial+1} ---\nPrompt:\n{prompt}\n")

        output = ollama.generate(
            model=LLM_MODEL,
            prompt=prompt,
            options={
                "temperature": CONFIG["generation"]["temperature"],
                "top_p": CONFIG["generation"]["top_p"],
                "top_k": CONFIG["generation"]["top_k"],
            }
        )

        response_text = output['response']
        llm_logger.debug(f"Response:\n{response_text}\n")

        # Try extracting code
        found, payload = extract_code_from_response(response_text)
        if not found:
            code_logger.warning(f"TRY {trial+1}: No code block found")
            error_message, last_code = payload, ""
            continue

        last_code = payload
        ok, err = try_exec_generated_code(payload)

        if ok:
            code_logger.info(f"TRY {trial+1}: SUCCESS\n{payload}")
            return payload, True, ""
        else:
            code_logger.error(f"TRY {trial+1}: Execution error: {err}")
            error_message = err

    code_logger.error("Max retries reached. No working code found.")
    return last_code, False, error_message

def process_questions_and_run(path: str = CONFIG["paths"]["questions_file"]):
    questions = load_questions(path)
    for question in questions:
        logger.info(f"=== QUESTION: {question} ===")

        query_embedding = embed_question(question)
        chunks = retrieve_top_chunks(query_embedding)

        logger.debug("=== TOP CHUNKS ===")
        for i, doc in enumerate(chunks):
            logger.debug(f"Chunk {i} preview: {doc[:500]}")

        full_context = build_context(chunks)
        full_context, token_count = truncate_context_if_needed(question, full_context)
        logger.info(f"Final context length: {token_count} tokens")

        code, success, error_message = generate_code_with_retries(question, full_context)
        if not success:
            code_logger.error(f"QUESTION FAILED: '{question}' , error: {error_message}")

def main():
    logger.info("Starting...")

    # Load documents
    documents = load_documents(CONFIG["paths"]["hint_code"])
    logger.info(f"Loaded {len(documents)} documents")

    # Split documents
    splits = split_docs(documents)
    logger.info(f"Split into {len(splits)} chunks")

    # Embed and store
    embednstore(splits, collection)
    logger.info(f"Stored {len(splits)} embedded chunks")

    # Process questions
    logger.info("Processing questions...")
    process_questions_and_run(CONFIG["paths"]["questions_file"])

    logger.info("Finished successfully.")


if __name__ == "__main__":
    main()
