In [None]:
print("Installing required libraries")
!pip uninstall -y numpy pandas torch torchvision transformers

!pip install --no-cache-dir -q "numpy>=2.0.0" "pandas>=2.2.3" "torch>=2.3.0" \
    "torchvision>=0.18.0" "transformers>=4.42.4" "peft==0.11.1" "accelerate==0.30.1" \
    "trl==0.9.4" "datasets==2.19.2" "bitsandbytes==0.43.1" "numpy_financial"

# RESTART RUNTIME!!!

In [None]:

import json
import os
import torch
import warnings
import textwrap
import ast
import re
import pandas as pd
import numpy as np
from numpy import ma
import re
import pickle
import numpy_financial as npf
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

VALIDATION_DATA_PATH = "validation_data.json"
DETAILED_RESULTS_CSV = "evaluation_detailed_results.csv"
SUMMARY_METRICS_CSV = "evaluation_summary_metrics.csv"

EVAL_GLOBALS = {
    'np': np,
    'ma': ma,
    'os': os,
    'ast': ast,
    'pickle': pickle,
    'npf': npf,
}

HF_BASE_IDS = ["google/gemma-2-2b-it", 
               "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
               "microsoft/phi-2", 
               "Qwen/Qwen1.5-1.8B-Chat",
               "Qwen/Qwen2.5-0.5B"
               ]

HF_FINETUNED_IDS = ["priyam-turakhia/gemma-2-2b-it-numpy-refactor-merged-v1", 
                    "julianpins/tinyllama-1.1b-chat-numpy-refactor-v1", 
                    "priyam-turakhia/phi-2-numpy-modernization-v1",
                    "julianpins/qwen1.5-1.8b-chat-numpy-refactor-merged-v2",#
                    "julianpins/qwen0.5b-numpy-refactor-v1"
                    ]

#select model to evaluate
HF_REPO_ID = HF_FINETUNED_IDS[1]

SYSTEM_PROMPT = (
    "You are a Python code refactoring tool for NumPy. Your task is to replace only the deprecated functions in the given code snippet with their modern equivalents.\n"
    "Your response must be structured with two markdown sections:\n"
    "1. A '### Refactored Code' section containing ONLY the updated Python code block.\n"
    "2. A '### Deprecation Context' section containing a brief explanation of the deprecation.\n"
    "Do not change the code's logic. If no functions are deprecated, return the original code and state that no changes were needed in the context section."
)

def build_user_prompt(sample, tokenizer):
    if "phi" in HF_REPO_ID:
        user_prompt = f"### INPUT CODE:\n```python\n{sample['input']}\n```"
        full_prompt = f"Instruct: {SYSTEM_PROMPT}\n\n{user_prompt}\nOutput:"
        return full_prompt
    
    elif "gemma" in HF_REPO_ID:
        user_content = f"{SYSTEM_PROMPT}\n\n### INPUT CODE:\n```python\n{sample['input']}\n```"
        messages = [{"role": "user", "content": user_content}]
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
    user_prompt = f"### INPUT CODE:\n```python\n{sample['input']}\n```"
    messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


def parse_model_output(raw_output):
    if "gemma" in HF_REPO_ID:
        output = raw_output.split("<start_of_turn>model")[-1]
    elif "phi" in HF_REPO_ID:
        output = raw_output
    else:
        output = raw_output.split("<|im_start|>assistant\n")[-1]
        
    code_match = re.search(r"### Refactored Code\s*```python\n(.*?)\n```", output, re.DOTALL)
    context_match = re.search(r"### Deprecation Context\n(.*?)$", output, re.DOTALL)
    code = code_match.group(1) if code_match else ""
    context = context_match.group(1).strip() if context_match else ""

    if not code and not context:
        return output, "Context not found."
    return code, context


def clean_for_ast(code):
    return "\n".join(
        line.rstrip() for line in code.splitlines()
        if line.rstrip() and not line.strip().startswith("#")
    )

def compare_outputs(actual, expected):
    if isinstance(actual, np.ma.MaskedArray) and actual.ndim == 0: actual = actual.item() if not actual.mask else np.ma.masked
    if isinstance(expected, np.ma.MaskedArray) and expected.ndim == 0: expected = expected.item() if not expected.mask else np.ma.masked
    if actual is np.ma.masked and expected is np.ma.masked: return True
    if isinstance(expected, tuple) and isinstance(actual, tuple):
        if len(expected) != len(actual): return False
        return all(compare_outputs(a, e) for a, e in zip(actual, expected))
    if isinstance(actual, str) and isinstance(expected, str): return actual.replace(" ", "").replace("\n", "") == expected.replace(" ", "").replace("\n", "")
    is_actual_arraylike = isinstance(actual, (np.ndarray, list, tuple))
    is_expected_arraylike = isinstance(expected, (np.ndarray, list, tuple))
    if is_actual_arraylike and is_expected_arraylike:
        try:
            actual_arr, expected_arr = np.asarray(actual), np.asarray(expected)
            if any(np.issubdtype(arr.dtype, np.number) for arr in [actual_arr, expected_arr]): return np.allclose(actual_arr, expected_arr, equal_nan=True)
            return np.array_equal(actual_arr, expected_arr)
        except (ValueError, TypeError): return False
    if is_actual_arraylike != is_expected_arraylike: return False
    return actual == expected

def check_compiles(f, code):
    execution_scope = {}
    try:
        exec(code, EVAL_GLOBALS, execution_scope)
        compiled_function = execution_scope.get(f)
        if not callable(compiled_function):
            raise NameError(f"Function '{f}' was not defined correctly.")
        return True, "OK", compiled_function
    except Exception as e:
        return False, f"SyntaxError: {e}", None

def check_indentation(f, code):
    scope = {}
    try:
        exec(code, EVAL_GLOBALS, scope)
        scope.get(f)
        return True, "OK"
    except Exception as e:
        return False , f"IndentationError: {e}"

def check_no_deprecations(f, input):
    if not callable(f): return False, "Function not callable"
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        try:
            if isinstance(input, tuple): f(*input)
            else: f(input)
        except AttributeError as e:
            if "module 'numpy' has no attribute" in str(e):
                return False, "Output contained deprecated features"
        except Exception as e:
            return False, f"Error during execution, most likely due to deprecation: {e}"

        for item in w:
            if issubclass(item.category, DeprecationWarning):
                return False, "Output contained deprecated features"
    return True, "OK"

def check_functionality(fun, test_cases: list):
    for j, case in enumerate(test_cases):
        try:
            input = eval(case['input'], EVAL_GLOBALS)
            expected_output = eval(case['expected_output'], EVAL_GLOBALS)
            actual_output = fun(*input) if isinstance(input, tuple) else fun(input)
            if not compare_outputs(actual_output, expected_output):
                return False, f"Failed: [Expected: {repr(expected_output)}, Got: {repr(actual_output)}]"
        except Exception as e:
            return False, f"Error during execution of Test Case {j+1}: {e}"
    return True, "All Test Cases passed"

#evaluation
def main():
    print("Starting evaluation script...")

    print(f"Loading model from Hugging Face Hub: {HF_REPO_ID}")
    try:
        pipe = pipeline("text-generation", model=HF_REPO_ID, device_map="auto", torch_dtype=torch.bfloat16)
        tokenizer = pipe.tokenizer
    except Exception as e:
        print(f"Failed to load model: {e}")
        return

    print(f"Loading validation data from: {VALIDATION_DATA_PATH}")

    print(f"VALIDATION_DATA_PATH: {repr(VALIDATION_DATA_PATH)}")
    print("Absolute path:", os.path.abspath(VALIDATION_DATA_PATH))
    print("File exists:", os.path.exists(VALIDATION_DATA_PATH))


    if not os.path.exists(VALIDATION_DATA_PATH):
        print(f"Validation file not found! Please upload '{VALIDATION_DATA_PATH}'.")
        return
    with open(VALIDATION_DATA_PATH, 'r') as f:
        validation_data = json.load(f)

    results_list = []
    total_samples = len(validation_data)

    for i, sample in enumerate(validation_data):
        print(f"\nProcessing sample {i+1}/{total_samples}:")

        prompt = build_user_prompt(sample, tokenizer)
        raw_output = pipe(prompt, max_new_tokens=256, do_sample=False)[0]['generated_text']
        generated_code, _ = parse_model_output(raw_output)
        print(generated_code)
        #MOCK
        #generated_code = sample['output']

        if not generated_code:
            print("Model did not generate code. Skipping.")
            results_list.append({'sample_index': i, 'compiles': 'Fail: No code generated'})
            continue

        dedented_code = textwrap.dedent(generated_code).strip()
        indented_code = textwrap.indent(dedented_code, "    ")
        full_code = sample['code_before'] + "\n" + indented_code + "\n" + sample['code_after']
        full_code_ni = sample['code_before'] + "\n" + generated_code + "\n" + sample['code_after']

        function_name = sample['code_before'].split('def ')[1].split('(')[0].strip()
        #COMPILATION CHECK
        compiles, compiles_msg, compiled_function = check_compiles(function_name, full_code)
        #INDENTATION CHECK
        indentation, indentation_msg = (check_indentation(f, full_code_ni) if compiles else (False, "Skipped"))
        #DEPRECATION CHECK
        test_input = eval(sample['test_cases'][0]['input'], EVAL_GLOBALS)
        no_deprecations, no_deprecations_msg = (check_no_deprecations(compiled_function, test_input) if compiles else (False, "Skipped"))
        #FUNCTIONALITY CHECK
        functionality, functionality_msg = (check_functionality(compiled_function, sample['test_cases']) if compiles and no_deprecations else (False, "Skipped"))

        results_list.append({
            'sample_index': i,
            'compiles': 'Pass' if compiles else f'Fail: {compiles_msg}',
            'correct_indentation': 'Pass' if indentation else f'Fail: {indentation_msg}',
            'no_deprecations': 'Pass' if no_deprecations else f'Fail: {no_deprecations_msg}',
            'correct_functionality': 'Pass' if functionality else f'Fail: {functionality_msg}',
        })

    print("\nEvaluation DONE")

    detailed_df = pd.DataFrame(results_list)
    detailed_df.to_csv(DETAILED_RESULTS_CSV, index=False)
    print("Detailed Results Table:")
    print(detailed_df.to_string())

    metrics = {}
    metrics['total_samples'] = total_samples
    metrics['compiles'] = detailed_df['compiles'].str.startswith('Pass').sum()
    metrics['correct_indentation'] = detailed_df['correct_indentation'].str.startswith('Pass').sum()
    metrics['no_deprecations'] = detailed_df['no_deprecations'].str.startswith('Pass').sum()
    metrics['correct_functionality'] = detailed_df['correct_functionality'].str.startswith('Pass').sum()



    summary_score = (
        metrics['compiles'] +
        metrics['correct_indentation'] +
        metrics['no_deprecations'] +
        (3 * metrics['correct_functionality'])
    )
    metrics['summary_score'] = summary_score

    summary_df = pd.DataFrame([metrics])
    summary_df.to_csv(SUMMARY_METRICS_CSV, index=False)
    print("Summary Metrics:")
    print(summary_df.to_string())

if __name__ == "__main__":
    main()
