In [None]:
import os
import re
import requests
import json
import ast
import subprocess

# Gemini API Key (Replace with your actual API key)
GEMINI_API_KEY = "AIzaSyDyrTbmBZ5HlcfnLeBhXR0SSaA4g9SBbwY"

# Gemini REST API URL
GEMINI_URL = f"https://generativelanguage.googleapis.com/v1/models/gemini-2.0-flash:generateContent?key={GEMINI_API_KEY}"

# Directory Paths
SRC_DIR = "D:\\original"
TEST_DIR = "D:\\tests"


class CodeAnalyzer:
    """Analyzes Python code to extract detailed information for better test generation."""

    def init(self, code):
        self.code = code
        self.tree = ast.parse(code)

    def get_functions_with_params(self):
        """Extract functions with their parameters and return types if available."""
        functions = []

        for node in ast.walk(self.tree):
            if isinstance(node, ast.FunctionDef):
                params = []
                for arg in node.args.args:
                    if hasattr(arg, 'annotation') and arg.annotation:
                        # Try to get type annotation if available
                        if isinstance(arg.annotation, ast.Name):
                            param_type = arg.annotation.id
                        else:
                            param_type = "unknown"
                        params.append((arg.arg, param_type))
                    else:
                        params.append((arg.arg, "unknown"))

                # Check for return type annotation
                return_type = "unknown"
                if node.returns:
                    if isinstance(node.returns, ast.Name):
                        return_type = node.returns.id

                functions.append({
                    "name": node.name,
                    "params": params,
                    "return_type": return_type,
                    "docstring": ast.get_docstring(node) or "No docstring available"
                })

        return functions

    def get_classes_with_methods(self):
        """Extract classes with their methods."""
        classes = []

        for node in ast.walk(self.tree):
            if isinstance(node, ast.ClassDef):
                methods = []

                for item in node.body:
                    if isinstance(item, ast.FunctionDef):
                        methods.append(item.name)

                classes.append({
                    "name": node.name,
                    "methods": methods,
                    "docstring": ast.get_docstring(node) or "No docstring available"
                })

        return classes


def get_python_files(directory):
    """Get all Python files from the specified directory, excluding init.py."""
    return [f for f in os.listdir(directory) if f.endswith(".py") and f != "init.py"]


def read_file(filepath):
    """Read the content of a file."""
    with open(filepath, "r", encoding="utf-8") as file:
        return file.read()


def write_test_file(module_name, test_code):
    """Write generated test cases to a file with correct imports and sys.path setup."""
    os.makedirs(TEST_DIR, exist_ok=True)
    test_filepath = os.path.join(TEST_DIR, f"test_{module_name}.py")

    # Prepend sys.path.append to the test code
    sys_path_code = (
        "import sys\n"
        "import os\n"
        "sys.path.append(os.path.abspath(\"D:/\"))\n\n"
    )

    # Force-correct the import line
    import_statement = f"from original.{module_name} import *\n\n"

    # Remove any incorrect import lines from test_code (e.g., from id_11 import ...)
    test_code_lines = test_code.splitlines()
    cleaned_code_lines = [
        line for line in test_code_lines
        if not re.match(r"^\s*from\s+\S+\s+import\s+", line)
    ]

    cleaned_code = "\n".join(cleaned_code_lines)

    with open(test_filepath, "w", encoding="utf-8") as file:
        file.write(sys_path_code + import_statement + cleaned_code)

    print(f"✅ Generated test file: {test_filepath}")
    return test_filepath




def call_gemini_api(prompt):
    """Call Gemini API using REST URL to generate pytest test cases."""
    headers = {
        "Content-Type": "application/json"
    }
    payload = {
        "contents": [
            {
                "parts": [
                    {"text": prompt}
                ]
            }
        ],
        "generationConfig": {
            "temperature": 0.2,  # Lower temperature for more focused outputs
            "maxOutputTokens": 8192  # Ensure we get enough detailed test code
        }
    }

    try:
        response = requests.post(GEMINI_URL, headers=headers, data=json.dumps(payload))
        response.raise_for_status()
        response_json = response.json()

        # Extract generated text
        text = response_json["candidates"][0]["content"]["parts"][0]["text"]
        return text.strip()
    except Exception as e:
        print(f"❌ Gemini API error: {e}")
        return None


def create_enhanced_prompt(python_code, module_name, code_analyzer):
    """Create an enhanced prompt with detailed code analysis for better test generation."""
    functions = code_analyzer.get_functions_with_params()
    classes = code_analyzer.get_classes_with_methods()

    function_details = ""
    for func in functions:
        params_str = ", ".join([f"{p[0]}: {p[1]}" for p in func["params"]])
        function_details += f"- Function: {func['name']}({params_str}) -> {func['return_type']}\n"
        function_details += f"  Docstring: {func['docstring']}\n"

    class_details = ""
    for cls in classes:
        class_details += f"- Class: {cls['name']}\n"
        class_details += f"  Methods: {', '.join(cls['methods'])}\n"
        class_details += f"  Docstring: {cls['docstring']}\n"

    prompt = f"""
    Generate comprehensive pytest unit tests for the following Python module '{module_name}.py':

    {python_code}

    Code analysis:
    Functions:
    {function_details}

    Classes:
    {class_details}

    Requirements for the tests:
    1. Create at least 3 test cases for each function and method
    2. Test edge cases (empty inputs, None values, boundary values, etc.)
    3. Include tests for expected exceptions
    4. Aim for maximum code coverage (at least 85%)
    5. Use pytest.parametrize for testing multiple inputs
    6. Use mocks where appropriate (for external dependencies)
    7. Provide tests for all branches in conditional logic
    8. Include appropriate assertions that verify expected behavior

    Output only valid Python code.
    Do not include any explanations, comments outside of code, or extra text.
    Do not wrap the code in triple backticks or markdown syntax.
    """

    return prompt


def extract_test_code(response_text):
    """Extract and clean test code from the API response."""
    if not response_text:
        return None

    # Remove markdown code blocks if present
    cleaned = re.sub(r"python|", "", response_text).strip()

    # Check if there's any non-code text and attempt to extract just the code
    if "import" not in cleaned[:200]:
        # Try to find the start of the actual code
        import_match = re.search(r"import\s+\w+", cleaned)
        if import_match:
            start_idx = import_match.start()
            cleaned = cleaned[start_idx:]

    # Remove any trailing explanation text
    if "" in cleaned:
        cleaned = cleaned.split("")[0].strip()

    return cleaned


# Generate and write test cases
python_files = get_python_files(SRC_DIR)
total_files = len(python_files)
successful_files = 0

# Create or clean the test directory
os.makedirs(TEST_DIR, exist_ok=True)

print(f"🔍 Found {total_files} Python files to generate tests for\n")

for python_file in python_files:
    module_name = os.path.splitext(python_file)[0]
    module_path = os.path.join(SRC_DIR, python_file)
    python_code = read_file(module_path)

    try:
        print(f"📝 Analyzing {module_name}.py...")
        analyzer = CodeAnalyzer(python_code)

        # Generate enhanced prompt
        prompt = create_enhanced_prompt(python_code, module_name, analyzer)

        print(f"🤖 Generating tests via Gemini API...")
        test_code = call_gemini_api(prompt)

        if test_code:
            # Clean up the response
            cleaned = extract_test_code(test_code)

            # Write test file
            test_file = write_test_file(module_name, cleaned)
            print(f"✓ Test file for {module_name}.py generated successfully\n")
            successful_files += 1
        else:
            print(f"❌ Failed to generate tests for {module_name}.py\n")

    except Exception as e:
        print(f"❌ Error processing {module_name}.py: {str(e)}\n")

# Display overall results
print(f"\n✅ Test generation completed!")
print(f"✓ Successfully generated tests for {successful_files}/{total_files} files")
print(f"📁 Test files saved to: {TEST_DIR}")
print(f"\nTo run tests with coverage, use: pytest {TEST_DIR} --cov={SRC_DIR}\n")