In [133]:
from dataclasses import dataclass
from typing import Optional


@dataclass
class ProjName:
    author: str
    repo_name: str
    module: Optional[str] = None


@dataclass
class EvalDataPoint:
    project: ProjName
    classPath: str
    className: str
    testName: str
    testMethod: str
    oracle: str
    focalFile: str
    focalName: str
    focalMethod: str
    commitid: str

    def __post_init__(self):
        self.project = ProjName(*self.project)


In [134]:
import json

data = []

with open("./assets/eval_data.jsonl", "r") as f:
    for line in f:
        print(line.strip())
        datapoint = json.loads(line.strip())
        data.append(EvalDataPoint(**datapoint))

{"project": ["orientechnologies", "orientdb", "server"], "classPath": "server/src/test/java/com/orientechnologies/orient/server/OClientConnectionTest.java", "className": "OClientConnectionTest", "testName": "testValidToken", "testMethod": "@Test\n public void testValidToken ( ) throws IOException {\n     OClientConnection conn = new OClientConnection ( 1 , protocol ) ; \n     OTokenHandler handler = new OTokenHandlerImpl ( server ) ; \n     byte [ ] tokenBytes = handler . getSignedBinaryToken ( db , db . getUser ( ) , conn . getData ( ) ) ; \n     conn . validateSession ( tokenBytes , handler , null ) ; \n     assertTrue(conn.getTokenBased());\n}", "oracle": "assertTrue(conn.getTokenBased());", "focalFile": "server/src/main/java/com/orientechnologies/orient/server/OClientConnection.java", "focalName": "getTokenBased", "focalMethod": " public Boolean getTokenBased ( ) { \n return tokenBased ; \n } ", "commitid": "2cabb46c9581572b7f46724864f02d9c688070c5"}
{"project": ["orientechnologies

In [135]:
from collections import Counter

print(len(data))

test_name_counts = Counter(dp.testName for dp in data)
duplicates = {name: count for name, count in test_name_counts.items() if count > 1}

print("Duplicate test names and their counts:")
for name, count in duplicates.items():
    print(f"{name}: {count}")

456
Duplicate test names and their counts:
testSetAttachments: 2
testGetSetLog: 4
testToString: 2
testSetVerb: 2
testEmpty: 2
testHashCode: 2
testGetName: 2
simple: 3
testSerialize: 2
testGet: 2
testExtends: 2
testSimpleEmbeddedDoc: 2


In [136]:
import textwrap


def format_java_package(project: ProjName, classPath: str) -> str:
    package = f"{project.author}.{project.repo_name}"
    if project.module:
        package += f".{project.module}"
    return package


def format_java_focal_class(data: EvalDataPoint) -> str:
    package_name = format_java_package(data.project, data.classPath)
    method_code = textwrap.indent(textwrap.dedent(data.focalMethod).strip(), "    ")
    focal_class = f"""package {package_name};

public class {data.className.replace("Test", "")} {{
{method_code}
}}"""
    return focal_class


def format_java_test_class(data: EvalDataPoint) -> str:
    package_name = format_java_package(data.project, data.classPath)
    dedented_method = textwrap.dedent(data.testMethod)
    oracle_str = data.oracle.strip()
    processed_lines = []
    for line in dedented_method.splitlines():
        if line.strip() == oracle_str:
            indent_prefix = line[: len(line) - len(line.lstrip())]
            processed_lines.append(f"{indent_prefix}<ASSERTIONS HERE>")
        else:
            processed_lines.append(line)
    method_processed = "\n".join(processed_lines).strip()
    method_code = textwrap.indent(method_processed, "    ")
    test_class = f"""package {package_name};

public class {data.className} {{
{method_code}
}}"""
    return test_class


In [137]:
from typing import List, Tuple
from difflib import SequenceMatcher


def parse_generated_assertion_list(assertions: str) -> List[str]:
    return [
        assertion.strip() for assertion in assertions.split(";") if assertion.strip()
    ]


def compare_assertions(
    assertions: List[str], reference: str
) -> List[Tuple[str, float]]:
    similarities = []

    for assertion in assertions:
        similarity_score = SequenceMatcher(
            None,
            reference.replace(";", "").replace("Assert.", "").replace(" ", ""),
            assertion.replace(";", "").replace("Assert.", "").replace(" ", ""),
        ).ratio()
        similarities.append((assertion, similarity_score))

    similarities.sort(key=lambda x: x[1], reverse=True)

    return similarities

In [138]:
def generate_prompt(data: EvalDataPoint) -> str:
    focal_class_code = format_java_focal_class(data)
    test_class_code = format_java_test_class(data)

    prompt = (
        "Below is the focal class and test class for reference:\n\n"
        f"{focal_class_code}\n\n"
        f"{test_class_code}\n\n"
        f"Please generate 20 assertions to be inserted in place of <ASSERTIONS HERE> in the test method. "
        "Only generate code (no comments or explanations)."
    )
    return prompt

In [139]:
MODEL = "gpt-35-turbo"
TEMP = 0.7
# available: "gpt-35-turbo", "gpt-4o-mini"
# temps to test: 0.0, 0.2, 0.7

In [140]:
import requests


def query_gpt(prompt: str, subscription_key: str) -> dict:
    url = (
        f"https://hkust.azure-api.net/openai/deployments/{MODEL}/chat/completions"
        "?api-version=2023-05-15&subscription-key=" + subscription_key
    )
    headers = {"Content-Type": "application/json"}

    payload = {
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 500,
        "temperature": TEMP,
    }

    response = requests.post(url, headers=headers, json=payload)
    response.raise_for_status()
    return response.json()

In [141]:
from dotenv import load_dotenv
import os

load_dotenv()
AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY")

In [142]:
import os

log_folder = f"./assets/logs/{MODEL}/{TEMP}/"
os.makedirs(log_folder, exist_ok=True)

In [143]:
import glob
import logging
from datetime import datetime

for i in data:
    pattern = os.path.join(log_folder, f"*_{i.testName}.log")
    if glob.glob(pattern):
        print(f"Skipping {i.testName}")
        continue

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_folder, f"{timestamp}_{i.testName}.log")

    logger = logging.getLogger(i.testName)
    logger.setLevel(logging.INFO)
    logger.handlers = []

    file_handler = logging.FileHandler(log_file, encoding="utf-8")
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter("%(message)s")
    file_handler.setFormatter(formatter)

    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

    prompt = generate_prompt(i)
    logger.info("=== Generated Prompt ===")
    logger.info(prompt)
    logger.info("========================\n")

    try:
        response = query_gpt(prompt, AZURE_OPENAI_KEY)
        message = response.get("choices", [{}])[0].get("message", {}).get("content", "")
        logger.info("=== GPT Response ===")
        logger.info(message)
        logger.info("====================\n")
        logger.info("=== Collecting and Comparing Assertions ===")
        logger.info(f"=== Ground Oracle Truth: {i.oracle} ===")

        assertions = parse_generated_assertion_list(message)
        similarities = compare_assertions(assertions, i.oracle)
        for assertion, score in similarities:
            logger.info(f"{score:.3f} - {assertion}")
    except Exception as e:
        logger.error("Error during API call: %s", e)
        break

Skipping testValidToken
Skipping testProcessChangeEventAddKeyWithConversion
Skipping testGetTypes
Skipping testQueryEmbedded
Skipping testSerializeWALChanges
Skipping nestedLinkSet
Skipping testMoveSingleRecordToAnotherClass
Skipping testSetupWithIncompleteJob
Skipping testGetSetScheduler
Skipping testNullIsEmpty
Skipping sum_of_cell_distances_is_distance_to_goal
Skipping testLoad_justKey
Skipping testReconnectionByUserSetting
Skipping testGeneratedConfigFiles
Skipping testAddressUnderscoreSeparator
Skipping testAddressCompareToEqual
Skipping testGetActivityId
Skipping tuStum
Skipping testMergeIfAbsent
Skipping testCancelReConnection
Skipping testMultiThread
Skipping testSetAttachments
Skipping testTerminated
Skipping testSetResponse
Skipping testDefaultValueFromJson
Skipping testWithSets
Skipping testLazyExecutionPlanning
Skipping testSetCustom
Skipping testMixCompositeQuery
Skipping testConversion
Skipping testAllAllowed
Skipping testAllForbidden
Skipping testFilterDocumentWithMetada

=== Generated Prompt ===


Skipping testCustomDepth
Skipping testGetDigestSha1TwoByteArrays
Skipping testAnchorFilterFalse
Skipping testRelativeCanonicalURL
Skipping testCreateValidUrl
Skipping testGetPolledObjects
Skipping testPrime
Skipping testParseNetworks
Skipping testTriangle2
Skipping testLinkCollectionOther
Skipping testRemoveNotificationOne
Skipping testRidSelect
Skipping testTxLinked
Skipping testFormatMap
Skipping testCommaList
Skipping testDefaultValues
Skipping has_yaml_representation
Skipping testSameChars
Skipping doTestRandomPortConstructor
Skipping testLoad_emptyTag
Skipping testSerialize
Skipping fixedLengthContentIsEncodedProperly
Skipping testDecodeCustomResponse
Skipping should_return_newer_tag
Skipping validLoggerReturned_whenOpenThenCloseThenOpenWithSameParameters
Skipping testGetSetConnectionFactoryDelegate
Skipping testGetSetAttacher
Skipping testGetStepAfterShutdown
Skipping testActivityDefinitionHashMapOfStringStringHashMapOfStringString
Skipping testGetDuration
Skipping testGetXapiSta

Below is the focal class and test class for reference:

package javadev.moneytostr-russian;

public class MoneyToStr {
    public String convert ( Double theMoney ) { 
    if ( theMoney == null ) { 
    throw new IllegalArgumentException ( "STR" ) ; 
    } Long intPart = theMoney . longValue ( ) ; 
    Long fractPart = Math . round ( ( theMoney - intPart ) * NUM100 ) ; 
    if ( currency == Currency . PER1000 ) { 
    fractPart = Math . round ( ( theMoney - intPart ) * NUM1000 ) ; 
    } return convert ( intPart , fractPart ) ; 
    }
}

package javadev.moneytostr-russian;

public class MoneyToStrTest {
    @Test
     public void checkPER100 ( ) {
         <ASSERTIONS HERE>
    }
}

Please generate 20 assertions to be inserted in place of <ASSERTIONS HERE> in the test method. Only generate code (no comments or explanations).

=== GPT Response ===
assertEquals("сто двадцать три рубля 45 копеек", moneyToStr.convert(123.45));
assertEquals("двести рублей 0 копеек", moneyToStr.convert(200.0

Skipping testSerialize


=== GPT Response ===
Assertions:

1. assertTrue(v.getId() instanceof ORID)
2. assertNotNull(v.getId())
3. assertEquals("STR", v.getProperty("STR"))
4. assertTrue(graph.getVertices().contains(v))
5. assertTrue(v.getPropertyKeys().contains("STR"))
6. assertFalse(v.getPropertyKeys().contains("INT"))
7. assertNotNull(v.getProperty("STR"))
8. assertNull(v.getProperty("INT"))
9. assertNotNull(graph.getVertex(v.getId()))
10. assertTrue(graph.getVertices().iterator().hasNext())
11. assertEquals(1, IteratorUtils.count(graph.getVertices().iterator()))
12. assertTrue(graph.getVertices().iterator().next().equals(v))
13. assertTrue(graph.getVertex(v.getId()).equals(v))
14. assertTrue(graph.getVertex(v.getId()).getPropertyKeys().contains("STR"))
15. assertEquals("STR", graph.getVertex(v.getId()).getProperty("STR"))
16. assertTrue(graph.getVertex(v.getId()).getPropertyKeys().contains("STR"))
17. assertFalse(graph.getVertex(v.getId()).getPropertyKeys().contains("INT"))
18. assertTrue(graph.getVertex(v

Skipping testExtends


=== GPT Response ===
Assert.assertNotNull(protocol0);
Assert.assertNotNull(protoocl1);
Assert.assertNotSame(protocol0, protoocl1);
Assert.assertEquals(protocol0.getClass(), protoocl1.getClass());
Assert.assertTrue(protoocl1 instanceof IProtocol);
Assert.assertNotEquals(protocol0.hashCode(), protoocl1.hashCode());
Assert.assertNotEquals(protocol0.toString(), protoocl1.toString());
Assert.assertSame(protocol0.copyInstance(), protoocl1.copyInstance());
Assert.assertNotEquals(protocol0, protoocl1);
Assert.assertFalse(protocol0.equals(protoocl1));
Assert.assertTrue(protoocl1.equals(protoocl1));
Assert.assertTrue(protoocl1.equals(protocol0));
Assert.assertNotEquals(protocol0, null);
Assert.assertNotEquals(protoocl1, null);
Assert.assertNotEquals(protocol0, "");
Assert.assertNotEquals(protoocl1, "");
Assert.assertNotEquals(protocol0, new Object());
Assert.assertNotEquals(protoocl1, new Object());
Assert.assertEquals(protocol0.copyInstance(), protoocl1.copyInstance());
Assert.assertEquals(prot

Skipping testGet


=== GPT Response ===
assertEquals("STR", parser.parse(sample1, true).getFieldValue("fieldName"));
assertEquals(Locale.ENGLISH, parser.parse(sample1, true).getLocale());
assertEquals(CEFVersion.V1_0, parser.parse(sample1, true).getVersion());
assertEquals("CEF", parser.parse(sample1, true).getVendor());
assertEquals("Event", parser.parse(sample1, true).getSignatureID());
assertEquals("1", parser.parse(sample1, true).getSeverity());
assertEquals("UNKNOWN", parser.parse(sample1, true).getExtension("extension1"));
assertEquals("UNKNOWN", parser.parse(sample1, true).getExtension("extension2"));
assertEquals("UNKNOWN", parser.parse(sample1, true).getExtension("extension3"));
assertEquals("UNKNOWN", parser.parse(sample1, true).getExtension("extension4"));
assertEquals("UNKNOWN", parser.parse(sample1, true).getExtension("extension5"));
assertEquals("UNKNOWN", parser.parse(sample1, true).getExtension("extension6"));
assertEquals("UNKNOWN", parser.parse(sample1, true).getExtension("extension7"))

Skipping testGetName


=== GPT Response ===
assertEquals("123", travis.getPullRequest());
assertNotEquals("456", travis.getPullRequest());
assertNotNull(travis.getPullRequest());
assertNull(travis.getPullRequest());
assertTrue(travis.getPullRequest().contains("pull"));
assertFalse(travis.getPullRequest().isEmpty());
assertSame("abc", travis.getPullRequest());
assertNotSame("def", travis.getPullRequest());
assertEquals("http://github.com/pull/123", travis.getPullRequest());
assertTrue(travis.getPullRequest().matches("[0-9]+"));
assertFalse(travis.getPullRequest().startsWith("PR"));
assertTrue(travis.getPullRequest().endsWith("123"));
assertThat(travis.getPullRequest(), is("123"));
assertThat(travis.getPullRequest(), not("456"));
assertThat(travis.getPullRequest(), containsString("pull"));
assertThat(travis.getPullRequest(), notNullValue());
assertThat(travis.getPullRequest(), instanceOf(String.class));
assertThat(travis.getPullRequest(), isA(String.class));
assertThat(travis.getPullRequest(), anyOf(is("123"),

Skipping testGetSetLog


=== GPT Response ===
assertEquals(1, result.size());
assertTrue(result.contains("STR"));
assertFalse(result.contains("str"));
assertEquals(1, a.size());
assertEquals(1, b.size());
assertEquals(1, c.size());
assertTrue(result.containsAll(a));
assertTrue(result.containsAll(b));
assertTrue(result.containsAll(c));
assertFalse(result.contains("abc"));
assertFalse(result.contains("123"));
assertFalse(result.contains(null));
assertTrue(result.contains("STR"));
assertTrue(result.contains("str"));
assertTrue(result.contains("StR"));
assertEquals(3, result.size());
assertEquals(3, a.size());
assertEquals(3, b.size());
assertEquals(3, c.size());
assertTrue(result.containsAll(a));
assertTrue(result.containsAll(b));
assertTrue(result.containsAll(c));

=== Collecting and Comparing Assertions ===
=== Ground Oracle Truth: assertEquals(5,result.size()); ===
0.966 - assertEquals(1, result.size())
0.966 - assertEquals(3, result.size())
0.830 - assertEquals(1, a.size())
0.830 - assertEquals(1, b.size())
0

Skipping testGetSetLog
Skipping simple
Skipping testEmpty


=== GPT Response ===
Assertions:

assertEquals(false, graph.hasCycles());
graph.addEdge(edge("STR", "STR"));
assertEquals(true, graph.hasCycles());
graph.removeEdge(edge("STR", "STR"));
assertEquals(false, graph.hasCycles());
graph.addEdge(edge("STR", "NEW"));
assertEquals(false, graph.hasCycles());
graph.addEdge(edge("NEW", "STR"));
assertEquals(true, graph.hasCycles());
graph.removeEdge(edge("NEW", "STR"));
assertEquals(false, graph.hasCycles());
graph.addEdge(edge("NEW", "NEW"));
assertEquals(false, graph.hasCycles());
graph.removeEdge(edge("NEW", "NEW"));
assertEquals(false, graph.hasCycles());
graph.addEdge(edge("STR", "STR"));
graph.addEdge(edge("STR", "NEW"));
assertEquals(true, graph.hasCycles());
graph.clear();
assertEquals(false, graph.hasCycles());
graph.addEdge(edge("A", "B"));
graph.addEdge(edge("B", "C"));
graph.addEdge(edge("C", "A"));
assertEquals(true, graph.hasCycles());
graph.clear();
assertEquals(false, graph.hasCycles());

=== Collecting and Comparing Assertions ==

Skipping testToString


=== GPT Response ===
assertEquals(8, b.length());
assertTrue(b.get(0));
assertFalse(b.get(1));
assertFalse(b.get(2));
assertFalse(b.get(3));
assertFalse(b.get(4));
assertFalse(b.get(5));
assertFalse(b.get(6));
assertFalse(b.get(7));
assertEquals("01010011", CommUtil.hexdump(bs.getBytes()));
assertNotNull(b);
assertNotNull(CommUtil.hexdump(bs.getBytes()));
assertEquals("0000 0001", CommUtil.hexdump(bs.getBytes()));
assertNotEquals(0, b.hashCode());
assertTrue(CommUtil.hexdump(bs.getBytes()).contains("53"));
assertFalse(CommUtil.hexdump(bs.getBytes()).isEmpty());
assertNotNull(b.toString());
assertEquals(64, b.size());
assertTrue(b.isEmpty());
assertFalse(CommUtil.hexdump(bs.getBytes()).equals("01001101"));
assertEquals(CommUtil.hexdump(bs.getBytes()), CommUtil.hexdump(bs.getBytes()));
assertFalse(b.isEmpty());
assertFalse(CommUtil.hexdump(bs.getBytes()).equals("01010000"));

=== Collecting and Comparing Assertions ===
=== Ground Oracle Truth: Assert.assertEquals("STR",CommUtil.hexdump(C

Skipping testSimpleEmbeddedDoc


=== GPT Response ===
assertNotNull(qResult);
assertFalse(qResult.isEmpty());
assertEquals(1, qResult.size());
assertTrue(qResult.hasNext());
assertEquals("expected result", qResult.next().getProperty("propertyName"));
assertNotNull(qResult.getProperty("propertyName"));
assertNotNull(qResult.getRecord());
assertTrue(qResult.getPropertyNames().contains("propertyName"));
assertTrue(qResult.hasNext());
assertEquals("expected value", qResult.getProperty("propertyName"));
assertTrue(qResult.next().getPropertyNames().contains("propertyName"));
assertEquals("expected value", qResult.next().getProperty("propertyName"));
assertTrue(qResult.hasNext());
assertEquals("expected value", qResult.getProperty("propertyName"));
assertTrue(qResult.getPropertyNames().contains("propertyName"));
assertFalse(qResult.hasNext());
assertNull(qResult.next());
assertEquals(0, qResult.size());
assertTrue(qResult.getPropertyNames().contains("propertyName"));
assertFalse(qResult.isEmpty());

=== Collecting and Compar

Skipping testSetVerb


=== GPT Response ===
Assert.assertNotNull(actual);
Assert.assertEquals(score, actual);
Assert.assertTrue(score == actual);
Assert.assertNotEquals(null, actual);
Assert.assertNotSame(score, actual);
Assert.assertSame(score, actual);
Assert.assertTrue(actual instanceof Score);
Assert.assertNotEquals(null, actual);
Assert.assertTrue(score.equals(actual));
Assert.assertNotEquals(score, null);
Assert.assertNotEquals(actual, null);
Assert.assertNotEquals(result.getScore(), null);
Assert.assertNotNull(result.getScore());
Assert.assertSame(result.getScore(), actual);
Assert.assertEquals(score.hashCode(), actual.hashCode());
Assert.assertNotSame(score.hashCode(), actual.hashCode());
Assert.assertTrue(result.getScore().equals(actual));
Assert.assertFalse(score == null);
Assert.assertTrue(actual != null);
Assert.assertNotEquals(result.getScore(), actual);

=== Collecting and Comparing Assertions ===
=== Ground Oracle Truth: assertNotNull(actual); ===
1.000 - Assert.assertNotNull(actual)
0.776 - A

Skipping testHashCode


=== GPT Response ===
Assertions for the testNonExistentPropertiesParsing method:

1. assertNotNull(props);
2. assertTrue(props.isEmpty());
3. assertEquals(0, props.size());
4. assertFalse(props.containsKey("key"));
5. assertNull(props.getProperty("key"));
6. assertEquals("default", props.getProperty("key", "default"));
7. assertFalse(props.containsKey("anotherKey"));
8. assertNull(props.getProperty("anotherKey"));
9. assertEquals("default", props.getProperty("anotherKey", "default"));
10. assertFalse(props.containsKey("yetAnotherKey"));
11. assertNull(props.getProperty("yetAnotherKey"));
12. assertEquals("default", props.getProperty("yetAnotherKey", "default"));
13. assertFalse(props.containsKey("lastKey"));
14. assertNull(props.getProperty("lastKey"));
15. assertEquals("default", props.getProperty("lastKey", "default"));
16. assertFalse(props.containsKey("testKey"));
17. assertNull(props.getProperty("testKey"));
18. assertEquals("default", props.getProperty("testKey", "default"));
19.