Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/code eval metrics #29

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions continuous_eval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@
)
from continuous_eval.metrics.retrieval_precision_recall_f1 import PrecisionRecallF1
from continuous_eval.metrics.retrieval_ranked_metrics import RankedRetrievalMetrics
from continuous_eval.metrics.code_deterministic_metrics import (
CodeStringMatch,
PythonASTSimilarity,
)
319 changes: 319 additions & 0 deletions continuous_eval/metrics/code_deterministic_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
import ast

from munkres import Munkres
from thefuzz import fuzz

from continuous_eval.metrics.base import Metric


class CodeStringMatch(Metric):
def calculate(self, answer, ground_truths, **kwargs):
max_exact_match = 0
max_similarity_score = 0
for gt in ground_truths:
exact_match = float(answer == gt)
similarity_score = fuzz.ratio(answer, gt) / 100
if exact_match > max_exact_match:
max_exact_match = exact_match
if similarity_score > max_similarity_score:
max_similarity_score = similarity_score
return {
"Exact_Match_Score": max_exact_match,
"Fuzzy_Match_Score": max_similarity_score,
}


class PythonASTSimilarity(Metric):
"""
The following functions are adapted from python-ast-comparison by Pedro Salazar Paredes
Copyright (c) 2023 Pedro Salazar Paredes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copyright year in the docstring seems to be incorrect. It is mentioned as 2023, which is a future year. Please correct it.

Licensed under the MIT License
Source: https://github.com/PedroSalazarParedes/python-ast-comparison
Modifications: Adjusted to be used in the context of generated code evaluation
"""

def _compare_ASTs(self, ast_a: ast.AST, ast_b: ast.AST, reorder_depth: int) -> int:
"""Compare two ASTs corresponding to python programs.

Args:
ast_a: The first program AST to compare.
ast_b: The first program AST to compare.
reorder_depth: The maximum children reorder depth for better
performance.

Returns:
The number of matching nodes in the ASTs.
"""
children_a = list(ast.iter_child_nodes(ast_a))
children_b = list(ast.iter_child_nodes(ast_b))
if (type(ast_a) == type(ast_b)) and len(list(children_a)) == 0 and len(list(children_b)) == 0:
return 1

if (type(ast_a) != type(ast_b)) or (len(children_a) != len(children_b)):
return 0

if reorder_depth == 0:
match_index = sum(
map(
lambda pairs: self._compare_ASTs(pairs[0], pairs[1], reorder_depth),
zip(children_a, children_b),
)
)
return match_index + 1

elif reorder_depth > 0:
match_index = self._reorder_children_compare(ast_a, ast_b, reorder_depth - 1)
return match_index + 1

return 0

def _reorder_children_compare(self, ast_a: ast.AST, ast_b: ast.AST, reorder_depth: int) -> int:
"""Reorders child nodes and compares them.

Args:
ast_a: The first AST for child comparison.
ast_b: The second AST for child comparison.
reorder_depth: The maximum children reorder depth for better
performance.

Returns:
True if there is a way to match 1-1 every child node of ast_a
with every child node of ast_b, otherwise False.
"""
comparison_matrix = []
cost_matrix = []
best_match_value = 0
children_a = list(ast.iter_child_nodes(ast_a))
children_b = list(ast.iter_child_nodes(ast_b))

if len(children_a) <= 1 or len(children_b) <= 1:
for child_a in children_a:
for child_b in children_b:
best_match_value += self._compare_ASTs(child_a, child_b, reorder_depth)
else:
for child_a in children_a:
row = []
cost_row = []
for child_b in children_b:
similarity = self._compare_ASTs(child_a, child_b, reorder_depth)
row.append(similarity)
cost_row.append(10000000 - similarity)

comparison_matrix.append(row)
cost_matrix.append(cost_row)

m = Munkres()
indices = m.compute(cost_matrix) # type: ignore

for row, col in indices:
best_match_value += comparison_matrix[row][col]

return best_match_value

def _compare_subtrees(self, sig_subtrees_p1: list, sig_subtrees_p2: list, reorder_depth: int) -> tuple:
"""Compare two significant subtree lists reordering up to a certain depth.

Args:
sig_subtrees_p1: The first significant AST list for comparison.
sig_subtrees_p2: The second significant AST list for comparison.
reorder_depth: The maximum children reorder depth for better
performance.

Returns:
A tuple with the ratio of matching to non-matching nodes of the
significant subtrees, and a list with the best matching of subtrees.
"""
comparison_matrix = []
cost_matrix = []
best_match = []
best_match_value = 0
best_match_weight = 0
children_a = sig_subtrees_p1.copy()
children_b = sig_subtrees_p2.copy()

if len(children_a) <= 1 or len(children_b) <= 1:
for child_a in children_a:
best_match += [child_a]
for child_b in children_b:
best_match_value += self._compare_ASTs(child_a, child_b, reorder_depth)
best_match += [child_b]
else:
for child_a in children_a:
row = []
cost_row = []
for child_b in children_b:
similarity = self._compare_ASTs(child_a, child_b, reorder_depth)
row.append(similarity)
cost_row.append(10000000 - similarity)

comparison_matrix.append(row)
cost_matrix.append(cost_row)

m = Munkres()
indices = m.compute(cost_matrix) # type: ignore

for row, col in indices:
best_match_weight += self._apply_weights_to_subtrees_mult(
comparison_matrix[row][col],
sig_subtrees_p1[row],
sig_subtrees_p2[col],
)
best_match += [sig_subtrees_p1[row], sig_subtrees_p2[col]]

all_subtrees_weight = sum(
map(
lambda tree: self._apply_weights_to_subtrees(self._get_num_nodes(tree), tree),
sig_subtrees_p1,
)
) + sum(
map(
lambda tree: self._apply_weights_to_subtrees(self._get_num_nodes(tree), tree),
sig_subtrees_p2,
)
)

similarity = 2 * best_match_weight / all_subtrees_weight

return round(similarity, 4), best_match

def _is_significant(self, root: ast.AST) -> bool:
"""Determine if an AST is significant.

Args:
root: The AST whose significance we want.

Returns:
True for if it is significant, False otherwise.
"""
return (
isinstance(root, ast.Import)
or isinstance(root, ast.FunctionDef)
or isinstance(root, ast.If)
or isinstance(root, ast.ClassDef)
or isinstance(root, ast.While)
or isinstance(root, ast.For)
or isinstance(root, ast.comprehension)
or isinstance(root, ast.Return)
)

def _get_significant_subtrees(self, root: ast.AST) -> list:
"""Find the significant subtrees of an AST.

Args:
root: The root of the main AST.

Returns:
A list with all the significant subtrees of root.
"""
significant_subtrees = []
for node in ast.walk(root):
if self._is_significant(node):
significant_subtrees.append(node)
return significant_subtrees

def _get_num_nodes(self, root: ast.AST) -> int:
"""Find the number of nodes for a given tree.

Args:
root: The root of the tree whose size we want.

Returns:
The number of nodes in the tree.
"""
return len(list(ast.walk(root)))

def _apply_weights_to_subtrees(self, weight: float, subtree: ast.AST) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making the weights applied to the subtrees in the _apply_weights_to_subtrees method configurable through method parameters or class attributes for greater flexibility.

"""Apply weights to subtrees according to the time por their roots.

Args:
weight: The number of nodes in the subtree.
subtree: The subtree.

Returns:
The weighed weight of the tree.
"""
new_weight = weight
if isinstance(subtree, ast.Import):
new_weight *= 0.3
elif isinstance(subtree, ast.Module):
new_weight *= 1
elif isinstance(subtree, ast.FunctionDef):
new_weight *= 1.2
elif isinstance(subtree, ast.If):
new_weight *= 0.5
elif isinstance(subtree, ast.ClassDef):
new_weight *= 1
elif isinstance(subtree, ast.While):
new_weight *= 1
elif isinstance(subtree, ast.For):
new_weight *= 1
elif isinstance(subtree, ast.comprehension):
new_weight *= 1
elif isinstance(subtree, ast.Return):
new_weight *= 1
return new_weight

def _apply_weights_to_subtrees_mult(self, weight: float, ast_1: ast.AST, ast_2: ast.AST) -> float:
"""Find the average weight of both trees in order to weigh the comparison.

Args:
weight: The weight of the comparison.
ast_1: The first compared tree.
ast_2: The second compared tree.

Returns:
The average of the subtrees' weights.
"""
if weight == 0:
return 0
else:
return (self._apply_weights_to_subtrees(weight, ast_1) + self._apply_weights_to_subtrees(weight, ast_2)) / 2

def _compare_many(self, programs: list) -> list:
"""Compare all of the programs in the list.

Args:
programs: A list of strings with python programs.

Returns:
A matrix with the similarity rating of between all the programs.
"""
tree_list = list(map(lambda prog: self._get_significant_subtrees(ast.parse(prog)), programs))

matrix = []
for program_1_tree_num in range(0, len(tree_list)):
for program_2_tree_num in range(program_1_tree_num, len(tree_list)):
if program_1_tree_num == program_2_tree_num:
continue

subtrees1 = tree_list[program_1_tree_num]
subtrees2 = tree_list[program_2_tree_num]

result = self._compare_subtrees(subtrees1, subtrees2, 1000)[0]

matrix.append((program_1_tree_num, program_2_tree_num, result))
matrix.append((program_2_tree_num, program_1_tree_num, result))

return matrix

def calculate(self, answer, ground_truths, **kwargs):

try:
answer_tree = ast.parse(answer, mode="exec")
ground_truth_trees = [ast.parse(gt, mode="exec") for gt in ground_truths]
except SyntaxError as e:
return {"Python_AST_Similarity": -1.0}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning a Python_AST_Similarity score of -1.0 when a SyntaxError is caught could be misleading. Consider returning None or raising an exception to indicate that an error occurred.


answer_subtree = self._get_significant_subtrees(answer_tree)
ground_truth_subtrees = [
self._get_significant_subtrees(ground_truth_tree) for ground_truth_tree in ground_truth_trees
]

similarity_scores = [
self._compare_subtrees(answer_subtree, ground_truth_subtree, 1000)[0]
for ground_truth_subtree in ground_truth_subtrees
]

return {
"Python_AST_Similarity": max(similarity_scores),
}
39 changes: 1 addition & 38 deletions docs/README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,7 @@
# Starlight Starter Kit: Basics
# Documentation

[![Built with Starlight](https://astro.badg.es/v2/built-with-starlight/tiny.svg)](https://starlight.astro.build)

```
npm create astro@latest -- --template starlight
```

[![Open in StackBlitz](https://developer.stackblitz.com/img/open_in_stackblitz.svg)](https://stackblitz.com/github/withastro/starlight/tree/main/examples/basics)
[![Open with CodeSandbox](https://assets.codesandbox.io/github/button-edit-lime.svg)](https://codesandbox.io/p/sandbox/github/withastro/starlight/tree/main/examples/basics)
[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fwithastro%2Fstarlight%2Ftree%2Fmain%2Fexamples%2Fbasics&project-name=my-starlight-docs&repository-name=my-starlight-docs)

> 🧑‍🚀 **Seasoned astronaut?** Delete this file. Have fun!

## 🚀 Project Structure

Inside of your Astro + Starlight project, you'll see the following folders and files:

```
.
├── public/
├── src/
│ ├── assets/
│ ├── content/
│ │ ├── docs/
│ │ └── config.ts
│ └── env.d.ts
├── astro.config.mjs
├── package.json
└── tsconfig.json
```

Starlight looks for `.md` or `.mdx` files in the `src/content/docs/` directory. Each file is exposed as a route based on its file name.

Images can be added to `src/assets/` and embedded in Markdown with a relative link.

Static assets, like favicons, can be placed in the `public/` directory.

## 🧞 Commands

All commands are run from the root of the project, from a terminal:
Expand All @@ -49,6 +15,3 @@ All commands are run from the root of the project, from a terminal:
| `npm run astro ...` | Run CLI commands like `astro add`, `astro check` |
| `npm run astro -- --help` | Get help using the Astro CLI |

## 👀 Want to learn more?

Check out [Starlight’s docs](https://starlight.astro.build/), read [the Astro documentation](https://docs.astro.build), or jump into the [Astro Discord server](https://astro.build/chat).
4 changes: 4 additions & 0 deletions docs/astro.config.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ export default defineConfig({
},
]
},
{
label: 'Code',
autogenerate: {directory: '/metrics/Code/'}
},
{
label: 'Metric Ensembling',
autogenerate: { directory: '/metrics/ensembling/' },
Expand Down
Loading
Loading