In [None]:
import torch
import gc
from vllm import LLM, SamplingParams
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
import re
from collections import Counter
import numpy as np
import time
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
N = 3  # Number of candidates per problem
M = 2  # Depth of generation
BATCH_SIZE = 4  # Batch size for vLLM
TIME_LIMIT = 18 * 3600  # 18 hours limit
TEST_SIZE = 1  # Number of GSM8K problems to test

# Initialize vLLM
MODEL_PATH = "/nlp_group/decapoda-research/Llama-2-7b-chat-hf"
llm = LLM(model=MODEL_PATH, tensor_parallel_size=1)  # Using all 8 GPUs
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.95,
    max_tokens=1024,
    stop=["```", "```output", "```python", "```\nOutput", ")\n```"]
)

def generate_prompt(problem):
    return f"""Solve this math problem using Python:

{problem}

Provide your solution as a Python function named 'solve_problem' that takes no arguments and returns the answer. Store the final answer in a variable named 'answer'. Enclose your code in triple backticks with 'python' specified, like this:

```python
def solve_problem():
    # Your code here
    answer = # Final calculated answer
    return answer
```

Do not include any explanations or additional text outside the code block.
"""

def extract_code(completion):
    code_block = re.search(r'```python\s*(.*?)\s*```', completion, re.DOTALL)
    if code_block:
        return code_block.group(1).strip()
    return None

def execute_code(code):
    try:
        # Add a print statement to ensure the final answer is output
        code += "\nprint(f'The final answer is {solve_problem()}')"
        
        # Create a dictionary to store local variables
        local_vars = {}
        
        # Execute the code in a restricted environment
        exec(code, {'__builtins__': {'print': print, 'int': int, 'float': float}}, local_vars)
        
        # Extract the answer from the local variables
        if 'answer' in local_vars:
            return local_vars['answer']
        else:
            return None
    except Exception as e:
        logging.error(f"Error executing code: {e}")
        return None

def process_batch(problems):
    prompts = [generate_prompt(problem) for problem in problems]
    all_candidates = []

    for _ in range(M):  # M attempts for each problem
        outputs = llm.generate(prompts, sampling_params)
        completions = [output.outputs[0].text for output in outputs]
        
        candidates = []
        for completion in completions:
            code = extract_code(completion)
            if code:
                result = execute_code(code)
                candidates.append(result)
            else:
                candidates.append(None)
        
        all_candidates.append(candidates)
        
        # Update prompts for problems that didn't get a valid result
        prompts = [prompt + "\n\nYour previous attempt was incorrect or incomplete. Please try again."
                   for prompt, candidate in zip(prompts, candidates) if candidate is None]
        
        if not prompts:  # If all problems got valid results, break
            break

    # Transpose all_candidates to group by problem
    return list(map(list, zip(*all_candidates)))

def evaluate_gsm8k():
    dataset = load_dataset("gsm8k", "main")
    test_data = dataset["test"].select(range(TEST_SIZE))
    
    correct = 0
    total = 0
    start_time = time.time()

    for i in range(0, len(test_data), BATCH_SIZE):
        if time.time() - start_time > TIME_LIMIT:
            logging.info("Time limit reached. Stopping evaluation.")
            break

        batch = test_data[i:i+BATCH_SIZE]
        problems = [item['question'] for item in batch]
        true_answers = [int(item['answer'].split()[-1]) for item in batch]

        all_candidates = process_batch(problems)

        for candidates, true_answer in zip(all_candidates, true_answers):
            candidates = [c for c in candidates if c is not None]
            if candidates:
                predicted_answer = Counter(candidates).most_common(1)[0][0]
                if predicted_answer == true_answer:
                    correct += 1
            total += 1

        logging.info(f"Processed {total}/{TEST_SIZE} problems. Current accuracy: {correct}/{total} = {correct/total:.2%}")

    final_accuracy = correct / total
    logging.info(f"Final Accuracy: {correct}/{total} = {final_accuracy:.2%}")
    return final_accuracy

# Main execution
if __name__ == "__main__":
    try:
        accuracy = evaluate_gsm8k()
        print(f"Evaluation completed. Final accuracy: {accuracy:.2%}")
    except Exception as e:
        logging.error(f"An error occurred during evaluation: {e}")