# RAG System with TFDF

## Imports

In [1]:
import os
import json
import logging
import sys
import gzip
import os
import tarfile
import xml.etree.ElementTree as ET
import re
import random

from llama_index.core import (
    VectorStoreIndex,
    Settings,
    Document,
    StorageContext,
    load_index_from_storage
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_classic.chains import create_retrieval_chain
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document as LangChainDocument
from langchain_core.prompts import PromptTemplate
from typing import List, Any
import gradio as gr

from langchain_core.prompts import ChatPromptTemplate
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_classic.chains import create_retrieval_chain

from langchain_openai import ChatOpenAI
from typing import Any, Dict, List, Tuple

from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.node_parser import SentenceSplitter

from pathlib import Path

resource module not available on Windows


## Configs

In [2]:
# --- 1. CONFIGURATION & IMPORTS ---
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

# Install langchain and its dependencies if not already installed
!pip install -qU langchain langchain-community

## Paths and APIs settings

In [3]:
# Set parameters
# Local path
TAR_PATH = (
     r"Schweiz.tar"
)

# API Key
api_key = ""

# Define where you want to save the indexed data
PERSIST_DIR = "./storage_tfdf"
ROOT_DIR =  Path.cwd().parent
input_questions_path = ROOT_DIR / "data" /"questions_20.json"
output_answers_path = ROOT_DIR / "data" /"tfdf_20.json"

DOC_LIMIT = 25
SOFA_NAMESPACE = "{http:///uima/cas.ecore}Sofa"

## Text extraction and model downloading

In [4]:
# Set storage for Q&A
CHAT_LOG = []  
LOG_FILE = "chat_history_tfdf.json" #

def save_log_to_disk():
    """Saves the current CHAT_LOG to a JSON file"""
    with open(LOG_FILE, "w", encoding="utf-8") as f:
        json.dump(CHAT_LOG, f, ensure_ascii=False, indent=2)
    print(f"Saved {len(CHAT_LOG)} interactions to {LOG_FILE}")


# Setup Models
Settings.node_parser = SentenceSplitter(chunk_size=1024, chunk_overlap=100)


# Load Data
def extract_raw_text(tar_path, doc_limit):
    """
    Iterates through the .xmi files in the tar archive and extracts the
    raw text content using ElementTree.
    """
    docs = []

    try:
        with tarfile.open(tar_path, "r:*") as tar:
            for m in tar.getmembers():
                name = m.name.lower()

                # Filter for XMI files
                if name.endswith((".xmi", ".xmi.gz", ".xmi.xmi.gz", ".xmi.xmi.gz")):
                    f = tar.extractfile(m)
                    if f is None:
                        continue

                    data = f.read()

                    if name.endswith(".gz"):
                        data = gzip.decompress(data)

                    # API for parsing and creating XML data
                    root = ET.fromstring(data)

                    # find sofa element with the text
                    sofa = root.find(f".//{SOFA_NAMESPACE}")

                    if sofa is not None:
                        text = sofa.get("sofaString")
                        if text:
                            # clean up unwanted characters before preprocessing
                            text = text.replace('\r\n', ' ').replace('\n', ' ').strip()
                            docs.append({
                                "id": os.path.basename(m.name).replace(".xmi", ""),
                                "text": text
                            })

                    # Stop after reaching the defined limit
                    if len(docs) >= doc_limit:
                        print(f"Reached document limit of {doc_limit}")
                        break

    except tarfile.TarError as e:
        print(f"Error reading tar file: {e}")
        return []
    except ET.ParseError as e:
        print(f"Error parsing XMI content: {e}")
        return []
    except Exception as e:
        print(f"An unexpected error occurred during extraction: {e}")
        return []

    print(f"Successfully extracted raw text from {len(docs)} documents.")
    return docs

raw_data = extract_raw_text(TAR_PATH, DOC_LIMIT)
if not raw_data:
    print("ERROR: No documents were found! Check your TAR_PATH and XML Namespace.")
    sys.exit(1)


print("Converting dictionaries to Documents")
documents = []
for entry in raw_data:
    doc = Document(
        text=entry["text"],
        id_=entry["id"],
        metadata={
            "id_": entry["id"]}
    )
    documents.append(doc)

document_ids = [doc.id_ for doc in documents]
print(document_ids)


Reached document limit of 25
Successfully extracted raw text from 25 documents.
Converting dictionaries to Documents
['20150914.gz', '20060918.gz', '20111220.gz', '20200302.gz', '20050308.gz', '20201217.gz', '20080918.gz', '20050317.gz', '20050927.gz', '20090320.gz', '20140305.gz', '20170502.gz', '20030605.gz', '20030925.gz', '20080312.gz', '20210923.gz', '20190320.gz', '20050920.gz', '20161208.gz', '20161216.gz', '20180308.gz', '20101208.gz', '20000324.gz', '20020311.gz', '20140926.gz']


## Indexing

In [5]:
# Indexing
splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=100)
nodes = splitter.get_nodes_from_documents(documents)

if not os.path.exists(PERSIST_DIR) or not os.listdir(PERSIST_DIR):
    print("Creating new BM25 index.")
    # Create the index from scratch
    bm25_retriever = BM25Retriever.from_defaults(
        nodes=nodes,
        similarity_top_k=5
    )
    # Persist it to the directory
    os.makedirs(PERSIST_DIR, exist_ok=True)
    bm25_retriever.persist(PERSIST_DIR)
else:
    print("Loading existing BM25 index from storage")
    bm25_retriever = BM25Retriever.from_persist_dir(PERSIST_DIR)

# Adapter Class
class LlamaIndexToLangChainRetriever(BaseRetriever):
    llama_retriever: Any 
    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[LangChainDocument]:
        nodes = self.llama_retriever.retrieve(query)
        langchain_docs = []
        for node in nodes:
            langchain_docs.append(
                LangChainDocument(page_content=node.get_content(), metadata=node.metadata)
            )
        return langchain_docs

Creating new BM25 index.


Finding newlines for mmindex:   0%|          | 0.00/4.90M [00:00<?, ?B/s]

## RAG Setup

In [6]:
# Connect to Langchain
llama_retriever = LlamaIndexToLangChainRetriever(llama_retriever=bm25_retriever)

llm = ChatOpenAI(
    model="gpt-4.1-mini",
    api_key=api_key,
    temperature=0,
    max_tokens=200
)


# Create Prompt
custom_template = """
    Du bist ein Assistent für Schweizer Parlamentsprotokolle.
    Deine Aufgabe ist es, Fragen basierend auf den bereitgestellten Textauszügen objektiv und faktenbasiert zu beantworten.
    
    Regeln:
    1. Nutze ausschließlich den bereitgestellten Kontext. Wenn die Information nicht enthalten ist, antworte: "Information nicht im Dokument enthalten."
    2. Zitiere: Füge hinter jeder Faktenbehauptung die Source-ID (z.B. [ID: 20050927.gz]) ein.
    3. Zagkeb: Extrahiere Zahlenwerte mit hoher Genauigkeit.
    4. Struktur: Nutze Bullet-Points für Aufzählungen. Halte dich kurz, aber verliere keine Details.
    
    Prozess:
    Gehe Schritt für Schritt vor:
    - Scanne den Kontext nach relevanten Namen, Daten und Zahlen.
    - Vergleiche die Informationen, falls die Frage danach verlangt.
    - Erstelle die finale Antwort.
    
    KONTEXT:
    {context}
    
    FRAGE: {input}
    
    ANTWORT:
    """

PROMPT = ChatPromptTemplate.from_template(custom_template)

document_chain = create_stuff_documents_chain(
    llm=llm,
    prompt=PROMPT
)

qa_chain = create_retrieval_chain(
    retriever=llama_retriever,
    combine_docs_chain=document_chain
)

## Questions pipeline

In [7]:
def load_questions(input_path: str) -> List[Dict[str, Any]]:
    if not os.path.isfile(input_path):
        raise FileNotFoundError(f"Input file not found: {input_path}")

    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if not isinstance(data, list):
        raise ValueError("Input JSON must be a list of objects (array).")

    for i, item in enumerate(data):
        if not isinstance(item, dict):
            raise ValueError(f"Item at index {i} is not an object.")
        if "question" not in item:
            raise ValueError(f"Item at index {i} has no 'question' field.")
        if not isinstance(item["question"], str) or not item["question"].strip():
            raise ValueError(f"Item at index {i} has empty/invalid 'question' field.")

    return data


def save_output(output_path: str, data: List[Dict[str, Any]]) -> None:
    out_dir = os.path.dirname(os.path.abspath(output_path))
    os.makedirs(out_dir, exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def ask_rag_only_question(qa_chain, question_text: str) -> Tuple[str, List[str]]:
    response = qa_chain.invoke({"input": question_text})

    answer_text = response.get("answer", "")
    if not isinstance(answer_text, str):
        answer_text = str(answer_text)
    answer_text = answer_text.strip()

    source_ids: List[str] = []
    for doc in response.get("context", []):
        s_id = None
        try:
            s_id = doc.metadata.get("id_")
        except Exception:
            s_id = None
        if isinstance(s_id, str) and s_id:
            source_ids.append(s_id)

    source_ids = list(dict.fromkeys(source_ids))

    return answer_text, source_ids


def process_questions_file(qa_chain, input_path: str, output_path: str) -> None:
    items = load_questions(input_path)

    for idx, item in enumerate(items, start=1):
        question_text = item["question"].strip()

        try:
            answer_text, answer_source_ids = ask_rag_only_question(qa_chain, question_text)
        except Exception as e:
            answer_text = f"Error: {str(e)}"
            answer_source_ids = []

        item["answer"] = answer_text
        item["answer_source_id"] = answer_source_ids

        print(f"[{idx}/{len(items)}] id={item.get('id', 'NA')} done")

    save_output(output_path, items)
    print(f"Saved results to {output_path}")



process_questions_file(qa_chain, input_questions_path, output_answers_path)

[1/20] id=1 done
[2/20] id=2 done
[3/20] id=3 done
[4/20] id=4 done
[5/20] id=5 done
[6/20] id=6 done
[7/20] id=7 done
[8/20] id=8 done
[9/20] id=9 done
[10/20] id=10 done
[11/20] id=11 done
[12/20] id=12 done
[13/20] id=13 done
[14/20] id=14 done
[15/20] id=15 done
[16/20] id=16 done
[17/20] id=17 done
[18/20] id=18 done
[19/20] id=19 done
[20/20] id=20 done
Saved results to D:\Study\NLP\tuw-ds-ws2025-nlp-g25-t13-main\data\tfdf_20.json
