In [47]:
import asyncio
import re
import json
import ast
from typing import TypedDict, Annotated
import operator
from langchain_together import ChatTogether
from langchain_core.messages import AnyMessage, HumanMessage, ToolMessage, AIMessage
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain_mcp_adapters.client import MultiServerMCPClient

In [None]:
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], operator.add]
    step: str
    models: list[str]
    domains: list[str]
    subdomains: dict[str, list[str]]
    query: str
    intent: str
    metrics: dict
    context: str
    selected_llm: str

memory = MemorySaver()
client = MultiServerMCPClient({
    "csv_r": {
        "command": "python3",
        "args": [
            "-u",
            "/home/sesi/testing/agents/ER/agentic-ai/experiments/HeteroLLMs/mcp_server/main.py"
        ],
        "cwd": "/home/sesi/testing/agents/ER/agentic-ai/experiments/HeteroLLMs/mcp_server/",
        "transport": "stdio",
    }
})

In [49]:
class Agent:
    def __init__(self, model, tools, checkpointer, system=""):
        self.system=system
        self.model = model.bind_tools(tools)
        self.tools = {t.name: t for t in tools}
        self.graph = self._build_graph(checkpointer)
        self.metric_weights = {
            "accuracy": 0.7,
            "latency_ms": 0.2,
            "memory_mb": 0.1,
        }

    def _build_graph(self, cp):
        g = StateGraph(AgentState)
        g.add_node("init", self.init_step)
        g.add_node("collect_metrics", self.collect_metrics_step)
        g.add_node("respond", self.respond_step)
        g.add_edge("init", "collect_metrics")
        g.add_edge("collect_metrics", "respond")
        g.set_entry_point("init")
        return g.compile(checkpointer=cp)

    def classify_intent(self, query) -> str:
        print("DEBUG: helper function 'classify_intent'")
        if isinstance(query, list):
            q = " ".join(map(str, query)).lower()
        else:
            q = str(query).lower()

        print(f"DEBUG: classify_intent received query: '{q}'")
        if re.search(r"\b(models|list|available)\b", q):
            return "list_models"
        if re.search(r"\b(compare|better|best|prefer|efficient|memory|performance)\b", q):
            return "compare"
        if re.search(r"\b(metrics|performance)\b", q):
            return "metrics"
        return "metrics"

    async def map_query_to_domain_subdomain_llm(self, query, domains, subdomains):

        print("DEBUG: helper function 'map_query_to_domain_subdomain_llm'")
        lines = []
        for d in domains:
            subs = [s.lower() for s in subdomains.get(d, [])] or ["none"]
            lines.append(f"- {d.lower()}: {', '.join(subs)}")
        prompt = f"""
                You are an AI assistant.
                Available domains and subdomains:
                {chr(10).join(lines)}
                User query: "{query}"
                The query can map to multiple domains and subdomains. List top 2 closest(if any):
                Respond with EXACTLY two lines for each mapping:
                Domain: <one domain above>
                Subdomain: <one subdomain above or None>
                """
        resp = await self.model.ainvoke([HumanMessage(content=prompt)])
        # print(resp)
        dm = re.search(r"Domain:\s*(\S+)", resp.content)
        sd = re.search(r"Subdomain:\s*(\S+)", resp.content)
        domain = dm.group(1) if dm else None
        sub = sd.group(1).lower() if sd and sd.group(1).lower() != 'none' else None
        print(f"DEBUG: mapping result -> domain={domain}, subdomain={sub}")
        return domain, sub, f"LLM mapped to {domain}/{sub} or 'None'"

    def normalize_accuracy(self, v):
        return v / 100 if isinstance(v, (int, float)) and v > 1 else v

    def normalize_metric(self, value, metric_type, all_values):
            """Normalize a metric to [0,1] based on min-max scaling."""
            if not all_values:
                return 0.0
            min_val = min(all_values)
            max_val = max(all_values)
            if max_val == min_val:
                return 1.0  # Avoid division by zero
            # For latency and memory, lower is better, so invert the normalization
            if metric_type in ["latency_ms", "memory_mb"]:
                return (max_val - value) / (max_val - min_val)
            return (value - min_val) / (max_val - min_val)

    def select_best_llm(self, metrics, key):
        print("DEBUG: helper function 'select_best_llm'")
        if not metrics:
            return None, "No metrics data available"

        # Collect all metric values for normalization
        metric_values = {
            "accuracy": [],
            "latency_ms": [],
            "memory_mb": []
        }
        for model_key, data in metrics.items():
            if isinstance(data, dict):
                for metric in metric_values:
                    value = data.get(metric, 0.0)
                    if isinstance(value, (int, float)):
                        metric_values[metric].append(value)

        scores = {}
        # print(metrics)
        # print(key)
        for model_key, data in metrics.items():
            if not isinstance(data, dict):
                continue
            score = 0.0
            for metric, weight in self.metric_weights.items():
                value = data.get(metric, 0.0)
                if isinstance(value, (int, float)) and metric_values[metric]:
                    normalized = self.normalize_metric(value, metric, metric_values[metric])
                    score += weight * normalized
            scores[model_key] = score
            print(f"DEBUG: weighted score: {score} for {model_key}")


        if not scores:
            return None, "No valid metrics for comparison"

        best_model = max(scores.items(), key=lambda x: x[1])[0]
        return best_model, f"Selected {best_model} based on weighted metrics (accuracy: {self.metric_weights['accuracy']}, latency: {self.metric_weights['latency_ms']}, memory: {self.metric_weights['memory_mb']})"

    async def init_step(self, state: AgentState):
        print("DEBUG: Entering node INIT")
        query = state["messages"][-1].content
        intent = self.classify_intent(query)
        
        print("DEBUG: fetching models/domains/subdomains")
        models = await self.tools["list_models"].ainvoke({})
        domains = await self.tools["list_domains"].ainvoke({})
        subdomains = {}
        for d in domains:
            subdomains[d] = await self.tools["list_sub_domains"].ainvoke({"domain": d})
        return {
            "step": "init",
            "query": query,
            "intent": intent,
            "models": models,
            "domains": domains,
            "subdomains": subdomains,
            "context": "",
            "metrics": {},
            "selected_llm": None,
            "messages": []
        }
        
    async def collect_metrics_step(self, state: AgentState):
        print("DEBUG: Entering node collect_metrics_step")
        query = state["query"]
        intent = state["intent"]
        metrics = {}
        messages = []
    
        domain, subdomain, mapping_reason = await self.map_query_to_domain_subdomain_llm(
            query, state["domains"], state["subdomains"]
        )
        # print(domain, subdomain, mapping_reason)
        messages.append(ToolMessage(
            tool_call_id="mapping",
            name="mapping",
            content=mapping_reason
        ))
    
        def parse_tool_result(raw):
            try:
                if isinstance(raw, dict):
                    data = raw
                elif isinstance(raw, str):
                    data = json.loads(raw)
                else:
                    data = {}
            except Exception:
                try:
                    data = ast.literal_eval(raw)
                except Exception:
                    data = {}
        
            if isinstance(data, dict) and 'accuracy' in data:
                data['accuracy'] = self.normalize_accuracy(data['accuracy'])
            return data

        if intent == "metrics" and domain:
            api_domain = domain.title()
            api_sub = subdomain.lower() if subdomain else None
    
            for model_name in state["models"]:
                print(f"DEBUG: Calling metrics for {model_name} in {api_domain}/{api_sub}")
                try:
                    if api_sub:
                        raw = await self.tools["get_metric_domain_subdomain"].ainvoke({
                            "model": model_name,
                            "domain": api_domain,
                            "sub_domain": api_sub,
                        })
                    else:
                        raw = await self.tools["get_metric_domain"].ainvoke({
                            "model": model_name,
                            "domain": api_domain,
                        })
                    # print(f"Raw output for {model_name}:\n{raw}\n---")
        
                    data = parse_tool_result(raw)
                    print(f"DEBUG: Printing data for {model_name}\n")
                    print(data)
                    print("\n")
                    key = f"{model_name}_{api_domain}_{api_sub}" if api_sub else f"{model_name}_{api_domain}"
                    metrics[key] = data
                    messages.append(ToolMessage(
                        tool_call_id=f"metrics_{key}",
                        name=("get_metric_domain_subdomain" if api_sub else "get_metric_domain"),
                        content=str(data)
                    ))
                except Exception as e:
                    print(f"Error fetching metrics for {model_name}: {e}")
            try:
                best_model, reason = self.select_best_llm(metrics, api_sub or api_domain)
                print(f"DEBUG: Best model: {best_model}\nReason:{reason}")
                state["selected_llm"] = best_model
                state["context"] = reason
            except Exception as e:
                print(f"Error selecting best model: {e}")
                state["selected_llm"] = None
                state["context"] = "Error during model selection."
    
        return {
            "step": "metrics",
            "metrics": metrics,
            "messages": messages,
            "selected_llm": state.get("selected_llm"),
            "context": state.get("context")
        }

    async def respond_step(self, state: AgentState):
        print("DEBUG: Entering node respond_step\n")
        raw_q = state.get("query")
        if isinstance(raw_q, list):
            query = " ".join(str(x) for x in raw_q).lower()
        else:
            query = str(raw_q or "").lower()
    
        intent = state.get("intent")
        metrics = state.get("metrics", {})
        context = state.get("context", "")
        selected_llm = state.get("selected_llm")
        # print(metrics)
        print(f"Responding with intent='{intent}', context={context}, selected_llm={selected_llm}, metrics={metrics}")
        response = f"Based on the analysis and context:\nContext: {context}\n\n"
        if intent == "list_models" and metrics.get("models"):
            response += f"Available models: {', '.join(metrics['models'])}\n"
        elif intent in ["compare", "metrics"]:
            if intent == "compare" and metrics.get("comparison"):
                comp = metrics["comparison"]
                best, best_acc = max(comp.items(), key=lambda kv: kv[1])
                response += f"Best model: {best} (accuracy={best_acc:.2f})\n"
                response += "Model accuracies:\n" + "\n".join(f"{m}: {a:.2f}" for m, a in comp.items()) + "\n"
            elif intent == "metrics" and metrics:
                response += "Metrics collected for models:\n"
                for m, v in metrics.items():
                    response += f"{m}:\n"
                    for metric_name, value in v.items():
                        response += f"  {metric_name}: {value}\n"
                    response += "\n"
                    # print("DEBUG: I am printing the response now....")
                    # print(response)
            else:
                response += "No data found for this query.\n"
            
        if selected_llm:
            model_parts = selected_llm.split("_")
            model_name = "_".join(model_parts[:2])
            print(f"\nModel: {model_name}")

In [51]:
prompt = """
You are an Intelligent router responsible for handling queries within a Multi-Agent System (MAS) framework.

Upon receiving a query from the MAS, your task is to:
1. Accurately classify the query's intent.
2. Identify the relevant domain(s) and subdomain(s) associated with the query.
3. Utilize available tool calls through the MCP server to:
   - Retrieve necessary information such as available models, domains, and subdomains.
   - Fetch and normalize model performance metrics.
   - Select the most suitable LLM(s) capable of addressing the query effectively.
4. If the query spans multiple domains or subdomains, you are allowed to make multiple tool calls to ensure optimal model selection.
5. Your response must only contain the selected model's name in the following format: "Model:___"
   
"""
async def main():
    tools = await client.get_tools()
    model = ChatTogether(model_name="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", temperature=0.0)
    agent = Agent(model, tools, checkpointer=memory,system=prompt)

    query = [
    "please respond the answers related to historical geometry?"
    # "Please respond the answers related to titration experiment?"
        # "Please respond the answers related to : My laptop keeps crashing when I open the video editing software!"
    ]
    thread = {"configurable": {"thread_id": "4"}}
    state = {
        "messages": [HumanMessage(content=query)],
        "step": "init",
        "models": [],
        "domains": [],
        "subdomains": {},
        "query": query,
        "intent": "",
        "metrics": {},
        "context": "",
        "selected_llm": None
    }
    result = await agent.graph.ainvoke(state, thread)
    # print(result["messages"][-1].content)

await main()

DEBUG: Entering node INIT
DEBUG: helper function 'classify_intent'
DEBUG: classify_intent received query: 'please respond the answers related to historical geometry?'
DEBUG: fetching models/domains/subdomains
DEBUG: Entering node collect_metrics_step
DEBUG: helper function 'map_query_to_domain_subdomain_llm'
DEBUG: mapping result -> domain=mathematics, subdomain=geometry
DEBUG: Calling metrics for Llama3.2_1B in Mathematics/geometry
DEBUG: Printing data for Llama3.2_1B

{'accuracy': 0.55, 'latency_ms': 7466.44, 'memory_mb': 0.0, 'cpu_ms': 1.2904, 'total_duration_ms': 7466.44, 'load_duration_ms': 91.056, 'prompt_eval_count': 52.875, 'prompt_eval_duration_ms': 417.27, 'eval_count': 149.275, 'eval_duration_ms': 6958.106}


DEBUG: Calling metrics for Qwen2.5_1.5B in Mathematics/geometry
DEBUG: Printing data for Qwen2.5_1.5B

{'accuracy': 0.675, 'latency_ms': 14657.977, 'memory_mb': 0.0, 'cpu_ms': 1.642, 'total_duration_ms': 14649.94, 'load_duration_ms': 69.45, 'prompt_eval_count': 58.0, 'p