In [None]:
# Pipeline Setup 
# %pip install torch torchvision torchaudio transformers sentencepiece accelerate bitsandbytes

In [None]:
import os
import re
import json
import torch
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# Helper Utilities

def safe_name(s: str) -> str:
    s = re.sub(r"[^A-Za-z0-9._\- ]+", "_", s)
    s = re.sub(r"\s+", "_", s)
    return s.strip("_")[:80] or "dependency"

def write_file(path: Path, content: str):
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(content, encoding="utf-8")

In [None]:
# Model Loader

def load_model(model_name: str = "codellama/CodeLlama-7b-Instruct-hf"):
    """Load CodeLlama instruct model for local inference."""
    print(f"[info] Loading model: {model_name}")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        device_map="auto" if device == "cuda" else None,
    )
    print(f"[info] Model loaded successfully on {device}")
    return model, tokenizer



In [None]:
# Load model and tokenizer
model, tokenizer = load_model(model_name="codellama/CodeLlama-7b-Instruct-hf")

In [None]:
# Test Generation Function

def test_generator(
    prompt: str,
    dependency_data: dict,
    model_name: str = "codellama/CodeLlama-7b-Instruct-hf",
    output_root: str = "execution_results",
):
    # create folder/files to save model output later
    timestamp = datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
    dep_safe = safe_name(dependency_data["thirdPartyPackage"])
    folder = Path(output_root) / f"{dep_safe}_{timestamp}"
    folder.mkdir(parents=True, exist_ok=True)

    metadata = {
        "entry_point": dependency_data["entryPoint"],
        "third_party_method": dependency_data["thirdPartyMethod"],
        "third_party_package": dependency_data["thirdPartyPackage"],
        "timestamp": timestamp,
        "model": model_name
    }

    device = model.device

    messages = [
        {"role": "user", "content": prompt},
    ]
    
    inputs = tokenizer.apply_chat_template(
    	messages,
    	add_generation_prompt=False,
    	tokenize=True,
    	return_dict=True,
    	return_tensors="pt",
    ).to(model.device)
    
    print("[info] Generating message...")
    outputs = model.generate(**inputs, max_new_tokens=4096)
    decoded = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]) 
    print("[info] Generation complete. \n")


    # Clean output
    result_text = decoded.replace(prompt, "").strip()
    m = re.search(r"```(?:java)?\s*(.*?)\s*```", result_text, re.DOTALL)
    java_code = m.group(1).strip() if m else result_text

    # Save files
    java_filename = f"{dep_safe}_GeneratedTest.java"
    write_file(folder / "prompt.txt", prompt)
    write_file(folder / java_filename, java_code)
    write_file(folder / "metadata.json", json.dumps(metadata, indent=2))

    print(f"[done] Output written to {folder.resolve()}")
    
    return {
        "folder": str(folder.resolve()),
        "java_file": str((folder / java_filename).resolve()),
        "prompt_file": str((folder / 'prompt.txt').resolve()),
        "metadata_file": str((folder / 'metadata.json').resolve()),
    }


## Sample Usage of The Test Generator on a single example

In [None]:
# You can change any of these to try different prompting techniques, but note 'dependency_data' must remain a dictionary.

llm_instruction = "Generate a complete JUnit 5 test suite that exercises the chain of methods starting from the entry point entryPoint, and ensures that the third-party method thirdPartyMethod from thirdPartyPackage is invoked and tested. Include tests for both successful execution and error scenarios. Use Mockito where appropriate. Output only the Java test code, without any explanations or text outside the code block."

dependency_data_definitions = """
    Dependency Information Meaning:
    - Entry point: path to the method in the project that eventually calls the third-party dependency.
    - Entry point body: the entire code block of the entry point method
    - Third-party method: the external method that should be tested.
    - Third-party package: the package containing the third-party method.
    - Path: the sequence of method calls from entry point to the third-party method.
    - Full methods: the full source code of the relevant methods from the entry point to the dependecy.
    - Method Slices: the code slices of all relevant methods from the entry point to the dependency
"""

# example dependency data
dependency_data = {
    "entryPoint" : "com.graphhopper.util.GHUtility.loadCustomModelFromJar",
    "thirdPartyMethod" : "com.fasterxml.jackson.databind.ObjectMapper.readValue",
    "thirdPartyPackage" : "com.fasterxml.jackson.databind",
    "path" : [ "com.graphhopper.util.GHUtility.loadCustomModelFromJar", "com.fasterxml.jackson.databind.ObjectMapper.readValue" ],
    "fullMethods" : [
      "public static CustomModel loadCustomModelFromJar(String name) {\n        try {\n            InputStream is = GHUtility.class.getResourceAsStream(\"/com/graphhopper/custom_models/\" + name);\n            if (is == null)\n                throw new IllegalArgumentException(\"There is no built-in custom model '\" + name + \"'\");\n            String json = readJSONFileWithoutComments(new InputStreamReader(is));\n            ObjectMapper objectMapper = Jackson.newObjectMapper();\n            return objectMapper.readValue(json, CustomModel.class);\n        } catch (IOException e) {\n            throw new IllegalArgumentException(\"Could not load built-in custom model '\" + name + \"'\", e);\n        }\n    }",
      "/**\n     * Method to deserialize JSON content from given JSON content String.\n     *\n     * @throws StreamReadException if underlying input contains invalid content\n     *    of type {@link JsonParser} supports (JSON for default case)\n     * @throws DatabindException if the input JSON structure does not match structure\n     *   expected for result type (or has other mismatch issues)\n     */\n    public <T> T readValue(String content, Class<T> valueType)\n        throws JsonProcessingException, JsonMappingException\n    {\n        _assertNotNull(\"content\", content);\n        return readValue(content, _typeFactory.constructType(valueType));\n    }"
    ]
}

llm_output_instruction = """
    - Please produce a single Java file (JUnit 5) with imports, test methods, and short comments explaining each test briefly. Use Mockito for mocking InputStream or ObjectMapper where appropriate. Ensure the test class can compile in a typical Maven/Gradle project.
    - Output only valid Java test code inside a ```java``` block.
    - Do not include any explanations or markdown outside the code block and do not repeat the propmt back.
"""

In [None]:
# Concatenate prompt

prompt = f"""
Instructions
{llm_instruction}

Dependency Data Definitions
{dependency_data_definitions}

Dependency Data
{dependency_data}

Output Instructions
{llm_output_instruction}
"""

In [None]:
# Generate test

# if __name__ == "__main__":
#     results = test_generator(
#         prompt,
#         dependency_data
#     )

#     print(json.dumps(results, indent=2))


# Automating test generation for etheo code examples 

 - ensure to upload the example JSON files from etheo github repo

In [None]:
def _build_dependency_data_from_example(example: Dict[str, Any]) -> Dict[str, Any]:
    """
    Normalize example object to the dependency_data shape expected by test_generator.
    Falls back sensibly when fields are missing.
    """
    dep = {}
    dep["entryPoint"] = example.get("entryPoint") or "UNKNOWN_ENTRY_POINT"
    dep["thirdPartyMethod"] = example.get("thirdPartyMethod") or "UNKNOWN_THIRDPARTY_METHOD"
    dep["thirdPartyPackage"] = example.get("thirdPartyPackage") or "UNKNOWN_PACKAGE"
    
    if "fullMethods" in example and isinstance(example["fullMethods"], list):
        dep["fullMethods"] = example["fullMethods"] or []
    elif "methodSlices" in example and isinstance(example["methodSlices"], list):
        dep["methodSlices"] = example.get("methodSlices") or []
    else:
        dep["entryPointBody"] = example.get("entryPointBody") or []
        
    dep["path"] = example.get("path") or []
    
    return dep

In [None]:
def generate_etheo_example_tests(
    examples_dir: str,
    output_root: str = "execution_results_batch",
    model_name: str = "codellama/CodeLlama-7b-Instruct-hf",
    max_examples: int = None,
) -> List[Dict[str, Any]]:
    """
    Reads JSON example files from examples_dir,
    for each example build a prompt using the notebook's prompting pieces and call the
    existing `test_generator` function.
    """
    results = []
    examples_path = Path(examples_dir)
    if not examples_path.exists():
        raise FileNotFoundError(f"Examples directory not found: {examples_dir}")

    json_files = sorted(examples_path.glob("*.json"))
    processed = 0

    for jf in json_files:
        try:
            raw = jf.read_text(encoding="utf-8")
            parsed = json.loads(raw)
            if not isinstance(parsed, list):
                # tolerate a top-level object containing an array under a key
                if isinstance(parsed, dict):
                    # try to find the first list value
                    lists = [v for v in parsed.values() if isinstance(v, list)]
                    parsed = lists[0] if lists else []
            if not parsed:
                print(f"[skip] {jf.name} - no examples found")
                continue
        except Exception as e:
            print(f"[error] Failed to read/parse {jf.name}: {e}")
            continue

        for idx, example in enumerate(parsed):
            if max_examples is not None and processed >= max_examples:
                print("[info] reached max_examples limit")
                return results

            try:
                dependency_data = _build_dependency_data_from_example(example)
                dep_json = json.dumps(dependency_data, indent=2, ensure_ascii=False)
                
                prompt = (
                    f"Instructions\n"
                    f"{llm_instruction}\n\n"
                    f"Dependency Data Definitions\n"
                    f"{dependency_data_definitions}\n\n"
                    f"Dependency Data\n"
                    f"{dep_json}\n\n"
                    f"Output Instructions\n"
                    f"{llm_output_instruction}\n"
                )

                out = test_generator(
                    prompt=prompt,
                    dependency_data=dependency_data,
                    model_name=model_name,
                    output_root=output_root,
                )

                results.append({
                    "source_file": jf.name,
                    "example_index": idx,
                    "entryPoint": dependency_data.get("entryPoint"),
                    "thirdPartyMethod": dependency_data.get("thirdPartyMethod"),
                    "result": out,
                })
                processed += 1
                print(f"[ok] {jf.name} #{idx} -> {dependency_data.get('thirdPartyMethod')} \n")

            except Exception as e:
                print(f"[error] {jf.name} #{idx} failed: {e}")
                continue

    return results

In [None]:
result = generate_etheo_example_tests(
    examples_dir = "/examples",
)