In [None]:
import json
import random
from PIL import Image
import pandas as pd
import time
import traceback
import anthropic
import base64
import os
from typing import List, Dict, Any

In [None]:
client = anthropic.Anthropic(api_key='Add_your_api_key_here')

In [None]:
NUM_RUNS = 10
MAX_RETRIES = 3
RETRY_DELAY = 10  # seconds

In [None]:
VLAT_PROMPT = """I am about to show you an image and ask you a multiple choice question about that image. 
Please structure your response in the following format:
Answer: [Enter the exact text of your chosen option]
Explanation: [Provide your reasoning]
Select the BEST answer, based only on the chart and not external knowledge. DO NOT GUESS.
If you are not sure about your answer or your answer is based on a guess, select "Omit".
Choose your answer ONLY from the provided options."""

CALVI_PROMPT = """I am about to show you an image and ask you a multiple choice question about that image. 
Please structure your response in the following format:
Answer: [Enter the exact text of your chosen option(s)]
Explanation: [Provide your reasoning]
Select the BEST answer, based only on the chart and not external knowledge.
Choose your answer ONLY from the provided options."""

In [None]:
def load_questions(file_path: str) -> List[Dict[str, Any]]:
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data['questions']

def get_image_mime_type(image_path: str) -> str:
    extension = os.path.splitext(image_path)[1].lower()
    if extension in ['.jpg', '.jpeg']:
        return 'image/jpeg'
    elif extension == '.png':
        return 'image/png'
    elif extension == '.gif':
        return 'image/gif'
    elif extension == '.webp':
        return 'image/webp'
    else:
        raise ValueError(f"Unsupported image format: {extension}")

def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def query_claude_with_retry(prompt: str, image_path: str) -> str:
    base64_image = encode_image(image_path)
    mime_type = get_image_mime_type(image_path)
    
    for attempt in range(MAX_RETRIES):
        try:
            message = client.messages.create(
                model="claude-3-5-sonnet-20240620",
                max_tokens=300,
                temperature=0,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prompt
                            },
                            {
                                "type": "image",
                                "source": {
                                    "type": "base64",
                                    "media_type": mime_type,
                                    "data": base64_image
                                }
                            }
                        ]
                    }
                ]
            )
            return message.content[0].text
        except Exception as e:
            print(f"Error occurred (Attempt {attempt + 1}/{MAX_RETRIES}): {str(e)}")
            if attempt < MAX_RETRIES - 1:
                print(f"Retrying in {RETRY_DELAY} seconds...")
                time.sleep(RETRY_DELAY)
            else:
                print("Max retries reached. Skipping this question.")
                return "Error: Max retries reached"

def extract_answer_and_explanation(claude_answer: str, options: List[str]) -> tuple:
    # Handle empty response
    if not claude_answer:
        return [], "No response provided"
    
    # First try to find a clearly marked answer section
    answer_markers = [
        "Answer:", 
        "answer:",
        "The best answer based on the information provided in the image is:",
        "The best answer based on the chart is:",
        "The best answer is:"
    ]
    
    explanation_markers = [
        "Explanation:",
        "explanation:",
        "Why:",
        "why:"
    ]
    
    # Initialize variables
    chosen_option = ""
    explanation = ""
    
    # Try to find the answer section
    lower_response = claude_answer.lower()
    
    # Find the start of the answer section
    answer_start = -1
    answer_marker_used = ""
    for marker in answer_markers:
        if marker.lower() in lower_response:
            answer_start = lower_response.index(marker.lower()) + len(marker)
            answer_marker_used = marker
            break
    
    # Find the start of the explanation section
    explanation_start = -1
    explanation_marker_used = ""
    for marker in explanation_markers:
        if marker.lower() in lower_response:
            explanation_start = lower_response.index(marker.lower())
            explanation_marker_used = marker
            break
    
    # Extract answer and explanation
    if answer_start >= 0:
        if explanation_start >= 0:
            # We have both answer and explanation
            chosen_option = claude_answer[answer_start:explanation_start].strip()
            explanation = claude_answer[explanation_start + len(explanation_marker_used):].strip()
        else:
            # We only have answer
            chosen_option = claude_answer[answer_start:].strip()
    else:
        # Fallback: use the first line as answer if it's not too long
        first_line = claude_answer.split('\n')[0].strip()
        if len(first_line) < 100:  # Arbitrary length check to avoid using explanations as answers
            chosen_option = first_line
    
    # Clean up the chosen option
    chosen_option = chosen_option.strip()
    
    # Remove numbering if present (e.g., "3) Answer" -> "Answer")
    if chosen_option and chosen_option[0].isdigit() and ') ' in chosen_option:
        chosen_option = chosen_option.split(') ', 1)[1].strip()
    
    # Remove quotes if present
    chosen_option = chosen_option.strip('"\'')
    
    # Match with provided options
    matches = []
    for opt in options:
        # Exact match (case-insensitive)
        if opt.lower().strip() == chosen_option.lower().strip():
            return [opt], explanation
    
    # If no exact match, try partial matches
    for opt in options:
        # Check if the option is contained within the chosen answer
        if opt.lower().strip() in chosen_option.lower():
            matches.append(opt)
    
    # If we found matches, return the first one
    if matches:
        return [matches[0]], explanation
    
    # If no matches but we have a chosen_option, return it
    if chosen_option and len(chosen_option) > 1:
        return [chosen_option], explanation
    
    return ["No valid answer extracted"], explanation

def evaluate_answer(correct_answer: str, claude_answer: List[str]) -> bool:
    correct_answers = set(answer.strip().lower() for answer in correct_answer.split(','))
    claude_answers = set(answer.strip().lower() for answer in claude_answer)
    return bool(correct_answers & claude_answers)

def evaluate_visualization_literacy(test_name: str, questions: List[Dict], prompt: str,
                                 randomize_options: bool = False,
                                 randomize_questions: bool = False) -> List[Dict]:
    results = []
    working_questions = questions.copy()
    
    if randomize_questions:
        random.shuffle(working_questions)
    
    for idx, question in enumerate(working_questions, 1):
        print(f"\nProcessing {test_name} question {idx}/{len(working_questions)}")
        print(f"Conditions: Options {'Randomized' if randomize_options else 'Not Randomized'}, "
              f"Questions {'Randomized' if randomize_questions else 'Not Randomized'}")
        
        options = question['options'].copy()
        if randomize_options:
            random.shuffle(options)
        
        question_prompt = f"Question: {question['question']}\n\nOptions:\n"
        for i, option in enumerate(options, 1):
            question_prompt += f"{i}) {option}\n"
        question_prompt += f"\n{prompt}"
        
        start_time = time.time()
        full_claude_answer = query_claude_with_retry(question_prompt, question['image_path'])
        end_time = time.time()
        time_taken = end_time - start_time
        
        print(f"Time taken: {time_taken:.2f} seconds")
        print(f"Raw Claude response: {full_claude_answer}")
        
        claude_answer, explanation = extract_answer_and_explanation(full_claude_answer, options)
        is_correct = evaluate_answer(question['correct_answer'], claude_answer)
        
        results.append({
            'test_name': test_name,
            'question': question['question'],
            'options': ', '.join(options),
            'correct_answer': question['correct_answer'],
            'claude_answer': ', '.join(claude_answer),
            'explanation': explanation,
            'raw_response': full_claude_answer,
            'Task': question.get('Task', ''),
            'Chart_type': question.get('Chart_type', ''),
            'Misleader': question.get('Misleader', ''),
            'wrong_due_to_misleader': question.get('wrong_due_to_misleader', ''),
            'is_correct': is_correct,
            'randomized_options': randomize_options,
            'randomized_questions': randomize_questions,
            'image_path': question['image_path'],
            'time_taken': time_taken
        })
        
        print(f"Question: {question['question']}")
        print(f"Options: {', '.join(options)}")
        print(f"Claude's answer: {', '.join(claude_answer)}")
        print(f"Claude's explanation: {explanation}")
        print(f"Correct answer: {question['correct_answer']}")
        print(f"Result: {'Correct' if is_correct else 'Incorrect'}")
        
        # Add a small delay between questions
        time.sleep(1)
    
    return results

def run_experiment(test_name: str, file_path: str, prompt: str):
    print(f"\nStarting {test_name} experiment...")
    questions = load_questions(file_path)
    
    conditions = [
        (False, False, "No_Randomization"),
        (True, False, "Randomized_Options"),
        (False, True, "Randomized_Questions"),
        (True, True, "Both_Randomized")
    ]
    
    all_results = []
    
    for randomize_options, randomize_questions, condition_name in conditions:
        print(f"\n=== Running {test_name} - {condition_name} ===")
        
        for run in range(1, NUM_RUNS + 1):
            print(f"\n--- Run {run}/{NUM_RUNS} ---")
            try:
                results = evaluate_visualization_literacy(
                    test_name,
                    questions, 
                    prompt,
                    randomize_options=randomize_options,
                    randomize_questions=randomize_questions
                )
                
                # Add condition information to results
                for result in results:
                    result['condition'] = condition_name
                    result['run'] = run
                
                all_results.extend(results)
                
                # Save individual run results
                df_run = pd.DataFrame(results)
                output_file = f'claude_{test_name.lower()}_{condition_name}_run_{run}.csv'
                df_run.to_csv(output_file, index=False)
                
                # Calculate and print run score
                score = (df_run['is_correct'].sum() / len(df_run)) * 100
                avg_time = df_run['time_taken'].mean()
                print(f"\nScore for {condition_name} Run {run}: {score:.2f}%")
                print(f"Average time per question: {avg_time:.2f} seconds")
                
            except Exception as e:
                print(f"Error in run {run}: {str(e)}")
                traceback.print_exc()
                continue
    
    # Combine all results
    if all_results:
        combined_df = pd.DataFrame(all_results)
        
        # Calculate and print overall statistics
        print("\n=== Overall Results ===")
        for condition in combined_df['condition'].unique():
            condition_df = combined_df[combined_df['condition'] == condition]
            print(f"\n{condition}:")
            print(f"Mean accuracy: {condition_df['is_correct'].mean() * 100:.2f}%")
            print(f"Best question accuracy: {condition_df.groupby('question')['is_correct'].mean().max() * 100:.2f}%")
            print(f"Worst question accuracy: {condition_df.groupby('question')['is_correct'].mean().min() * 100:.2f}%")
            print(f"Average time per question: {condition_df['time_taken'].mean():.2f} seconds")
            print(f"Fastest question: {condition_df['time_taken'].min():.2f} seconds")
            print(f"Slowest question: {condition_df['time_taken'].max():.2f} seconds")
        
        # Statistics by various dimensions
        print("\nAverage Statistics by Task:")
        print(combined_df.groupby(['Task', 'condition'])['is_correct'].mean().unstack())
        
        print("\nAverage Time by Task:")
        print(combined_df.groupby(['Task', 'condition'])['time_taken'].mean().unstack())
        
        print("\nAverage Statistics by Chart Type:")
        print(combined_df.groupby(['Chart_type', 'condition'])['is_correct'].mean().unstack())
        
        print("\nAverage Time by Chart Type:")
        print(combined_df.groupby(['Chart_type', 'condition'])['time_taken'].mean().unstack())
        
        if 'Misleader' in combined_df.columns and not combined_df['Misleader'].isna().all():
            print("\nAverage Statistics by Misleader Type:")
            print(combined_df.groupby(['Misleader', 'condition'])['is_correct'].mean().unstack())
            
            misleader_stats = combined_df[combined_df['wrong_due_to_misleader'].notna()]
            if not misleader_stats.empty:
                print("\nAccuracy for Questions with Misleader Issues:")
                print(misleader_stats.groupby(['wrong_due_to_misleader', 'condition'])['is_correct'].mean().unstack())
        
        # Save combined results
        output_file = f'claude_{test_name.lower()}_all_results.csv'
        combined_df.to_csv(output_file, index=False)
        print(f"\nCombined results saved to {output_file}")

In [None]:
if __name__ == "__main__":
    try:
        run_experiment("VLAT", "vlat_skip.json", VLAT_PROMPT)
        run_experiment("CALVI", "calvi.json", CALVI_PROMPT)
    except Exception as e:
        print(f"Fatal error: {str(e)}")
        traceback.print_exc()