# PLC KG ChatBot (Single Notebook)

Dieses Notebook baut einen **einzelnen** ChatBot auf deinem PLC Knowledge Graph (TTL/RDF) auf.

Design-Ziele:
- **Deterministisch wo möglich** (Tools für Call-Graph, Variable-Info, Trace, Similarity)
- **LLM nur als Planner + Text2SPARQL-Fallback**
- **Guardrails**: nur SELECT, LIMIT erzwingen, Code-Fences strippen
- **Plan → Execute → Answer** Ablauf (debugbar)

Referenzen / Best Practices:
- Plan-and-Execute Agent Pattern (LangGraph)
- SPARQL QA Chains & SPARQL Extraction Helper (LangChain)
- Tool-Guardrails & Role-Isolation gegen Prompt-Injection


## 0) Installation (optional)
Wenn du lokal etwas vermisst, installiere hier die Dependencies.

In [1]:
# Optional: einmalig ausführen (lokal)
%pip install -U rdflib pandas ipywidgets langchain-core langchain-openai langchain-community pydantic

Collecting pandas
  Downloading pandas-3.0.0-cp312-cp312-win_amd64.whl.metadata (19 kB)
Collecting ipywidgets
  Downloading ipywidgets-8.1.8-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.15-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.16-py3-none-any.whl.metadata (20 kB)
Downloading pandas-3.0.0-cp312-cp312-win_amd64.whl (9.7 MB)
   ---------------------------------------- 0.0/9.7 MB ? eta -:--:--
   ---------- ----------------------------- 2.6/9.7 MB 15.1 MB/s eta 0:00:01
   -------------------------- ------------- 6.6/9.7 MB 18.3 MB/s eta 0:00:01
   ---------------------------------------- 9.7/9.7 MB 18.4 MB/s  0:00:00
Downloading ipywidgets-8.1.8-py3-none-any.whl (139 kB)
Downloading jupyterlab_widgets-3.0.16-py3-none-any.whl (914 kB)
   ---------------------------------------- 0.0/914.9 kB ? eta -:--:--
   -------------

  You can safely remove it manually.

[notice] A new release of pip is available: 25.3 -> 26.0
[notice] To update, run: python.exe -m pip install --upgrade pip


## 1) Konfiguration
Passe die Pfade und Modelle an. Der Code versucht automatisch, eine TTL im selben Ordner oder unter /mnt/data zu finden.

In [17]:
from pathlib import Path

# === Pfad zur TTL-Datei ===
# 1) Lokal: setze hier deinen absoluten Pfad.
TTL_PATH = r"D:\MA_Python_Agent\MSRGuard_Anpassung\KGs\Test2_filled.ttl"
filename = "Test2_filled.ttl"

# 2) Autodetect (z.B. Sandbox)

print("TTL_PATH =", TTL_PATH)

# === Index-Datei (Similarity / Routine Index) ===
index_name = filename.replace(".ttl", "_routine_index.json")
INDEX_DIR = Path(r"D:\MA_Python_Agent\MSRGuard_Anpassung\KGs\ChatBotRoutinen")
INDEX_PATH = str(INDEX_DIR / index_name)
print("INDEX_PATH =", INDEX_PATH)

# === LLM Backend ===
# "openai" (via langchain_openai). Du kannst später "gemini" ergänzen.
LLM_BACKEND = "openai"

# OpenAI (LangChain) Settings
OPENAI_MODEL = "gpt-4o-mini"
OPENAI_TEMPERATURE = 0

# Limits
MAX_SPARQL_ROWS = 200

TTL_PATH = D:\MA_Python_Agent\MSRGuard_Anpassung\KGs\Test2_filled.ttl
INDEX_PATH = D:\MA_Python_Agent\MSRGuard_Anpassung\KGs\ChatBotRoutinen\Test2_filled_routine_index.json


## 2) Graph laden

In [18]:
from rdflib import Graph

g = Graph()
g.parse(TTL_PATH, format="turtle")

print("✅ Graph geladen")
print("Triples:", len(g))
print("Namespaces (Auszug):", list(g.namespaces())[:10])

✅ Graph geladen
Triples: 21335
Namespaces (Auszug): [('brick', rdflib.term.URIRef('https://brickschema.org/schema/Brick#')), ('csvw', rdflib.term.URIRef('http://www.w3.org/ns/csvw#')), ('dc', rdflib.term.URIRef('http://purl.org/dc/elements/1.1/')), ('dcat', rdflib.term.URIRef('http://www.w3.org/ns/dcat#')), ('dcmitype', rdflib.term.URIRef('http://purl.org/dc/dcmitype/')), ('dcterms', rdflib.term.URIRef('http://purl.org/dc/terms/')), ('dcam', rdflib.term.URIRef('http://purl.org/dc/dcam/')), ('doap', rdflib.term.URIRef('http://usefulinc.com/ns/doap#')), ('foaf', rdflib.term.URIRef('http://xmlns.com/foaf/0.1/')), ('geo', rdflib.term.URIRef('http://www.opengis.net/ont/geosparql#'))]


## 3) Schema Card (kompakte KG-Übersicht)
Diese Übersicht geht in Planner und Text2SPARQL Prompt.

In [19]:
from collections import Counter
from rdflib.namespace import RDF

def schema_card(graph: Graph, top_n: int = 15) -> str:
    pred_counts = Counter()
    type_counts = Counter()

    for s, p, o in graph:
        try:
            pred_counts[graph.qname(p)] += 1
        except Exception:
            pred_counts[str(p)] += 1

        if p == RDF.type:
            try:
                type_counts[graph.qname(o)] += 1
            except Exception:
                type_counts[str(o)] += 1

    lines = []
    lines.append("TOP CLASSES (rdf:type):")
    for k, v in type_counts.most_common(top_n):
        lines.append(f"  - {k}: {v}")
    lines.append("")
    lines.append("TOP PROPERTIES:")
    for k, v in pred_counts.most_common(top_n):
        lines.append(f"  - {k}: {v}")
    return "\n".join(lines)

SCHEMA_CARD = schema_card(g, top_n=15)
print(SCHEMA_CARD[:2000])

TOP CLASSES (rdf:type):
  - owl:NamedIndividual: 3640
  - ap:class_Variable: 1276
  - ap:class_ParameterAssignment: 988
  - ap:class_Port: 595
  - ap:class_FBInstance: 267
  - ap:class_POUCall: 190
  - ap:class_SignalSource: 154
  - ap:class_PortInstance: 98
  - ap:class_IOChannel: 79
  - ap:class_FBType: 57
  - ap:class_Skill: 48
  - ap:class_CustomFBType: 47
  - owl:DatatypeProperty: 35
  - owl:ObjectProperty: 31
  - ap:class_SkillImplementationHypothesis: 31

TOP PROPERTIES:
  - rdf:type: 7612
  - dp:hasVariableName: 1364
  - dp:hasVariableType: 1246
  - op:hasInternalVariable: 1158
  - op:usesVariable: 1158
  - op:hasAssignment: 988
  - op:assignsFrom: 988
  - op:assignsToPort: 930
  - op:hasPort: 595
  - dp:hasPortDirection: 595
  - dp:hasPortName: 595
  - dp:hasPortType: 595
  - op:implementsPort: 555
  - op:isInstanceOfFBType: 267
  - op:representsFBInstance: 267


## 4) SPARQL Helper (Guardrails)
- nur SELECT
- blockt UPDATE/Service
- erzwingt LIMIT
- Ergebnisse als Liste von Dicts

Wenn ein LLM SPARQL in Codeblöcke packt, extrahieren wir es robust.


In [20]:
import re
from typing import Any, Dict, List

DEFAULT_PREFIXES = """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX ag:  <http://www.semanticweb.org/AgentProgramParams/>
PREFIX dp:  <http://www.semanticweb.org/AgentProgramParams/dp_>
PREFIX op:  <http://www.semanticweb.org/AgentProgramParams/op_>
"""

def _normalize_ws(s: str) -> str:
    return re.sub(r"\s+", " ", s).strip()

def enforce_select_only(query: str, max_limit: int = 200) -> str:
    q = query.strip()
    q_u = _normalize_ws(q).upper()

    if not (q_u.startswith("PREFIX") or q_u.startswith("SELECT")):
        raise ValueError("Only SELECT queries are allowed (optionally with PREFIX).")

    forbidden = [
        "INSERT","DELETE","LOAD","CLEAR","CREATE","DROP","MOVE","COPY","ADD",
        "SERVICE","WITH","USING","GRAPH"
    ]
    for kw in forbidden:
        if re.search(rf"\b{kw}\b", q_u):
            raise ValueError(f"Forbidden SPARQL keyword detected: {kw}")

    m = re.search(r"\bLIMIT\s+(\d+)\b", q_u)
    if m:
        lim = int(m.group(1))
        if lim > max_limit:
            q = re.sub(r"(?i)\bLIMIT\s+\d+\b", f"LIMIT {max_limit}", q)
    else:
        q = q.rstrip() + f"\nLIMIT {max_limit}\n"
    return q

def strip_code_fences(text: str) -> str:
    t = text.strip()
    t = re.sub(r"^```[a-zA-Z]*\s*", "", t)
    t = re.sub(r"\s*```$", "", t)
    return t.strip()

try:
    from langchain_community.chains.graph_qa.neptune_sparql import extract_sparql as lc_extract_sparql
except Exception:
    lc_extract_sparql = None

def extract_sparql_from_llm(text: str) -> str:
    if lc_extract_sparql is not None:
        try:
            return lc_extract_sparql(text).strip()
        except Exception:
            pass
    t = strip_code_fences(text)
    m = re.search(r"(SELECT\s+.*)", t, flags=re.IGNORECASE | re.DOTALL)
    return (m.group(1).strip() if m else t)

def sparql_select_raw(query: str, max_rows: int = 200) -> List[Dict[str, Any]]:
    q = query.strip()
    if "PREFIX" not in q.upper():
        q = DEFAULT_PREFIXES + "\n" + q
    q = enforce_select_only(q, max_limit=max_rows)

    res = g.query(q)
    vars_ = [str(v) for v in res.vars]

    out: List[Dict[str, Any]] = []
    for row in res:
        item = {}
        for i, v in enumerate(vars_):
            val = row[i]
            item[v] = None if val is None else str(val)
        out.append(item)
    return out

## 5) Deterministische Tools (ohne LLM)
Diese Tools beantworten typische Fragen, ohne dass das LLM freie SPARQL generieren muss.


In [21]:
from dataclasses import dataclass
from typing import Optional, Set, Tuple
from rdflib import URIRef, Literal, Namespace
from rdflib.namespace import RDF

AG = Namespace("http://www.semanticweb.org/AgentProgramParams/")
DP = Namespace("http://www.semanticweb.org/AgentProgramParams/dp_")
OP = Namespace("http://www.semanticweb.org/AgentProgramParams/op_")

@dataclass
class SensorSnapshot:
    program_name: str
    sensor_values: Dict[str, Any]

@dataclass
class RoutineSignature:
    pou_name: str
    reachable_pous: List[str]
    called_pou_names: List[str]
    used_variable_names: List[str]
    hardware_addresses: List[str]
    port_names: List[str]

    def as_dict(self) -> Dict[str, Any]:
        return {
            "pou_name": self.pou_name,
            "reachable_pous": self.reachable_pous,
            "called_pou_names": self.called_pou_names,
            "used_variable_names": self.used_variable_names,
            "hardware_addresses": self.hardware_addresses,
            "port_names": self.port_names,
        }

class KGStore:
    def __init__(self, graph: Graph):
        self.g = graph
        self._pou_by_name: Dict[str, URIRef] = {}
        self._build_cache()

    def _build_cache(self) -> None:
        for pou, _, name in self.g.triples((None, DP.hasPOUName, None)):
            if isinstance(name, Literal):
                self._pou_by_name[str(name)] = pou

    def pou_uri_by_name(self, pou_name: str) -> Optional[URIRef]:
        return self._pou_by_name.get(pou_name)

    def pou_name(self, pou_uri: URIRef) -> str:
        v = self.g.value(pou_uri, DP.hasPOUName)
        return str(v) if v else str(pou_uri)

    def get_reachable_pous(self, root_pou_uri: URIRef) -> Set[URIRef]:
        visited: Set[URIRef] = set()
        queue: List[URIRef] = [root_pou_uri]
        while queue:
            cur = queue.pop(0)
            if cur in visited:
                continue
            visited.add(cur)
            for call in self.g.objects(cur, OP.containsPOUCall):
                for called in self.g.objects(call, OP.callsPOU):
                    if isinstance(called, URIRef) and called not in visited:
                        queue.append(called)
        return visited

    def get_called_pous(self, pou_uri: URIRef) -> Set[URIRef]:
        called: Set[URIRef] = set()
        for call in self.g.objects(pou_uri, OP.containsPOUCall):
            for target in self.g.objects(call, OP.callsPOU):
                if isinstance(target, URIRef):
                    called.add(target)
        return called

    def get_used_variables(self, pou_uri: URIRef) -> Set[URIRef]:
        vars_: Set[URIRef] = set()
        for v in self.g.objects(pou_uri, OP.usesVariable):
            if isinstance(v, URIRef):
                vars_.add(v)
        for v in self.g.objects(pou_uri, OP.hasInternalVariable):
            if isinstance(v, URIRef):
                vars_.add(v)
        return vars_

    def get_variable_names(self, var_uri: URIRef) -> Set[str]:
        names: Set[str] = set()
        for _, _, name in self.g.triples((var_uri, DP.hasVariableName, None)):
            if isinstance(name, Literal):
                names.add(str(name))
        return names

    def get_hardware_address(self, var_uri: URIRef) -> Optional[str]:
        v = self.g.value(var_uri, DP.hasHardwareAddress)
        return str(v) if v else None

    def get_ports_of_pou(self, pou_uri: URIRef) -> Set[URIRef]:
        ports: Set[URIRef] = set()
        for p in self.g.objects(pou_uri, OP.hasPort):
            if isinstance(p, URIRef):
                ports.add(p)
        return ports

    def get_port_name(self, port_uri: URIRef) -> str:
        v = self.g.value(port_uri, DP.hasPortName)
        return str(v) if v else ""

kg = KGStore(g)

def tool_list_programs() -> List[Dict[str, Any]]:
    q = """
    SELECT ?programName WHERE {
      ?program rdf:type ag:class_Program ;
               dp:hasProgramName ?programName .
    } ORDER BY ?programName
    """
    return sparql_select_raw(q, max_rows=MAX_SPARQL_ROWS)

def tool_get_program_overview(program_name: str) -> List[Dict[str, Any]]:
    q = f"""
    SELECT ?report WHERE {{
      ?program rdf:type ag:class_Program ;
               dp:hasProgramName \"{program_name}\" .
      OPTIONAL {{ ?program dp:hasConsistencyReport ?report }}
    }}
    """
    return sparql_select_raw(q, max_rows=MAX_SPARQL_ROWS)

def tool_get_called_pous(program_name: str) -> List[Dict[str, Any]]:
    q = f"""
    SELECT DISTINCT ?calleeName WHERE {{
      ?program rdf:type ag:class_Program ;
               dp:hasProgramName \"{program_name}\" ;
               op:containsPOUCall ?call .
      ?call op:callsPOU ?callee .
      OPTIONAL {{ ?callee dp:hasPOUName ?calleeName }}
    }} ORDER BY ?calleeName
    """
    return sparql_select_raw(q, max_rows=MAX_SPARQL_ROWS)

def tool_get_pou_code(pou_name: str) -> List[Dict[str, Any]]:
    q = f"""
    SELECT ?lang ?code ?report WHERE {{
      ?pou dp:hasPOUName \"{pou_name}\" .
      OPTIONAL {{ ?pou dp:hasPOULanguage ?lang }}
      OPTIONAL {{ ?pou dp:hasPOUCode ?code }}
      OPTIONAL {{ ?pou dp:hasConsistencyReport ?report }}
    }}
    """
    return sparql_select_raw(q, max_rows=MAX_SPARQL_ROWS)

def tool_search_variables(name_contains: str) -> List[Dict[str, Any]]:
    needle = name_contains.replace('"', '\\"')
    q = f"""
    SELECT DISTINCT ?name ?type ?addr WHERE {{
      ?var rdf:type ag:class_Variable ;
           dp:hasVariableName ?name ;
           dp:hasVariableType ?type .
      FILTER(CONTAINS(LCASE(STR(?name)), LCASE(\"{needle}\")))
      OPTIONAL {{ ?var dp:hasHardwareAddress ?addr }}
    }} ORDER BY ?name
    """
    return sparql_select_raw(q, max_rows=MAX_SPARQL_ROWS)

def tool_get_variable_trace(name_contains: str) -> List[Dict[str, Any]]:
    needle = name_contains.replace('"', '\\"')
    q = f"""
    SELECT DISTINCT ?varName ?exprText ?calleeName WHERE {{
      ?var rdf:type ag:class_Variable ;
           dp:hasVariableName ?varName .
      FILTER(CONTAINS(LCASE(STR(?varName)), LCASE(\"{needle}\")))

      OPTIONAL {{
        ?expr rdf:type ag:class_Expression ;
              dp:hasExpressionText ?exprText ;
              op:isExpressionCreatedBy ?var .
        OPTIONAL {{
          ?assign rdf:type ag:class_ParameterAssignment ;
                  op:assignsFrom ?expr .
          OPTIONAL {{
            ?pouCall rdf:type ag:class_POUCall ;
                     op:hasAssignment ?assign ;
                     op:callsPOU ?callee .
            OPTIONAL {{ ?callee dp:hasPOUName ?calleeName }}
          }}
        }}
      }}
    }}
    """
    return sparql_select_raw(q, max_rows=MAX_SPARQL_ROWS)

## 6) Routine-Signaturen + Similarity Index
Speichert Signaturen in einer JSON-Datei neben der TTL, damit Similarity Checks schnell sind.

In [23]:
import json
from pathlib import Path

def jaccard(a: Set[str], b: Set[str]) -> float:
    if not a and not b:
        return 0.0
    inter = len(a & b)
    union = len(a | b)
    return inter / union if union else 0.0

class SignatureExtractor:
    def __init__(self, kg: KGStore):
        self.kg = kg

    def extract_signature(self, pou_name: str) -> RoutineSignature:
        pou_uri = self.kg.pou_uri_by_name(pou_name)
        if pou_uri is None:
            raise ValueError(f"POU '{pou_name}' not found in KG.")

        reachable = self.kg.get_reachable_pous(pou_uri)

        reachable_names: Set[str] = set()
        called_names: Set[str] = set()
        used_var_names: Set[str] = set()
        hw_addrs: Set[str] = set()
        port_names: Set[str] = set()

        for rp in reachable:
            reachable_names.add(self.kg.pou_name(rp))
            for callee in self.kg.get_called_pous(rp):
                called_names.add(self.kg.pou_name(callee))
            for var in self.kg.get_used_variables(rp):
                used_var_names |= self.kg.get_variable_names(var)
                ha = self.kg.get_hardware_address(var)
                if ha:
                    hw_addrs.add(ha)
            for port in self.kg.get_ports_of_pou(rp):
                pn = self.kg.get_port_name(port)
                if pn:
                    port_names.add(pn)

        return RoutineSignature(
            pou_name=pou_name,
            reachable_pous=sorted(reachable_names),
            called_pou_names=sorted(called_names),
            used_variable_names=sorted(used_var_names),
            hardware_addresses=sorted(hw_addrs),
            port_names=sorted(port_names),
        )

class RoutineIndex:
    def __init__(self, signatures: List[RoutineSignature]):
        self.signatures = signatures

    def save(self, path: str) -> None:
        Path(path).write_text(
            json.dumps([s.as_dict() for s in self.signatures], indent=2, ensure_ascii=False),
            encoding="utf-8"
        )

    @staticmethod
    def load(path: str) -> "RoutineIndex":
        data = json.loads(Path(path).read_text(encoding="utf-8"))
        sigs = [RoutineSignature(**d) for d in data]
        return RoutineIndex(sigs)

    @staticmethod
    def build_from_kg(kg: KGStore, only_pous: Optional[List[str]] = None) -> "RoutineIndex":
        extractor = SignatureExtractor(kg)
        if only_pous is None:
            only_pous = sorted(kg._pou_by_name.keys())

        sigs: List[RoutineSignature] = []
        for name in only_pous:
            try:
                sigs.append(extractor.extract_signature(name))
            except Exception:
                pass
        return RoutineIndex(sigs)

    def find_similar(self, target: RoutineSignature, top_k: int = 5) -> List[Dict[str, Any]]:
        tgt_hw = set(target.hardware_addresses)
        tgt_vars = set(target.used_variable_names)
        tgt_called = set(target.called_pou_names)

        scored: List[Tuple[float, RoutineSignature]] = []
        for cand in self.signatures:
            cand_hw = set(cand.hardware_addresses)
            cand_vars = set(cand.used_variable_names)
            cand_called = set(cand.called_pou_names)

            sim_hw = jaccard(tgt_hw, cand_hw) if (tgt_hw or cand_hw) else 0.0
            sim_vars = jaccard(tgt_vars, cand_vars)
            sim_called = jaccard(tgt_called, cand_called)

            score = 0.55 * sim_hw + 0.25 * sim_vars + 0.20 * sim_called
            scored.append((score, cand))

        scored.sort(key=lambda x: x[0], reverse=True)
        return [{"score": round(s, 4), "pou_name": r.pou_name} for s, r in scored[:top_k]]

def classify_checkable_sensors(snapshot: SensorSnapshot, sig: RoutineSignature) -> Dict[str, str]:
    checkable_set = set(sig.used_variable_names) | set(sig.hardware_addresses)
    return {k: ("checkable" if k in checkable_set else "not_checkable") for k in snapshot.sensor_values.keys()}

# Build / Load index
from pathlib import Path
import json
from json import JSONDecodeError

p = Path(INDEX_PATH)

def try_load_index(path: Path):
    try:
        if not path.exists() or path.stat().st_size == 0:
            return None
        # BOM-sicher + Whitespace entfernen
        raw = path.read_text(encoding="utf-8-sig").strip()
        if not raw:
            return None
        data = json.loads(raw)
        sigs = [RoutineSignature(**d) for d in data]
        return RoutineIndex(sigs)
    except (JSONDecodeError, UnicodeError):
        return None

routine_index = try_load_index(p)
if routine_index is None:
    routine_index = RoutineIndex.build_from_kg(kg)
    routine_index.save(str(p))
    print("✅ RoutineIndex neu gebaut & gespeichert:", p)
else:
    print("✅ RoutineIndex geladen:", p)

✅ RoutineIndex neu gebaut & gespeichert: D:\MA_Python_Agent\MSRGuard_Anpassung\KGs\ChatBotRoutinen\Test2_filled_routine_index.json


## 7) LLM Setup
Planner + Text2SPARQL + Answerer.

In [24]:
from typing import Callable

def get_llm_invoke() -> Callable[[str, str], str]:
    if LLM_BACKEND == "openai":
        try:
            from langchain_openai import ChatOpenAI
            from langchain_core.messages import SystemMessage, HumanMessage
        except Exception as e:
            raise RuntimeError(
                "Bitte installiere langchain-openai + langchain-core.\n"
                "pip install -U langchain-openai langchain-core"
            ) from e

        llm = ChatOpenAI(model=OPENAI_MODEL, temperature=OPENAI_TEMPERATURE, max_tokens=1200)

        def _invoke(system: str, user: str) -> str:
            msgs = [SystemMessage(content=system), HumanMessage(content=user)]
            return llm.invoke(msgs).content

        return _invoke

    raise ValueError("LLM_BACKEND nicht unterstützt. Setze LLM_BACKEND='openai' oder erweitere den Wrapper.")

llm_invoke = get_llm_invoke()
print("✅ LLM Wrapper bereit:", LLM_BACKEND, OPENAI_MODEL)

OpenAIError: The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable

## 8) Text2SPARQL (Fallback)

In [None]:
TEXT2SPARQL_SYSTEM = f"""
Du erzeugst ausschließlich eine SPARQL SELECT Query für einen RDF Knowledge Graph eines SPS Programms.
Regeln:
- Gib NUR SPARQL zurück (keine Erklärung, kein Markdown).
- Nur SELECT (kein INSERT/DELETE/UPDATE, kein SERVICE).
- Nutze die Prefixes: rdf, ag, dp, op.
Schema Card:
{SCHEMA_CARD}
"""

def text2sparql(question: str) -> str:
    raw = llm_invoke(TEXT2SPARQL_SYSTEM, question)
    return extract_sparql_from_llm(raw).strip()

def tool_text2sparql_select(question: str, max_rows: int = 50) -> Dict[str, Any]:
    q = text2sparql(question)
    rows = sparql_select_raw(q, max_rows=max_rows)
    return {"sparql": q, "rows": rows}

## 9) Tools + Planner + Executor

In [None]:
def tool_exception_prep(program_name: str, snapshot: Dict[str, Any], top_k: int = 5) -> Dict[str, Any]:
    extractor = SignatureExtractor(kg)
    sig = extractor.extract_signature(program_name)
    snap = SensorSnapshot(program_name=program_name, sensor_values=snapshot)
    check_map = classify_checkable_sensors(snap, sig)
    similar = routine_index.find_similar(sig, top_k=top_k)
    return {
        "signature": sig.as_dict(),
        "checkable": check_map,
        "similar": similar,
    }

TOOLS = {
    "list_programs": lambda: tool_list_programs(),
    "program_overview": lambda program_name: tool_get_program_overview(program_name),
    "called_pous": lambda program_name: tool_get_called_pous(program_name),
    "pou_code": lambda pou_name: tool_get_pou_code(pou_name),
    "search_variables": lambda name_contains: tool_search_variables(name_contains),
    "variable_trace": lambda name_contains: tool_get_variable_trace(name_contains),
    "text2sparql_select": lambda question, max_rows=50: tool_text2sparql_select(question, max_rows=max_rows),
    "exception_prep": lambda program_name, snapshot, top_k=5: tool_exception_prep(program_name, snapshot, top_k=top_k),
}

TOOL_DESCRIPTIONS = """
Erlaubte Tools:
- list_programs()
- program_overview(program_name)
- called_pous(program_name)
- pou_code(pou_name)
- search_variables(name_contains)
- variable_trace(name_contains)
- text2sparql_select(question, max_rows)
- exception_prep(program_name, snapshot, top_k)

Regeln:
- Wenn eine Frage mit den Tools beantwortbar ist, nutze Tools, nicht freie SPARQL.
- text2sparql_select ist Fallback.
- Maximal 3 Schritte.
"""

In [None]:
import json

PLANNER_SYSTEM = f"""
Du bist ein Planner für einen PLC Knowledge-Graph ChatBot.
Du planst Tool-Aufrufe als JSON.
{TOOL_DESCRIPTIONS}

Ausgabeformat (NUR JSON, kein Markdown!):
{{
  "steps": [
    {{"tool": "list_programs", "args": {{}}}}
  ]
}}

Heuristiken:
- Wenn User Sensorwerte/Statuswerte nennt oder 'Exception'/'Fehlerbild' -> exception_prep
- Wenn User nach 'welche POUs', 'aufrufen', 'Call-Graph' -> called_pous
- Wenn User nach 'Code' -> pou_code
- Wenn User nach 'Variable', 'Adresse', 'I/O' -> search_variables oder variable_trace
- Sonst text2sparql_select
Max 3 Steps.
"""

def safe_json_loads(s: str) -> Dict[str, Any]:
    t = strip_code_fences(s)
    m = re.search(r"(\{.*\})", t, flags=re.DOTALL)
    t = m.group(1) if m else t
    return json.loads(t)

def make_plan(user_message: str, max_retries: int = 2) -> Dict[str, Any]:
    user = user_message.strip()
    last_err = None

    for _ in range(max_retries + 1):
        out = llm_invoke(PLANNER_SYSTEM, user)
        try:
            plan = safe_json_loads(out)
            if "steps" not in plan:
                raise ValueError("Missing 'steps'")
            return plan
        except Exception as e:
            last_err = e
            repair_system = "Du reparierst JSON. Gib NUR gültiges JSON zurück."
            repair_user = f"Repariere zu gültigem JSON:\n{out}"
            out = llm_invoke(repair_system, repair_user)

    raise RuntimeError(f"Planner failed to return valid JSON: {last_err}")

def execute_plan(plan: Dict[str, Any]) -> Dict[str, Any]:
    results: Dict[str, Any] = {}
    for i, step in enumerate(plan.get("steps", []), start=1):
        tool = step.get("tool")
        args = step.get("args", {}) or {}
        fn = TOOLS.get(tool)
        if fn is None:
            results[f"step_{i}_{tool}"] = {"error": "unknown tool", "tool": tool, "args": args}
            continue
        try:
            results[f"step_{i}_{tool}"] = fn(**args)
        except Exception as e:
            results[f"step_{i}_{tool}"] = {"error": str(e), "tool": tool, "args": args}
    return results

## 10) Answerer

In [None]:
ANSWER_SYSTEM = """
Du bist ein SPS-Assistent. Antworte auf Deutsch.
WICHTIG:
- Nutze ausschließlich die Fakten aus den Tool-Ergebnissen.
- Wenn Daten fehlen oder leer sind: sag klar, dass der KG dazu nichts liefert.
- Wenn SPARQL Ergebnisse da sind: fasse sie strukturiert zusammen (max 10 Punkte).
- Keine erfundenen Klassen/Properties.
"""

def make_answer(user_message: str, plan: Dict[str, Any], tool_results: Dict[str, Any]) -> str:
    payload = {
        "question": user_message,
        "plan": plan,
        "tool_results": tool_results,
    }
    user = "Hier sind Plan und Tool-Ergebnisse als JSON:\n" + json.dumps(payload, ensure_ascii=False, indent=2)[:12000]
    return llm_invoke(ANSWER_SYSTEM, user)

def chat_once(user_message: str, debug: bool = True) -> Dict[str, Any]:
    plan = make_plan(user_message)
    results = execute_plan(plan)
    answer = make_answer(user_message, plan, results)
    out = {"answer": answer}
    if debug:
        out["plan"] = plan
        out["tool_results"] = results
    return out

## 11) Chat UI (ipywidgets)

In [None]:
import ipywidgets as widgets
from IPython.display import display, Markdown

debug_toggle = widgets.Checkbox(value=True, description="Debug (Plan + Tool Results anzeigen)")
input_box = widgets.Textarea(
    placeholder="Frage stellen... (z.B. 'Welche POUs ruft HRL_SkillSet auf?')",
    layout=widgets.Layout(width="100%", height="90px")
)
send_btn = widgets.Button(description="Send", button_style="primary")
out = widgets.Output()

display(debug_toggle, input_box, send_btn, out)

def on_send(_):
    out.clear_output()
    user_msg = input_box.value.strip()
    if not user_msg:
        return

    with out:
        print("User:", user_msg)
        resp = chat_once(user_msg, debug=debug_toggle.value)
        display(Markdown("### Antwort"))
        print(resp["answer"])

        if debug_toggle.value:
            display(Markdown("### Plan"))
            print(json.dumps(resp["plan"], ensure_ascii=False, indent=2))
            display(Markdown("### Tool Results (gekürzt)"))
            print(json.dumps(resp["tool_results"], ensure_ascii=False, indent=2)[:8000])

send_btn.on_click(on_send)

## 12) Quick Tests

In [None]:
tests = [
    "Welche Programme gibt es?",
    "Welche POUs ruft HRL_SkillSet auf?",
    "Zeig mir den Code von JobMethode_Schablone",
    "Suche Variablen, die 'NotAus' enthalten",
    "Trace für DI04_EncoderStart02",
]

for q in tests:
    print("\n---\n", q)
    try:
        resp = chat_once(q, debug=False)
        print(resp["answer"][:700])
    except Exception as e:
        print("Fehler:", e)