In [None]:
from langchain_community.document_loaders import Docx2txtLoader
from langchain_community.document_loaders import UnstructuredMarkdownLoader

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters import MarkdownHeaderTextSplitter

from typing import List
from langchain_core.documents import Document
import os
import ollama
import chromadb

client = chromadb.PersistentClient(path="chroma_storage")
collection = client.get_or_create_collection(name="docs")

folder_path = "/home/ruta/irishep/hint_files"

In [None]:
def load_documents(folder_path: str) -> List[Document]:
    documents = []
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if filename.endswith('.md'):
            with open(file_path) as f:
                content = f.read()
            documents.append(Document(page_content=content, metadata={"source": filename}))
        #handle other file types if needed...
    return documents
'''
def load_documents(folder_path: str) -> List[Document]:
    documents = []
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if filename.endswith('.pdf'):
            loader = PyPDFLoader(file_path)
        elif filename.endswith('.docx'):
            loader = Docx2txtLoader(file_path)
        elif filename.endswith('.md'):
            loader = UnstructuredMarkdownLoader(file_path)
        else:
            print(f"unsupported file type: {filename}")
            continue
        documents.extend(loader.load())
        
    return documents
'''

headers_to_split_on = [
    #("#", "Header 1") because hint files have comments as well, headers start with ##
    ("##", "Header 2"),
    ("###", "Header 3"),
    ("####", "Header 4")
]

def split_docs(documents: List[Document]) -> List[Document]:
    final_docs = []
    splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
    for doc in documents:
        if doc.metadata.get("source", "").endswith(".md"):
            md_chunks = splitter.split_text(doc.page_content)
            for chunk in md_chunks:
                final_docs.append(Document(page_content=chunk.page_content, metadata={**doc.metadata, **chunk.metadata}))
        else:
            final_docs.append(doc)
    return final_docs


def embednstore(splits, collection):
    for i, doc in enumerate(splits):
        print(f"Chunk {i}:", doc.page_content[:200])  #first 200 chars

        text = doc.page_content
        response = ollama.embed(model="mxbai-embed-large", input=text)
        embedding = response["embeddings"][0]

        collection.add(
            ids=[str(i)],
            embeddings=[embedding],
            documents=[text]
        )
        print(f"embedding {i} length: {len(embedding)} | preview: {embedding[:5]} \n")
    

    
documents = load_documents("/home/ruta/irishep/hint_files")
print(f"loaded {len(documents)} documents from the folder \n")

splits = split_docs(documents)
print(f"split the documents into {len(splits)} chunks \n")
    
embednstore(splits, collection)
print(f"stored {len(splits)} embedded chunks \n")


In [None]:
questions = [
    #"Q1: Plot the missing transverse energy of all events in XYZ dataset (or file)",
    #"Q2: Plot the transverse momentum of all jets of all events in XYZ dataset",
    "Q3: Plot the transverse momentum of all jets with |η| < 1 in the first 10000 events of XYZ dataset",
    #"Q4: Plot the transverse momentum of all jets with |η| < 1 in the first 10000 events of datasets XYZ and ABC and overlay the results",
    "Q5: Plot the missing transverse energy of events that have at least two jets with pT > 40 GeV of XYZ dataset",
    "Q6: Plot the missing transverse energy of events that have an opposite-charge muon pair with an invariant mass between 60 and 120 GeV in XYZ dataset"
]

for input in questions:
    print(f"\n=== QUESTION: {input} ===\n")
    resp = ollama.embed(model="mxbai-embed-large", input=input)
    query_embedding = resp["embeddings"][0]

    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=5
    )

    # Limit context window
    max_context_chars = 24000
    chunks = results['documents'][0]
    data = ""
    for chunk in chunks:
        if len(data) + len(chunk) > max_context_chars:
            break
        data += "\n\n" + chunk

    max_tries = 5
    success = False
    error_message = ""
    last_code = ""

    for trial in range(max_tries):
        prompt = f"""You are a helpful assistant with access to these CMS specific hint files with python code snippets: {data}
Only use the above data to answer the following question, without hallucinating or making up your own statements: {input}
The expected output is a python code snippet that only contains code between triple backticks like this: ``` [code] ```
If the answer is not in the provided data, say "I don't know based on the available information."
If you get an error, here is the error message: {error_message}
If you tried code previously, here is the last attempt:
{last_code}
Please fix the code if there was an error, otherwise try again.
"""
        output = ollama.generate(
            model="llama3",
            prompt=prompt,
        )
        print(f"(TRY {trial+1}):\n", output['response'])

        code_start = output['response'].find("```")
        code_end = output['response'].rfind("```")
        if code_start != -1 and code_end != -1 and code_end > code_start:
            if output['response'].startswith("```python", code_start):
                code_start += len("```python")
            else:
                code_start += len("```")
            code = output['response'][code_start:code_end].strip()
            last_code = code
            try:
                exec(code)
                print("SUCCESS")
                success = True
                break
            except Exception as e:
                error_message = str(e)
                print(f"Error running code: {error_message}")
        else:
            print("no code block found in response")
            error_message = "no code block found in response"
            last_code = ""
    if not success:
        print("no valid code snippet ran without errors after max trials")
        