## LLM-as-a-Judge evaluation 

In this notebook, we test some approaches of evaluating responses Kai generates for different migration scenarios using an LLM as a judge.

The responses are generated by running Kai manually in vscode with different models. The diff files of the responses are stored as artifacts for simplicity.

#### Goal

Goal of this exercise is to:
- verify & finalize the evaluation metrics
- finalize the evaluation prompts

### Process

Run Tools  -->  Run Kai (manual, responses collected)  -->  Run Tools  -->  Evaluate

#### Running tools

In this notebook, we run following tools before and after applying fixes from Kai:
- `mvn compile`
- `mvn test` (only if tests are present for a given test case)
- `analyze` (kantra used for analysis)

#### Evaluate

For evaluating, three metrics are calculated:

1. Completeness (C): Measures whether the issue is completely resolved
2. Functional Parity (F): Measures whether existing functionality is maintained
3. Knock-on Effort (E): Measures how much effort is needed to address new issues caused by the fix

The total score is normalized for each metric and a final weighted score is produced between 0-10:

```
Final Score = 10 * (0.5 * C + 0.3 * F + 0.2 * E)
```

#### Pre-requisites

- Create a virtualenv using Jupyter for running cells in this notebook. Use [requirements.txt](../requirements.txt) to install dependencies needed. 
- Copy the .env.sample file to .env. Select the model you want to use for evaluation - Only Bedrock and ChatOpenAI are supported. Add your LLM key for the model you want to use. Once setup, run the following cell to load .env file.
- Make sure you have _java_ and _mvn_ installed.
- Download the latest _kantra_ binary.


In [1]:
%load_ext dotenv
%dotenv

Before we begin, we write some common code in the following cell which we will need later on. Run this cell before proceeding.

In [6]:
# some common functions we will need later on for the evaluation
import os
import sys
import yaml
import subprocess
from git import Repo
from pathlib import Path
from langchain_openai import ChatOpenAI
from langchain_aws import ChatBedrockConverse

test_data_path = Path("test-data").absolute()
apps_repo_path = (test_data_path / "apps").absolute()
artifacts_path = (test_data_path / "artifacts").absolute()
apps_repo_path.mkdir(parents=True, exist_ok=True)
artifacts_path.mkdir(parents=True, exist_ok=True)
## NOTE: make sure this is correct for your system
kantra_path = Path.home() / ".kantra" / "kantra"
## NOTE: set these values
JAVA_HOME = Path("/usr/lib/jvm/java-21-openjdk/")
JAVA_BIN = JAVA_HOME / "bin"

def clone_repo(url: str, branch: str, path: Path):
    try:
        Repo.clone_from(url, depth=1, single_branch=True, branch=branch, to_path=path)
    except Exception as e:
        if "already exists" not in str(e):
            print("fatal error cloning repo")
            sys.exit(1)

def get_model():
    provider = os.getenv("model_provider")
    model_id = os.getenv("model_id")
    if not model_id or not provider:
        raise ValueError("model_id and/or model_provider are not set")
    match provider:
        case "chatbedrock":
            key_id = os.getenv("aws_access_key_id")
            access_key = os.getenv("aws_secret_access_key")
            region = os.getenv("region")
            if not region or not access_key or not key_id:
                raise ValueError("aws_region and/or aws_secret_access_key and/or aws_access_key_id is not set")
            return ChatBedrockConverse(
                model_id=model_id,
                aws_access_key_id=key_id,
                aws_secret_access_key=access_key,
                region_name=region,
                temperature=0.0,
            )
        case "chatopenai":
            api_key = os.getenv("OPEANAI_API_KEY") or os.getenv("api_key")
            if not api_key:
                raise ValueError("OPEANAI_API_KEY or api_key is not set")
            return ChatOpenAI(model=model_id, api_key=api_key, temperature=0.0)
        case _:
            raise ValueError(f"Invalid model provider: {provider}")

def parse_yaml(path: Path): 
    with open(path, "r") as f: return yaml.safe_load(f)

def clone_app(app_path: Path) -> Path:
    parsed = parse_yaml(app_path)
    repo_url = parsed["source_code"]["git"]["url"]
    branch = parsed["source_code"]["git"]["branch"]
    path = Path("test-data") / "apps" / parsed["name"]
    clone_repo(repo_url, branch, path)
    return path.absolute()

def parse_from_tc(tc_path: str, key: str):
    parsed = parse_yaml(tc_path)
    return parsed[key]

def run_command(cmd: list[str], stdout_path: Path, stderr_path: Path, cwd: str, env_vars: dict[str, str] = {}):
    pwd = os.getcwd()
    os.chdir(cwd)
    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env_vars)
    with open(stdout_path, "w") as f:
        f.write(result.stdout)
    with open(stderr_path, "w") as f:
        f.write(result.stderr)
    os.chdir(pwd)

def apply_diff(diff_path: Path, app_path: Path):
    repo = Repo(app_path)
    repo.git.apply(diff_path)

def git_reset(app_path: Path):
    repo = Repo(app_path)
    repo.git.reset("--hard")

def run_mvn(tc_name: str, app_path: Path, output_sub_dir: Path = Path(""), env_vars: dict[str, str] = {}):
    output_dir = (artifacts_path / tc_name / output_sub_dir / "mvn").absolute()
    output_dir.mkdir(parents=True, exist_ok=True)
    stdout_path = os.path.join(output_dir, "mvn_compile.log")
    stderr_path = os.path.join(output_dir, "mvn_compile.err")
    run_command(["mvn", "compile"], stdout_path, stderr_path, app_path, env_vars)

def run_mvn_test(tc_name: str, app_path: Path, test_selectors: list[str], output_sub_dir: Path = Path(""), env_vars: dict[str, str] = {}):
    output_dir = (artifacts_path / tc_name / output_sub_dir / "mvn").absolute()
    output_dir.mkdir(parents=True, exist_ok=True)
    stdout_path = os.path.join(output_dir, "mvn_test.log")
    stderr_path = os.path.join(output_dir, "mvn_test.err")
    mvn_test_cmd = ["mvn", "test"]
    for selector in test_selectors:
        mvn_test_cmd.append("-Dtest=")
        mvn_test_cmd.append(selector)
    run_command(mvn_test_cmd, stdout_path, stderr_path, app_path, env_vars)

def run_kantra(tc_name: str, app_path: Path, targets: list[str], output_sub_dir: Path = Path(""), env_vars: dict[str, str] = {}):
    output_dir = (artifacts_path / tc_name / output_sub_dir / "kantra").absolute()
    output_dir.mkdir(parents=True, exist_ok=True)
    stdout_path = os.path.join(output_dir, "kantra.log")
    stderr_path = os.path.join(output_dir, "kantra.err")
    kantra_cmd = [kantra_path, "analyze", "--overwrite", "--input", app_path, "--output", output_dir]
    for target in targets:
        kantra_cmd.append("--target")
        kantra_cmd.append(target.strip('"'))
    env_vars = os.environ.copy()
    env_vars["JAVA_HOME"] = JAVA_HOME
    env_vars["PATH"] = os.environ["PATH"] + ":" + str(JAVA_BIN)
    run_command(kantra_cmd, stdout_path, stderr_path, app_path, env_vars)

METRICS = {
    "TRIVIAL_CHANGES_NEEDED": 1,
    "COMPLEX_CHANGES_NEEDED": 3,
    "REDESIGN_NEEDED": 9,
    "NOT_FIXED": 10,
}

In the following cell, we have our LLM-as-a-Judge code. Run this cell before proceeding.

In [25]:
import re
from git import Repo, Head
from pathlib import Path
from typing import Annotated
from dataclasses import dataclass, field
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent

@dataclass
class EvaluationTools:
    app_path: Path
    repo_path: Path 
    tc_path: Path
    diff_path: Path
    changed_filepaths: list[str] = field(default_factory=list)
    changed_filelist: list[str] = field(default_factory=list)
    changed_files: list[str] = field(default_factory=dict)
    architecture_md: str = "Not available"

    def init(self):
        self.architecture_md = (self.app_path / "architecture.md").read_text()
        self.repo = Repo(self.repo_path)
        self.repo.git.reset("--hard")
        self.repo.git.apply(self.diff_path)
        self.changed_filepaths = self.repo.git.diff("--name-only").split()
        self.changed_filelist = [str(p.relative_to(self.repo_path)) for p in Path(self.repo_path).rglob("*") if p.is_file()]
        for path in self.changed_filepaths:
            self.changed_files[path] = (self.repo_path / Path(path)).read_text()
        self.repo.git.reset("--hard")

    def get_architecture_tool(self):
        architecture_md = self.architecture_md
        @tool
        def get_architecture():
            """Returns the architecture of the application in markdown format"""
            return architecture_md
        return get_architecture
    
    def get_changed_files_tool(self):
        changed_filepaths = self.changed_filepaths
        @tool
        def get_changed_files():
            """Returns the list of files that have been changed after applying the fixes"""
            return "\n".join(changed_filepaths)
        return get_changed_files
    
    def get_file_content_tool(self):
        changed_files = self.changed_files
        repo_path = self.repo_path
        @tool
        def get_file_content(
            file_path: Annotated[str, "Relative path to the file"],
            pre_migration: Annotated[bool, "Return a pre-migration version, defaults to False"] = False):
            """Returns the current content of a file, if pre_migration is set, returns its old content"""
            if pre_migration or file_path not in self.changed_files:
                return (repo_path / file_path).read_text()
            return changed_files[file_path]
        return get_file_content

    def list_files_tool(self):
        changed_filelist = self.changed_filelist
        repo_path = self.repo_path
        @tool
        def list_files(
            pre_migration: Annotated[bool, "Set to True to return pre-migration list of files, defaults to False"] = False):
            """Returns a list of files in the source code"""
            if pre_migration:
                return "\n".join([str(p.relative_to(repo_path)) for p in Path(repo_path).rglob("*") if p.is_file()])
            return "\n".join(changed_filelist)
        return list_files

def run_completeness_agent(
    issue: str,
    issue_notes: str,
    source_tech: str,
    target_tech: str,
    app_path: Path,
    repo_path: Path,
    tc_path: Path,
    diff_path: Path,
) -> int:
    COMPLETENESS_PROMPTS = {
    "system": """You are a senior software engineer expert in migrating applications from one technology to another.""",
    "user_summarize": """You are reviewing a code change made to fix a migration issue identified by a static analysis tool.
The application is being migrated from {source_tech} to {target_tech}.
You are provided with:
- The original issue description
- Developer notes describing the intended fix (consider these the only source of truth)
- Access to a set of tools that can inspect the codebase

## Your task

Your goal is to evaluate whether the migration issue has been completely fixed by comparing:
- What the notes say should have been done, and
- What was actually done in the codebase.

## Evaluation Rules
- Stick strictly to the notes when comparing. They are the authoritative reference.
- Do not assume additional requirements, behaviors, or conventions.
- Use tools thoughtfully. Only gather information necessary to confirm alignment between the notes and the actual code changes.
- Be specific. When identifying missing or incomplete work, clearly describe what evidence led you to that conclusion.
- Be concise but thorough. Focus on correctness and completeness.
- Refer explicitly to evidence from the codebase when making your assessment.
- Avoid subjective language like "it seems" or "probably".

## Output Format

If the issue is completely fixed, respond in the following format:

```
<Brief summary explaining how the applied changes match the notes and why they fully resolve the issue.>

COMPLETELY_FIXED
```

If the issue is not completely fixed, respond in the following format:

```
<Brief summary explaining why the fix is incomplete or incorrect, and what parts are missing or deviate from the notes.>

## Issues

1. <First issue summary>
2. <Second issue summary>
...
```

Each issue should represent a *logically distinct* missing or incorrect aspect of the fix. Do not produce duplicate issues. Group related issues together logically.

Here are your inputs:
## The issue identified was:
{issue}

## Here are the notes about the issue:
{issue_notes}
""",
    "user_rate": """We used a static analysis tool to identify a migration issue in the application which needs to be migrated from {source_tech} to {target_tech}.
We fixed the issue and asked a senior software engineer to review the changes. The engineer reviewed the changes and provided a summary of the changes made.

## Your task

Your goal is to:

- Analyze each *distinct unresolved issue* described in the summary.
- Assess the complexity of fixing each issue for given application.
- Finally, provide a rating for "completeness" of the fix based on the following scale:
    - TRIVIAL_CHANGES_NEEDED: Requires only a few small, localized modifications to files in the codebase (e.g., renaming, updating a parameter, minor logic tweak, or adding a missing import).
    - COMPLEX_CHANGES_NEEDED: Requires multiple related changes across files, complex logic changes, or coordination between components or services.
    - REDESIGN_NEEDED: Indicates a fundamental design or architectural flaw requiring a significant refactor or reimplementation of core logic.
    - COMPLETE: No issues remain; the migration fix is fully complete.

## Guidelines
- Focus only on issues explicitly described in the summary.
- Do not invent or infer new issues beyond what is stated.
- Use architectural reasoning (e.g., cross-file dependencies, data flow, API contracts) to decide the appropriate rating.
- Each issue should have a short justification describing why it falls into that category.
- Maintain objectivity — avoid vague or subjective language like “probably complex” or “seems fine.”
- If there are no issues identified in the summary, rate the migration as COMPLETE.
- Look at architecture of the application to help you make the decision.

Produce your response in following format:

```
<Brief reasoning behind the rating.>

Rating: <RATING>
```

Here are your inputs:

## Issue we fixed
{issue}

## Summary
{summary}
""",
}

    tool_factory = EvaluationTools(app_path, repo_path, tc_path, diff_path)
    tool_factory.init()
    tools = [
        tool_factory.get_architecture_tool(),
        tool_factory.get_changed_files_tool(),
        tool_factory.get_file_content_tool(),
    ]
    agent = create_react_agent(
        model=get_model(),
        tools=tools,
        prompt=COMPLETENESS_PROMPTS["system"],
    )
    response = agent.invoke(
        {
            "messages": [
                {
                    "role": "user",
                    "content": COMPLETENESS_PROMPTS["user_summarize"].format(
                        source_tech=source_tech,
                        target_tech=target_tech,
                        issue=issue,
                        issue_notes=issue_notes,
                    )
                },
            ]
        }
    )
    rate_agent = create_react_agent(
        model=get_model(),
        tools=[tool_factory.get_architecture_tool()],
        prompt=COMPLETENESS_PROMPTS["system"],
    )
    response = rate_agent.invoke(
        {
            "messages": [
                {
                    "role": "user",
                    "content": COMPLETENESS_PROMPTS["user_rate"].format(
                        source_tech=source_tech,
                        target_tech=target_tech,
                        issue=issue,
                        summary=response["messages"][-1].content,
                    )
                },
            ]
        }
    )
    content = response["messages"][-1].content
    match = re.search(r"Rating:\s*([^\n\r]+)", content)
    rating = -1
    if match:
        try:
            val = int(METRICS.get(match.group(1).strip(), -1))
            if val == -1:
                return val
            rating = (max(METRICS.values()) - int(val)) / max(METRICS.values())
        except ValueError:
            pass
    return rating

def run_functional_correctness_agent(
    issue: str,
    issue_notes: str,
    source_tech: str,
    target_tech: str,
    repo_path: Path,
    tc_path: Path,
    diff_path: Path,
):
    FUNCTIONAL_CORRECTNESS_PROMPTS = {
        "system": "You are a senior engineering expert in migrating source code as well as reviewing source code that is migrated from {source_tech} to {target_tech}.",
        "user_summarize": """You are evaluating whether the migrated code behaves functionally equivalent to its pre-migration behavior and/or meets the acceptance criteria defined for the migration.

You are given:
- The original migration issue that was intended to be fixed.
- Detailed notes describing the expected fix and acceptance criteria (this is the sole source of truth).
- Various tools to access post migration artifacts such as changed files, contents of pre and post migration files, results of behavioral tests (if present) and more.

## Your task

Determine whether the migrated code preserves the functional behavior of the pre-migration code.
Do not rely on assumptions, intuition, or unstated expectations.
Follow a deterministic and evidence-driven process:

- Examine pre-migration code to identify control flow, data flow, API contracts, input constraints, and edge-case handling.
- Refer to architecture.md (if available) to understand context and dependencies.
- Use the migration notes as the only authoritative specification for what the fix should achieve.
- Do not invent new requirements or infer functionality not explicitly mentioned.
- For each described behavior, check that the migrated version maintains equivalent logic, inputs/outputs, and side effects.
- Pay attention to API signatures, data transformations, validation rules, and error handling.
- Use behavioral tests to understand the expected runtime behavior.
- If tests fail, determine why:
  - If due to test harness or environment adaptation (e.g., dependency mismatch, configuration differences), this does not count against equivalence.
  - If due to missing or altered functionality, consider this a deviation.
- You are assessing functional parity, not quality, maintainability, or completeness of the migration itself.
- Your conclusion should be based solely on behavior and conformance to the notes.

## Your output

Finally, rate "functional equivalence" on a scale defined below:
* EQUIVALENT
  - The migrated code preserves all functional behavior compared to its pre-migration version.
  - Any test failures are attributable only to the harness or environment, not to business logic or data flow differences.
* SOMEWHAT_EQUIVALENT
  - The core behavior is largely preserved, but minor functional gaps exist that could be fixed with small code tweaks.
  - Tests fail due to small missing pieces or outdated expectations, but the intended behavior remains intact.
* NOT_EQUIVALENT
  - Critical functionality altered / missing in the migrated version of the code.
  - Behavioral expectations in the notes or pre-migration logic are not met.
  - Test failures reflect true functional regressions, not harness issues.

Produce your output in format below:

```
<Brief, objective reasoning for your rating — reference specific files, functions, or behaviors as evidence.>

Rating: <EQUIVALENT | SOMEWHAT_EQUIVALENT | NOT_EQUIVALENT>
```

Here are your inputs:

## Original issue

{issue}

## Issue notes

{issue_notes}

## Applicable behvioral tests
{}
""",
    }

    pass


## Scanario 1:  Ehcache 2 to 3 upgrade

| App         | Complexity |
|-------------|------------|
| Petclinic   |    High    |

See description of the issue found [here](../apps/petclinic/test_cases/ehcache-2-to-3/tc.yaml).

See notes on expected fix [here](../apps/petclinic/test_cases/ehcache-2-to-3/notes.md).


In [4]:
# run tools -> apply diff -> run tools

tc_1_name = "ehcache-2-to-3"
tc_1_app_path = Path("../apps/petclinic/app.yaml").absolute()
tc_1_path = Path("../apps/petclinic/test_cases/ehcache-2-to-3/tc.yaml").absolute()
tc_1_selectors = parse_from_tc(tc_1_path, "testSelectors")
tc_1_targets = parse_from_tc(tc_1_app_path, "targets")
tc_1_repo_path = clone_app(tc_1_app_path)
run_mvn(tc_1_name, tc_1_repo_path, "before")
run_mvn_test(tc_1_name, tc_1_repo_path, tc_1_selectors, "before")
run_kantra(tc_1_name, tc_1_repo_path, tc_1_targets, "before")

tc_1_agent_diff = Path("test-data/diffs/ehcache-2-to-3/gpt_4o_agent.diff").absolute()
tc_1_non_agent_diff = Path("test-data/diffs/ehcache-2-to-3/gpt_4o_non_agent.diff").absolute()
apply_diff(tc_1_agent_diff, tc_1_repo_path)
run_mvn(tc_1_name, tc_1_repo_path, "after")
run_mvn_test(tc_1_name, tc_1_repo_path, tc_1_selectors, "after")
run_kantra(tc_1_name, tc_1_repo_path, tc_1_targets, "after")
git_reset(tc_1_repo_path)

Now we have all the data we need for pre and post fix. We will use an LLM to evaluate the responses.

In [None]:
## Evaluating for completeness metric
tc_1_notes = (tc_1_path.parent / "notes.md").read_text()
tc_1_issue = parse_from_tc(tc_1_path, "description")
print("Completeness score: ", 
    run_completeness_agent(tc_1_issue, tc_1_notes, 
    "Spring Framework 5", "Spring Framework 6", 
    tc_1_app_path.parent, tc_1_repo_path, tc_1_path, tc_1_agent_diff))

Completeness score:  0.9
