In [1]:
import os
import json
import sys
import gzip
import tarfile
import xml.etree.ElementTree as ET
import random
from typing import Any, Dict, List, Tuple
import tiktoken
from llama_index.core import VectorStoreIndex, Settings, Document, StorageContext, load_index_from_storage
#from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.core.postprocessor import MetadataReplacementPostProcessor

from langchain_ollama import ChatOllama
from langchain.chains import create_retrieval_chain
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document as LangChainDocument
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from pydantic import Field

import gradio as gr
import logging


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("llama_index.embeddings.ollama").setLevel(logging.WARNING)

In [3]:
# Set parameters
# Local path
TAR_PATH = r"C:\Users\ssick\OneDrive\Master_TU\NLP\Schweiz.tar"
PERSIST_DIR = "./storage_local_llama" 

DOC_LIMIT = 100
SOFA_NAMESPACE = "{http:///uima/cas.ecore}Sofa"



In [5]:
# Set storage for Q&A
CHAT_LOG = []
LOG_FILE = "chat_history_llama.json"

# Utility functions
def save_log_to_disk():
    with open(LOG_FILE, "w", encoding="utf-8") as f:
        json.dump(CHAT_LOG, f, ensure_ascii=False, indent=2)



# Setup Models
Settings.embed_model = OllamaEmbedding(
    model_name="nomic-embed-text",
    embed_batch_size=10,
    request_timeout=300.0
)
Settings.llm = Ollama(model="llama3.2:3b", request_timeout=120.0)

# Setup SentenceWindow 
Settings.node_parser = SentenceWindowNodeParser.from_defaults(
    window_size=3,
    window_metadata_key="window",
    original_text_metadata_key="original_text",
)

try:
    ChatOllama.model_rebuild()
except Exception:
    pass

# Main Chat LLM 
llm = ChatOllama(model="llama3.2:3b", temperature=0)

# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("cl100k_base")

# Load Data
def extract_raw_text(tar_path, doc_limit):
    docs = []
    total_tokens = 0  # Global token counter
    
    if not os.path.exists(tar_path):
        print(f"Error: Tar file not found at {tar_path}")
        return [], 0

    try:
        with tarfile.open(tar_path, "r:*") as tar:
            for m in tar.getmembers():
                name = m.name.lower()
                if name.endswith((".xmi", ".xmi.gz")):
                    f = tar.extractfile(m)
                    if f is None: continue
                    data = f.read()
                    if name.endswith(".gz"):
                        data = gzip.decompress(data)
                    
                    root = ET.fromstring(data)

                    sofa = root.find(f".//{SOFA_NAMESPACE}")
                    
                    if sofa is not None:
                        text = sofa.get("sofaString")
                        if text:
                            text = text.replace('\r\n', ' ').replace('\n', ' ').strip()
                            
                          
                            tokens = tokenizer.encode(text)
                            total_tokens += len(tokens)
                            
                            docs.append({
                                "id": os.path.basename(m.name).replace(".xmi", ""),
                                "text": text
                            })
                            
                if len(docs) >= doc_limit:
                    break
    except Exception as e:
        print(f"Extraction error: {e}")
        return [], 0
        
    return docs, total_tokens

raw_data, total_raw_tokens = extract_raw_text(TAR_PATH, DOC_LIMIT)
documents = [Document(text=e["text"], id_=e["id"], metadata={"id_": e["id"]}) for e in raw_data]

document_ids = [doc.id_ for doc in documents]
print(document_ids)




['20150914.gz', '20060918.gz', '20111220.gz', '20200302.gz', '20050308.gz', '20201217.gz', '20080918.gz', '20050317.gz', '20050927.gz', '20090320.gz', '20140305.gz', '20170502.gz', '20030605.gz', '20030925.gz', '20080312.gz', '20210923.gz', '20190320.gz', '20050920.gz', '20161208.gz', '20161216.gz', '20180308.gz', '20101208.gz', '20000324.gz', '20020311.gz', '20140926.gz', '20210609.gz', '20170302.gz', '20080305.gz', '20030304.gz', '20010620.gz', '20060620.gz', '20210602.gz', '20210316.gz', '20091126.gz', '20040507.gz', '20070612.gz', '20150611.gz', '20010614.gz', '20200618.gz', '20080529.gz', '20021128.gz', '20120615.gz', '20120227.gz', '20140312.gz', '20180911.gz', '20210616.gz', '20181127.gz', '20150603.gz', '20110926.gz', '20100928.gz', '20040923.gz', '20020919.gz', '20121211.gz', '20011002.gz', '20011204.gz', '20211129.gz', '20081203.gz', '20040311.gz', '20130917.gz', '20100923.gz', '20110307.gz', '20100610.gz', '20200311.gz', '20020307.gz', '20170911.gz', '20200610.gz', '20191218

In [6]:
# Indexing
if os.path.exists(PERSIST_DIR):
    print(f"Load local index from {PERSIST_DIR}")
    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
    index = load_index_from_storage(storage_context)
else:
    # Create one note per sentence and store 3 before/after neighbors in metadata
    print("Creating new Sentence Window index")
    node_parser = SentenceWindowNodeParser.from_defaults(
        window_size=3,
        window_metadata_key="window",
        original_text_metadata_key="original_text",
    )
    
    nodes = node_parser.get_nodes_from_documents(documents)
    
    # Build the index 
    index = VectorStoreIndex(nodes, show_progress=True)
    index.storage_context.persist(persist_dir=PERSIST_DIR)



Load local index from ./storage_local_llama
INFO:llama_index.core.indices.loading:Loading all indices.


In [None]:
class LlamaIndexToLangChainRetriever(BaseRetriever):
    llama_retriever: Any = Field(exclude=True)
    postprocessor: Any = Field(exclude=True)
    
    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[LangChainDocument]:
        # Retrieve sentence nodes
        nodes = self.llama_retriever.retrieve(query)
        
        # Swap the sentence for the bigger context
        from llama_index.core.schema import NodeWithScore
        nodes_with_score = [NodeWithScore(node=n.node, score=n.score) for n in nodes]
        processed_nodes = self.postprocessor.postprocess_nodes(nodes_with_score)
        
        return [LangChainDocument(page_content=n.get_content(), metadata=n.metadata) for n in processed_nodes]

postprocessor = MetadataReplacementPostProcessor(target_metadata_key="window")

raw_retriever = index.as_retriever(similarity_top_k=5)
llama_retriever = LlamaIndexToLangChainRetriever(
    llama_retriever=raw_retriever, 
    postprocessor=postprocessor
)


# Create prompt
custom_template = """
Du bist ein Assistent für Schweizer Parlamentsprotokolle.
Deine Aufgabe ist es, Fragen basierend auf den bereitgestellten Textauszügen objektiv und faktenbasiert zu beantworten.

Regeln:
1. Nutze ausschließlich den bereitgestellten Kontext. Wenn die Information nicht enthalten ist, antworte: "Information nicht im Dokument enthalten."
2. Zitiere: Füge hinter jeder Faktenbehauptung die Source-ID (z.B. [ID: 20050927.gz]) ein.
3. Zagkeb: Extrahiere Zahlenwerte mit hoher Genauigkeit.
4. Struktur: Nutze Bullet-Points für Aufzählungen. Halte dich kurz, aber verliere keine Details.

Prozess:
Gehe Schritt für Schritt vor:
- Scanne den Kontext nach relevanten Namen, Daten und Zahlen.
- Vergleiche die Informationen, falls die Frage danach verlangt.
- Erstelle die finale Antwort.

KONTEXT:
{context}

FRAGE: {input}

ANTWORT:
"""

PROMPT = ChatPromptTemplate.from_template(custom_template)
document_chain = create_stuff_documents_chain(llm=llm, prompt=PROMPT)
qa_chain = create_retrieval_chain(retriever=llama_retriever, combine_docs_chain=document_chain)


In [None]:
def load_questions(input_path: str) -> List[Dict[str, Any]]:
    if not os.path.isfile(input_path):
        raise FileNotFoundError(f"Input file not found: {input_path}")

    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if not isinstance(data, list):
        raise ValueError("Input JSON must be a list of objects (array).")

    for i, item in enumerate(data):
        if not isinstance(item, dict):
            raise ValueError(f"Item at index {i} is not an object.")
        if "question" not in item:
            raise ValueError(f"Item at index {i} has no 'question' field.")
        if not isinstance(item["question"], str) or not item["question"].strip():
            raise ValueError(f"Item at index {i} has empty/invalid 'question' field.")

    return data


def save_output(output_path: str, data: List[Dict[str, Any]]) -> None:
    out_dir = os.path.dirname(os.path.abspath(output_path))
    os.makedirs(out_dir, exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def ask_rag_only_question(qa_chain, question_text: str, snippet_len: int = 300):
    response = qa_chain.invoke({"input": question_text})

    answer_text = response.get("answer", "")
    if not isinstance(answer_text, str):
        answer_text = str(answer_text)
    answer_text = answer_text.strip()

    source_ids = []
    context_snippets = []

    docs = response.get("context", None)
    if docs is None:
        docs = response.get("documents", [])

    for doc in docs:
        s_id = None
        try:
            s_id = doc.metadata.get("id_")
        except Exception:
            s_id = None

        text = getattr(doc, "page_content", "") or ""

        if isinstance(s_id, str) and s_id:
            source_ids.append(s_id)
            snippet = text[:snippet_len].replace("\n", " ").strip()
            context_snippets.append({"source_id": s_id, "snippet": snippet})

    source_ids = list(dict.fromkeys(source_ids))
    return answer_text, source_ids, context_snippets


def process_questions_file(qa_chain, input_path: str, output_path: str) -> None:
    items = load_questions(input_path)

    for idx, item in enumerate(items, start=1):
        question_text = item["question"].strip()

        answer_text = ""
        answer_source_ids = []
        context_snippets = []

        try:
            answer_text, answer_source_ids, context_snippets = ask_rag_only_question(
                qa_chain, question_text
            )
        except Exception as e:
            answer_text = f"Error: {str(e)}"
            answer_source_ids = []
            context_snippets = []

        item["answer"] = answer_text
        item["answer_source_id"] = answer_source_ids
        item["answer_context_snippets"] = context_snippets

        print(f"[{idx}/{len(items)}] id={item.get('id', 'NA')} done")

    save_output(output_path, items)
    print(f"Saved results to {output_path}")

input_questions_path = "data/questions_20.json"
output_answers_path = "data/ollama_local_20.json"

process_questions_file(qa_chain, input_questions_path, output_answers_path)