In [None]:
import json
import base64
import requests
import time
import random
import mimetypes
import pandas as pd
import traceback
from typing import List, Dict, Any
import os
import re

In [None]:
# Constants
NUM_RUNS = 10
MAX_RETRIES = 3
RETRY_DELAY = 2
API_URL = "http://localhost:11434/api/chat"
TEMPERATURE = 0
MAX_TOKENS = 300

In [None]:
VLAT_PROMPT = """I am about to show you an image and ask you a multiple choice question about that image. 
Please structure your response in the following format:

Answer: [Enter the exact text of your chosen option]
Explanation: [Provide your reasoning]

Select the BEST answer, based only on the chart and not external knowledge. DO NOT GUESS.
If you are not sure about your answer or your answer is based on a guess, select "Omit".
Choose your answer ONLY from the provided options."""

CALVI_PROMPT = """I am about to show you an image and ask you a multiple choice question about that image. 
Please structure your response in the following format:

Answer: [Enter the exact text of your chosen option(s)]
Explanation: [Provide your reasoning]

Select the BEST answer, based only on the chart and not external knowledge.
Choose your answer ONLY from the provided options."""

In [None]:
def load_questions(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data['questions']

def encode_image(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except Exception as e:
        print(f"Error encoding image {image_path}: {str(e)}")
        raise

def query_llama_with_retry(prompt, image_path):
    base64_image = encode_image(image_path)
    mime_type = mimetypes.guess_type(image_path)[0] or 'image/png'
    
    for attempt in range(MAX_RETRIES):
        try:
            data = {
                "model": "llama3.2-vision",
                "messages": [{
                    "role": "user",
                    "content": prompt,
                    "images": [base64_image]
                }],
                "stream": False,
                "temperature": TEMPERATURE,
                "max_tokens": MAX_TOKENS
            }
            
            headers = {"Content-Type": "application/json"}
            response = requests.post(API_URL, headers=headers, json=data)
            
            if response.status_code == 200:
                return response.json()["message"]["content"].strip()
            else:
                print(f"Error: Received status code {response.status_code}")
                print(f"Response content: {response.text}")
        except Exception as e:
            print(f"Error occurred (Attempt {attempt + 1}/{MAX_RETRIES}): {str(e)}")
            if attempt < MAX_RETRIES - 1:
                print(f"Retrying in {RETRY_DELAY} seconds...")
                time.sleep(RETRY_DELAY)
            else:
                print("Max retries reached. Skipping this question.")
                return "Error: Max retries reached"


def extract_answer_and_explanation(model_answer: str, options: List[str]) -> tuple:
    """Extract the answer and explanation from the model's response"""
    # First try to find the last instance of "Answer:" in the text
    answer_marker = "Answer:"
    explanation_marker = "Explanation:"
    
    try:
        # Split into answer and explanation parts
        if answer_marker in model_answer:
            # Find the last occurrence of Answer: as it's likely the final answer
            last_answer_idx = model_answer.rindex(answer_marker)
            remaining_text = model_answer[last_answer_idx:]
            
            # Split answer from explanation if explanation marker exists
            if explanation_marker in remaining_text:
                parts = remaining_text.split(explanation_marker)
                answer_part = parts[0].split(answer_marker)[1].strip()
                explanation = parts[1].strip()
            else:
                # If no explanation marker, take everything after the last answer marker
                # and look for a natural break (newline or period)
                answer_text = remaining_text.split(answer_marker)[1].strip()
                # Try to split on newline first
                if '\n' in answer_text:
                    answer_parts = answer_text.split('\n', 1)
                    answer_part = answer_parts[0].strip()
                    explanation = answer_parts[1].strip()
                else:
                    # Try to split on period
                    answer_parts = answer_text.split('.', 1)
                    answer_part = answer_parts[0].strip()
                    explanation = answer_parts[1].strip() if len(answer_parts) > 1 else "No explanation provided"
            
            # Clean up the answer
            answer_part = answer_part.strip('*').strip()
            
            # Try to extract option number if present
            option_number_match = re.search(r'(\d+)\)', answer_part)
            if option_number_match:
                option_number = int(option_number_match.group(1))
                if 1 <= option_number <= len(options):
                    return [options[option_number-1]], explanation
            
            # Direct matching with options
            matches = []
            for opt in options:
                if opt in answer_part:
                    matches.append(opt)
            
            # If no direct matches, try case-insensitive matching
            if not matches:
                for opt in options:
                    if opt.lower() in answer_part.lower():
                        matches.append(opt)
            
            # Special case for "Omit"
            if not matches and "omit" in answer_part.lower() and "Omit" in options:
                matches.append("Omit")
            
            if matches:
                return matches, explanation
                
        # Fallback to legacy parsing if no Answer: marker or no matches found
        return legacy_extract_answer_and_explanation(model_answer, options)
            
    except Exception as e:
        print(f"Error parsing response: {str(e)}")
        return legacy_extract_answer_and_explanation(model_answer, options)

def legacy_extract_answer_and_explanation(model_answer: str, options: List[str]) -> tuple:
    """Legacy method for extracting answers when structured format isn't followed"""
    # Search for patterns like "answer: X" or "X) option" anywhere in the text
    answer_patterns = [
        r'answer:\s*([^.\n]+)',
        r'(?:^|\n)\s*(\d+\)[\s\w\d\-\.,$]+)',
        r'the correct answer is:?\s*([^.\n]+)',
        r'therefore,?\s*(?:the correct answer is:?)?\s*([^.\n]+)',
    ]
    
    for pattern in answer_patterns:
        matches = re.finditer(pattern, model_answer.lower(), re.MULTILINE)
        for match in matches:
            potential_answer = match.group(1).strip()
            # Check if this potential answer contains any of our options
            for opt in options:
                if opt.lower() in potential_answer:
                    # Get the text after this match as the explanation
                    explanation_start = match.end()
                    explanation = model_answer[explanation_start:].strip()
                    if not explanation:
                        explanation = "No explanation provided"
                    return [opt], explanation
            
            # Check for option numbers
            number_match = re.search(r'(\d+)\)', potential_answer)
            if number_match:
                option_num = int(number_match.group(1))
                if 1 <= option_num <= len(options):
                    explanation_start = match.end()
                    explanation = model_answer[explanation_start:].strip()
                    if not explanation:
                        explanation = "No explanation provided"
                    return [options[option_num-1]], explanation
    
    # If no patterns matched, fall back to the original split logic
    separators = ['why:', 'because', 'as', 'since', 'explanation:', 'reasoning:', ':']
    parts = None
    
    for sep in separators:
        if sep in model_answer.lower():
            parts = model_answer.lower().split(sep, 1)
            if len(parts) > 1:
                break
    
    if not parts:
        parts = [model_answer, "No explanation provided"]
    
    chosen_option = parts[0].strip()
    explanation = parts[1].strip() if len(parts) > 1 else "No explanation provided"
    
    matches = []
    for opt in options:
        if opt.lower() in chosen_option.lower():
            matches.append(opt)
    
    if not matches:
        # Try to find the last occurrence of a number followed by )
        numbers = re.findall(r'(\d+)\)', chosen_option)
        if numbers:
            last_num = int(numbers[-1])  # Take the last number found
            if 1 <= last_num <= len(options):
                return [options[last_num-1]], explanation
    
    if not matches:
        print(f"Warning: Could not match response '{chosen_option}' to any option exactly.")
        print(f"Available options were: {options}")
        # Return the most likely option based on the last mentioned answer
        if "therefore" in model_answer.lower():
            last_part = model_answer.lower().split("therefore")[-1]
            for opt in options:
                if opt.lower() in last_part:
                    return [opt], explanation
    
    return matches if matches else [chosen_option], explanation

def evaluate_answer(correct_answer: str, model_answers: List[str]) -> bool:
    """Evaluate if the model's answer matches the correct answer"""
    # Convert answers to sets for comparison
    correct_set = {ans.strip() for ans in correct_answer.split(',')}
    model_set = {ans.strip() for ans in model_answers}
    
    # First try exact matching
    if correct_set & model_set:
        return True
        
    # If no exact match, try case-insensitive matching
    correct_lower = {ans.lower() for ans in correct_set}
    model_lower = {ans.lower() for ans in model_set}
    
    return bool(correct_lower & model_lower)

def evaluate_visualization_literacy(test_name: str, questions: list, prompt: str,
                                 randomize_options: bool = False,
                                 randomize_questions: bool = False) -> list:
    results = []
    working_questions = questions.copy()
    
    if randomize_questions:
        random.shuffle(working_questions)
    
    for idx, question in enumerate(working_questions, 1):
        print(f"\nProcessing {test_name} question {idx}/{len(working_questions)}")
        print(f"Conditions: Options {'Randomized' if randomize_options else 'Not Randomized'}, "
              f"Questions {'Randomized' if randomize_questions else 'Not Randomized'}")
        
        options = question['options'].copy()
        if randomize_options:
            random.shuffle(options)
        
        # Format multiple choice question
        question_prompt = f"Question: {question['question']}\n\nOptions:\n"
        for i, option in enumerate(options, 1):
            question_prompt += f"{i}) {option}\n"
        question_prompt += f"\n{prompt}"  # Add the test-specific prompt
        
        start_time = time.time()
        full_llama_answer = query_llama_with_retry(question_prompt, question['image_path'])
        end_time = time.time()
        time_taken = end_time - start_time
        
        print(f"Time taken: {time_taken:.2f} seconds")
        print(f"Raw Llama response: {full_llama_answer}")
        
        llama_answer, explanation = extract_answer_and_explanation(full_llama_answer, options)
        is_correct = evaluate_answer(question['correct_answer'], llama_answer)
        
        # Store results
        results.append({
            'test_name': test_name,
            'question': question['question'],
            'options': ', '.join(options),
            'correct_answer': question['correct_answer'],
            'llama_answer': ', '.join(llama_answer),
            'explanation': explanation,
            'raw_response': full_llama_answer,
            'Task': question['Task'],
            'Chart_type': question['Chart_type'],
            'Misleader': question.get('Misleader', ''),
            'wrong_due_to_misleader': question.get('wrong_due_to_misleader', ''),
            'is_correct': is_correct,
            'randomized_options': randomize_options,
            'randomized_questions': randomize_questions,
            'image_path': question['image_path'],
            'time_taken': time_taken
        })
        
        # Print results
        print(f"Question: {question['question']}")
        print(f"Options: {', '.join(options)}")
        print(f"Llama's answer: {', '.join(llama_answer)}")
        print(f"Llama's explanation: {explanation}")
        print(f"Correct answer: {question['correct_answer']}")
        print(f"Result: {'Correct' if is_correct else 'Incorrect'}")
        
        # Add a small delay between questions
        time.sleep(1)
    
    return results

def run_experiment(test_name: str, file_path: str, prompt: str):
    print(f"\nStarting {test_name} experiment...")
    questions = load_questions(file_path)
    
    conditions = [
         (False, False, "No_Randomization"),
        (True, False, "Randomized_Options"),
        (False, True, "Randomized_Questions"),
        (True, True, "Both_Randomized")
    ]
    
    all_results = []
    
    for randomize_options, randomize_questions, condition_name in conditions:
        print(f"\n=== Running {test_name} - {condition_name} ===")
        
        for run in range(1, NUM_RUNS + 1):
            print(f"\n--- Run {run}/{NUM_RUNS} ---")
            try:
                results = evaluate_visualization_literacy(
                    test_name,
                    questions, 
                    prompt,
                    randomize_options=randomize_options,
                    randomize_questions=randomize_questions
                )
                
                # Add condition information to results
                for result in results:
                    result['condition'] = condition_name
                    result['run'] = run
                
                all_results.extend(results)
                
                # Save individual run results
                df_run = pd.DataFrame(results)
                output_file = f'llama_vision_{test_name.lower()}_{condition_name}_run_{run}.csv'
                df_run.to_csv(output_file, index=False)
                
                # Calculate and print run score
                score = (df_run['is_correct'].sum() / len(df_run)) * 100
                avg_time = df_run['time_taken'].mean()
                print(f"\nScore for {condition_name} Run {run}: {score:.2f}%")
                print(f"Average time per question: {avg_time:.2f} seconds")
                
            except Exception as e:
                print(f"Error in run {run}: {str(e)}")
                traceback.print_exc()
                continue
    
    # Combine all results
    if all_results:
        combined_df = pd.DataFrame(all_results)
        
        # Calculate and print overall statistics
        print("\n=== Overall Results ===")
        for condition in combined_df['condition'].unique():
            condition_df = combined_df[combined_df['condition'] == condition]
            print(f"\n{condition}:")
            print(f"Mean accuracy: {condition_df['is_correct'].mean() * 100:.2f}%")
            print(f"Best question accuracy: {condition_df.groupby('question')['is_correct'].mean().max() * 100:.2f}%")
            print(f"Worst question accuracy: {condition_df.groupby('question')['is_correct'].mean().min() * 100:.2f}%")
            print(f"Average time per question: {condition_df['time_taken'].mean():.2f} seconds")
            print(f"Fastest question: {condition_df['time_taken'].min():.2f} seconds")
            print(f"Slowest question: {condition_df['time_taken'].max():.2f} seconds")
        
        # Statistics by various dimensions
        print("\nAverage Statistics by Task:")
        print(combined_df.groupby(['Task', 'condition'])['is_correct'].mean().unstack())
        
        print("\nAverage Time by Task:")
        print(combined_df.groupby(['Task', 'condition'])['time_taken'].mean().unstack())
        
        print("\nAverage Statistics by Chart Type:")
        print(combined_df.groupby(['Chart_type', 'condition'])['is_correct'].mean().unstack())
        
        print("\nAverage Time by Chart Type:")
        print(combined_df.groupby(['Chart_type', 'condition'])['time_taken'].mean().unstack())
        
        if 'Misleader' in combined_df.columns:
            print("\nAverage Statistics by Misleader Type:")
            print(combined_df.groupby(['Misleader', 'condition'])['is_correct'].mean().unstack())
        
        # Save combined results
        output_file = f'llama_vision_{test_name.lower()}_all_results.csv'
        combined_df.to_csv(output_file, index=False)
        print(f"\nCombined results saved to {output_file}")

In [None]:
if __name__ == "__main__":
    try:
        run_experiment("VLAT", "vlat_skip.json", VLAT_PROMPT)
        run_experiment("CALVI", "calvi.json", CALVI_PROMPT)
    except Exception as e:
        print(f"Fatal error: {str(e)}")
        traceback.print_exc()