In [6]:
import openai
import time
from tqdm import tqdm
import json
from openai import AzureOpenAI
import requests

client = AzureOpenAI(
  azure_endpoint = "YOUR-API-ENDPOINT", 
  api_key="YOUR-API-KEY",  
  api_version="2024-05-01-preview"
)

def initialize_openai(api_key, api_base, api_version="2024-05-01-preview"):
    """
    Initializes the OpenAI API client with Azure OpenAI configurations.

    Args:
        api_key (str): Azure OpenAI API key.
        api_base (str): Azure OpenAI endpoint URL.
        api_version (str, optional): Azure OpenAI API version. Defaults to "2024-05-01-preview".
    """
    openai.api_type = "azure"
    openai.api_key = api_key
    openai.api_base = api_base
    openai.api_version = api_version

def get_completion_with_retry(output_text, deployment_name, max_retries=5, delay=10):
    """
    Calls Azure OpenAI's chat completion API to extract the answer letter with retry logic.

    Args:
        output_text (str): The raw text output from the model.
        deployment_name (str): The deployment name of the model in Azure OpenAI.
        max_retries (int, optional): Maximum number of retries for API calls. Defaults to 5.
        delay (int, optional): Seconds to wait between retries. Defaults to 10.

    Returns:
        str: The extracted prediction letter or "Invalid" if extraction fails.
    """
    # instructblip Define the prompt to extract only the answer letter
    prompt = (
        "Extract only the final answer letter (A, B, C, etc.) from the following text. "
        "Provide only the single uppercase letter without any additional text.\n\n"
        f"Text: {output_text}\n\nAnswer:"
    )

    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model=deployment_name,  # instructblip deployment_name corresponds to the model deployed in Azure
                messages=[
                    {"role": "system", "content": "You are an AI assistant that helps people find information."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=2,  # instructblip Only need a single character
                temperature=0,  # instructblip Deterministic output
                top_p=1,
                n=1,
                stop=None
            )
            # instructblip answer = response.choices[0].message['content'].strip().upper()
            response_data = response.to_dict()
            answer = response_data["choices"][0]["message"]["content"].strip()

            if answer in {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 
                         'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'}:
                return answer
            else:
                return "Invalid"
        except openai.error.RateLimitError:
            retries = attempt + 1
            print(f"Rate limit exceeded. Retrying {retries}/{max_retries} in {delay} seconds...")
            time.sleep(delay)
        except Exception as e:
            print(f"Error during OpenAI API call: {e}. Retrying {attempt + 1}/{max_retries} in {delay} seconds...")
            time.sleep(delay)
    return "Invalid"

def validate_prediction(prediction, choices):
    """
    Validates the predicted answer against the available choices.

    Args:
        prediction (str): The predicted answer letter.
        choices (list): The list of choice strings.

    Returns:
        str: The valid prediction letter or "Invalid" if the prediction is not among the choices.
    """
    # instructblip Generate valid option letters based on the number of choices
    valid_options = {chr(ord("A") + i) for i in range(len(choices))}
    if prediction in valid_options:
        return prediction
    else:
        return "Invalid"

def correct_evaluation_results(input_file, output_file, deployment_name):
    """
    Loads the evaluation JSON file, corrects the exact matches,
    and saves the updated data to a new JSON file.

    Args:
        input_file (str): Path to the input JSON file.
        output_file (str): Path to save the corrected JSON file.
        deployment_name (str): The deployment name of the model in Azure OpenAI.
    """
    # instructblip Load the existing JSON data
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    # instructblip Iterate through each item in the logs array
    for item in tqdm(data["logs"], desc="Processing items"):
        output_text = item["resps"][0]
        choices = item["doc"]["choices"]
        reference = item["target"]

        # instructblip Extract the predicted answer using Azure OpenAI
        prediction = get_completion_with_retry(output_text, deployment_name)

        # instructblip Validate the prediction against available choices
        valid_prediction = validate_prediction(prediction, choices)

        # instructblip Determine exact match
        exact_match = 1.0 if valid_prediction == reference else 0.0

        # instructblip Update only the exact_match field
        item["exact_match"] = exact_match

    # instructblip Save the updated data to the output JSON file
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

    print(f"\nUpdated evaluation results have been saved to '{output_file}'.")

api_key="YOUR-API-KEY"
api_base="2024-05-01-preview"
initialize_openai(api_key, api_base)

# instructblip Define Parameters
input_file = "/Volumes/work/LAS/wzhang-lab/mingl/code/vllm/lmms_eval/logs/visual/now/0909_0911_instructblip_13b_instructblip_model_args_7539f0/raw_I-scienceqa_insert_image.json"
output_file = "/Volumes/work/LAS/wzhang-lab/mingl/code/vllm/lmms_eval/logs/visual/now/0909_0911_instructblip_13b_instructblip_model_args_7539f0/I-scienceqa_insert_image.json"

deployment_name = "gpt35"  # instructblip Replace with your actual deployment name, e.g., 'text-davinci-003'

# instructblip Execute the Correction Process
correct_evaluation_results(input_file, output_file, deployment_name)

Processing items: 100%|██████████| 2099/2099 [05:34<00:00,  6.27it/s]



Updated evaluation results have been saved to '/Volumes/work/LAS/wzhang-lab/mingl/code/vllm/lmms_eval/logs/visual/now/0909_0911_instructblip_13b_instructblip_model_args_7539f0/I-scienceqa_insert_image.json'.


In [4]:
import json

def calculate_accuracy(input_file):
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    total_samples = 0
    correct_matches = 0
    
    for item in data['logs']:
        total_samples += 1
        if item['exact_match'] == 1.0:
            correct_matches += 1
    
    accuracy = correct_matches / total_samples if total_samples > 0 else 0
    
    print(f"Total samples: {total_samples}")
    print(f"Correct matches: {correct_matches}")
    print(f"Accuracy: {accuracy:.2%}")
    
    return accuracy

# instructblip Usage
output_file = "/Volumes/work/LAS/wzhang-lab/mingl/code/vllm/lmms_eval/logs/visual/now/0909_0822_instructblip_7b_instructblip_model_args_a7cc8c/I-scienceqa_insert_image.json"
accuracy = calculate_accuracy(output_file)
accuracy

Total samples: 2099
Correct matches: 744
Accuracy: 35.45%


0.354454502143878

In [10]:
import json
from collections import defaultdict

def calculate_accuracy_by_distract_type(input_file):
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    distract_type_data = defaultdict(lambda: {'total': 0, 'correct': 0})
    
    for item in data['logs']:
        distract_type = item['doc']['distract_type']
        distract_type_data[distract_type]['total'] += 1
        if item['exact_match'] == 1.0:
            distract_type_data[distract_type]['correct'] += 1
    
    print("Accuracy by distract_type:")
    for distract_type, counts in distract_type_data.items():
        total = counts['total']
        correct = counts['correct']
        accuracy = correct / total if total > 0 else 0
        print(f"{distract_type}: {accuracy:.2%}")
    
    overall_total = sum(counts['total'] for counts in distract_type_data.values())
    overall_correct = sum(counts['correct'] for counts in distract_type_data.values())
    overall_accuracy = overall_correct / overall_total if overall_total > 0 else 0
    
    print(f"\nOverall accuracy: {overall_accuracy:.2%} ({overall_correct}/{overall_total})")
    
    return distract_type_data, overall_accuracy

# instructblip Usage
output_file = "/Volumes/work/LAS/wzhang-lab/mingl/code/vllm/lmms_eval/logs/visual/now/0909_0802_llava_13b_llava_model_args_ba31e2/I-scienceqa_insert_image.json"
distract_type_data, overall_accuracy = calculate_accuracy_by_distract_type(output_file)

Accuracy by distract_type:
neutral_backgrounds: 73.20%
generic_landscapes: 68.40%
abstract_art: 68.80%
everyday_objects: 66.00%
cultural_artifacts: 72.00%
digital_creations: 70.80%
word_embeddings: 69.60%
emotional_contexts: 66.80%
diffusion_inpainting: 72.73%

Overall accuracy: 69.60% (1461/2099)


In [None]:
# # instructblip 7b insert_image
# Accuracy by distract_type:
# neutral_backgrounds: 37.60%
# generic_landscapes: 36.80%
# abstract_art: 32.40%
# everyday_objects: 30.40%
# cultural_artifacts: 39.20%
# digital_creations: 34.40%
# word_embeddings: 34.00%
# emotional_contexts: 36.00%
# diffusion_inpainting: 42.42%


# # instructblip 13b insert_image
# Accuracy by distract_type:
# neutral_backgrounds: 73.20%
# generic_landscapes: 68.40%
# abstract_art: 68.80%
# everyday_objects: 66.00%
# cultural_artifacts: 72.00%
# digital_creations: 70.80%
# word_embeddings: 69.60%
# emotional_contexts: 66.80%
# diffusion_inpainting: 72.73%


# # instructblip 7b add_image
# Accuracy by distract_type:
# abstract_art: 38.80%
# cultural_artifacts: 37.60%
# word_embeddings: 44.40%
# digital_creations: 42.40%
# neutral_backgrounds: 37.60%
# generic_landscapes: 42.40%
# emotional_contexts: 40.00%
# everyday_objects: 45.20%

# # instructblip 13b add_image
# Accuracy by distract_type:
# abstract_art: 70.40%
# cultural_artifacts: 72.00%
# word_embeddings: 73.60%
# digital_creations: 75.20%
# neutral_backgrounds: 73.20%
# generic_landscapes: 70.40%
# emotional_contexts: 67.60%
# everyday_objects: 73.60%