In [88]:
import os
import sys
from pathlib import Path

CURRENT_DIRECTORY = Path(os.getcwd())
ROOT_DIRECTORY = (CURRENT_DIRECTORY / "..").absolute().resolve()

print(f"Current directory: {CURRENT_DIRECTORY}")
print(f"Root directory: {ROOT_DIRECTORY}")

sys.path.append(str(ROOT_DIRECTORY))

Current directory: /home/ubuntu/arga-arc/sygus
Root directory: /home/ubuntu/arga-arc


In [106]:
import typing as t
import json
from pprint import pprint
from dataclasses import dataclass
import tensorflow as tf
import numpy as np
import math
from config import CONFIG
from openai import OpenAI
import re
from collections import Counter
import random
import ast
import traceback
import sexpdata as sexp
from sygus.utils import BENCHMARK_NAMES, compute_output_file, SygusBenchmark, EXAMPLE_FILES, BENCHMARK_DIRECTORIES

pprint(CONFIG.__dict__.keys())


dict_keys(['OPENAI_SECRET_KEY', 'OPENAI_ORGANIZATION', 'TOGETHER_SECRET_KEY', 'TOGETHER_BASE_URL', 'OCTO_SECRET_KEY'])


In [107]:
BENCHMARKS = [SygusBenchmark.read_from_file(benchmark, compute_output_file(benchmark, "deepseek-ai/deepseek-coder-33b-instruct"), BENCHMARK_DIRECTORIES[benchmark]) for benchmark in ["string"]]
pprint(BENCHMARKS)

[<sygus.utils.SygusBenchmark object at 0x7f75baafc550>]


## extract

In [108]:
EXAMPLE_COMPLETIONS = [
    # normal/well-behaved example
    "\n(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 \",\")) (str.substr _arg_0 (+ (str.indexof _arg_0 \",\") 1) (str.len _arg_0)))\n)\n",
    # preceding text, code block
    "\nHere is a possible solution in SMT-LIB syntax:\n\n```\n(define-fun f ((_arg_0 String) (_arg_1 Int)) String\n  (let ((comma_pos (str.indexof _arg_0 \",\" 0)))\n    (if (= comma_pos -1)\n        _arg_0\n        (str.substr _arg_0\n                    (ite (= _arg_1 1) 0 (+ comma_pos 1))\n                    (ite (= _arg_1 1) comma_pos (str.len _arg_0))))))\n```\n\nThis function works by first finding the position of the comma in the input string. If there is no comma, it simply returns the input string. Otherwise, it returns a substring of the input string depending on the value of `_arg_1`. If `_arg_1",
    # text after
    "\n(str.replace _arg_0 \"<b>\" \"\")\n)\n\nThe function `str.replace` is used to replace all occurrences of \"<b>\" with an empty string. This effectively removes the \"<b>\" tags from the input.\n"
]

In [109]:
CODE_BLOCK_REGEX = r"```((?:(?!\n).)*\n)?((?:(?!```).)*)```"

def extract_code_block(completion: str) -> t.Optional[str]:
    match = re.search(CODE_BLOCK_REGEX, completion, re.DOTALL)
    if match:
        group_number = len(match.groups())
        return match.group(group_number).rstrip()
    else:
        return None

In [110]:
for completion in EXAMPLE_COMPLETIONS:
    print("## completion:")
    print(completion)
    extracted = extract_code_block(completion)
    print("### extracted code block:")
    print(extracted)
    print()

## completion:

(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))
)

### extracted code block:
None

## completion:

Here is a possible solution in SMT-LIB syntax:

```
(define-fun f ((_arg_0 String) (_arg_1 Int)) String
  (let ((comma_pos (str.indexof _arg_0 "," 0)))
    (if (= comma_pos -1)
        _arg_0
        (str.substr _arg_0
                    (ite (= _arg_1 1) 0 (+ comma_pos 1))
                    (ite (= _arg_1 1) comma_pos (str.len _arg_0))))))
```

This function works by first finding the position of the comma in the input string. If there is no comma, it simply returns the input string. Otherwise, it returns a substring of the input string depending on the value of `_arg_1`. If `_arg_1
### extracted code block:
(define-fun f ((_arg_0 String) (_arg_1 Int)) String
  (let ((comma_pos (str.indexof _arg_0 "," 0)))
    (if (= comma_pos -1)
        _arg_0
        (str.substr _arg_0
            

In [111]:
def extract_plain_code(completion: str) -> t.Optional[str]:
    return completion.split("\n\n")[0].strip()

In [112]:
for completion in EXAMPLE_COMPLETIONS:
    print("## completion:")
    print(completion)
    extracted = extract_plain_code(completion)
    print("### extracted plain coode:")
    print(extracted)
    print()

## completion:

(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))
)

### extracted plain coode:
(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))
)

## completion:

Here is a possible solution in SMT-LIB syntax:

```
(define-fun f ((_arg_0 String) (_arg_1 Int)) String
  (let ((comma_pos (str.indexof _arg_0 "," 0)))
    (if (= comma_pos -1)
        _arg_0
        (str.substr _arg_0
                    (ite (= _arg_1 1) 0 (+ comma_pos 1))
                    (ite (= _arg_1 1) comma_pos (str.len _arg_0))))))
```

This function works by first finding the position of the comma in the input string. If there is no comma, it simply returns the input string. Otherwise, it returns a substring of the input string depending on the value of `_arg_1`. If `_arg_1
### extracted plain coode:
Here is a possible solution in SMT-LIB syntax:

In [113]:
def remove_leading_close_paren(completion: str) -> str:
        return (
            completion.strip()[1:]
            if completion.strip().startswith(")")
            else completion.strip()
        )

In [114]:
def extract_code(completion: str) -> t.Optional[str]:
    code_block_result = extract_code_block(completion)

    if code_block_result is not None:
        ans = code_block_result
    else:
        ans = extract_plain_code(completion)
    
    return remove_leading_close_paren(ans)

In [115]:
for completion in EXAMPLE_COMPLETIONS:
    print("## completion:")
    print(completion)
    extracted = extract_code(completion)
    print("### extracted code:")
    print(extracted)
    print()

## completion:

(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))
)

### extracted code:
(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))
)

## completion:

Here is a possible solution in SMT-LIB syntax:

```
(define-fun f ((_arg_0 String) (_arg_1 Int)) String
  (let ((comma_pos (str.indexof _arg_0 "," 0)))
    (if (= comma_pos -1)
        _arg_0
        (str.substr _arg_0
                    (ite (= _arg_1 1) 0 (+ comma_pos 1))
                    (ite (= _arg_1 1) comma_pos (str.len _arg_0))))))
```

This function works by first finding the position of the comma in the input string. If there is no comma, it simply returns the input string. Otherwise, it returns a substring of the input string depending on the value of `_arg_1`. If `_arg_1
### extracted code:
(define-fun f ((_arg_0 String) (_arg_1 Int)) String
  (let (

## normalize

In [116]:
NORMALIZE_CODE_INPUTS = [
    """(str.replace _arg_0 "<b>" "")
))""",
    """(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))""",
"""(define-fun f ((_arg_0 String) (_arg_1 Int)) String
  (let ((comma_pos (str.indexof _arg_0 "," 0)))
    (if (= comma_pos -1)
        _arg_0
        (str.substr _arg_0
                    (ite (= _arg_1 1) 0 (+ comma_pos 1))
                    (ite (= _arg_1 1) comma_pos (str.len _arg_0))))))"""
]

In [117]:
def add_definition(code: str, definition: str) -> str:
    if "define-fun" in code:
        return code
    
    return definition + code

In [118]:
def add_closing_bracket(completion: str) -> str:
    return completion + ")"

def remove_closing_bracket(completion: str) -> str:
    return completion[:-1]

def balance_parens(completion: str) -> t.Optional[str]:
    current_completion = completion.strip()
    for _ in range(10):
        try:
            parsed = sexp.loads(current_completion)
            return sexp.dumps(parsed)
        except Exception as e:
            if "Not enough closing brackets." in str(e):
                current_completion = add_closing_bracket(completion)
            elif "Too many closing brackets." in str(e):
                current_completion = remove_closing_bracket(completion)
            else:
                print(f"Caught unexpected error parsing completion:")
                print(f"{completion}")
                print(traceback.format_exc())
                return None

In [119]:
def normalize_code(code: str, definition: str) -> str:
    return balance_parens(add_definition(code, definition))

In [120]:
for input in NORMALIZE_CODE_INPUTS:
    print("## input")
    print(input)
    print()
    print("### normalized code")
    print(normalize_code(input, "(define-fun f ((_arg_0 String) (_arg_1 Int)) String "))
    print()
    print()

## input
(str.replace _arg_0 "<b>" "")
))

### normalized code
((define-fun f ((_arg_0 String) (_arg_1 Int)) String (str.replace _arg_0 "<b>" "")))


## input
(ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))

### normalized code
((define-fun f ((_arg_0 String) (_arg_1 Int)) String (ite (= _arg_1 1) (str.substr _arg_0 0 (str.indexof _arg_0 ",")) (str.substr _arg_0 (+ (str.indexof _arg_0 ",") 1) (str.len _arg_0)))))


## input
(define-fun f ((_arg_0 String) (_arg_1 Int)) String
  (let ((comma_pos (str.indexof _arg_0 "," 0)))
    (if (= comma_pos -1)
        _arg_0
        (str.substr _arg_0
                    (ite (= _arg_1 1) 0 (+ comma_pos 1))
                    (ite (= _arg_1 1) comma_pos (str.len _arg_0))))))

### normalized code
((define-fun f ((_arg_0 String) (_arg_1 Int)) String (let ((comma_pos (str.indexof _arg_0 "," 0))) (if (= comma_pos -1) _arg_0 (str.substr _arg_0 (ite (= _arg_1 1) 0 (+ co

## testing it on everything

In [121]:
failed_stats = {
    "total": 0,
    "total_failed": 0,
    "not_extracted": 0,
    "invalid_syntax": 0,
    "num_failed_per_task": []
}

for filename, output in BENCHMARKS[0].output.items():
    num_failed = 0
    for completion in output["completions"]:
        failed_stats["total"] += 1
        extracted = extract_code(completion)
        if extracted is None:
            failed_stats["total_failed"] += 1
            failed_stats["not_extracted"] += 1
            num_failed += 1
            print("## failed completion:")
            print('\"\"\"' + completion + '\"\"\"')
            print("extracted code:")
            print(extracted)
            print()
            continue
        
        normalized = normalize_code(extracted, "(define-fun f ((_arg_0 String) (_arg_1 Int)) String ")
        try:
            sexp.loads(normalized)
        except Exception as e:
            failed_stats["total_failed"] += 1
            failed_stats["invalid_syntax"] += 1
            num_failed += 1
            print("## failed completion: (invalid syntax)")
            print('\"\"\"' + completion + '\"\"\"')
            print("### extracted code:")
            print(extracted)
            print()
            print(e)
            print()
            continue
    failed_stats["num_failed_per_task"].append(num_failed)

print("total", failed_stats["total"])
print("total failed", failed_stats["total_failed"])
print("not_extracted", failed_stats["not_extracted"])
print("invalid_syntax", failed_stats["invalid_syntax"])
print()
print("average failed per task", np.mean(failed_stats["num_failed_per_task"]))
print("median failed per task", np.median(failed_stats["num_failed_per_task"]))
print("min failed per task", np.min(failed_stats["num_failed_per_task"]))
print("max failed per task", np.max(failed_stats["num_failed_per_task"]))        

## failed completion: (invalid syntax)
"""(str.len (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace (str.replace 