In [1]:
import csv
import difflib

# Helper function to calculate string similarity
def str_similarity(str1, str2):
    seq = difflib.SequenceMatcher(None, str1, str2)
    return seq.ratio()

def find_most_similar_index(str_list, target_str):
    """
    Given a list of strings and a target string, returns the index of the most similar string in the list.
    """
    # Initialize variables to keep track of the most similar string and its index
    most_similar_str = None
    most_similar_index = None
    highest_similarity = 0
    
    # Iterate through each string in the list
    for i, str in enumerate(str_list):
        # Calculate the similarity between the current string and the target string
        similarity = str_similarity(str, target_str)
        
        # If the current string is more similar than the previous most similar string, update the variables
        if similarity > highest_similarity:
            most_similar_str = str
            most_similar_index = i
            highest_similarity = similarity
    
    # Return the index of the most similar string
    return most_similar_index

# Function to calculate open accuracy using similarity measure
def calculate_open_accuracy(csv_file_path):
    with open(csv_file_path, 'r') as file:
        reader = csv.DictReader(file)
        total_similarity = 0
        total_questions = 0
        for row in reader:
            label = row['Label']
            pred = row['Pred']
            # Find the index of the most similar prediction
            most_similar_index = find_most_similar_index([pred], label)
            # Calculate the similarity only if a most similar string was found
            similarity = str_similarity(pred, label) if most_similar_index is not None else 0
            total_similarity += similarity
            total_questions += 1
        # Return the total similarity and the total number of questions for weighted accuracy calculation
        return total_similarity, total_questions

# Function to calculate closed accuracy using direct comparison
def calculate_closed_accuracy(csv_file_path):
    with open(csv_file_path, 'r') as file:
        reader = csv.DictReader(file)
        hits = 0
        total_questions = 0
        for row in reader:
            label, pred = row['Label'], row['Pred']
            hits += 1 if label.lower() in pred.lower() else 0
            total_questions += 1
        return hits, total_questions

# Function to calculate overall accuracy considering both open and closed questions
def calculate_overall_accuracy(open_csv_file_path, closed_csv_file_path):
    open_total_similarity, open_total_questions = calculate_open_accuracy(open_csv_file_path)
    closed_hits, closed_total_questions = calculate_closed_accuracy(closed_csv_file_path)
    
    # Calculate weighted overall accuracy
    total_correct = open_total_similarity + closed_hits
    total_questions = open_total_questions + closed_total_questions
    overall_accuracy = total_correct / total_questions if total_questions > 0 else 0
    
    # Calculate individual accuracies for open and closed
    open_accuracy = open_total_similarity / open_total_questions if open_total_questions > 0 else 0
    closed_accuracy = closed_hits / closed_total_questions if closed_total_questions > 0 else 0
    
    return open_accuracy, closed_accuracy, overall_accuracy

# Paths to the CSV files (replace these placeholders with your actual file paths)
open_csv_file_path = '/root/Uni-Med/src/Uni/result_final_greedy_VQA_RAD_openQA_no_pretrain_no_aug_test1_VQA_RAD.csv'  # Placeholder path
closed_csv_file_path = '/root/Uni-Med/src/Uni/result_final_greedy_VQA_RAD_closeQA_no_pretrain_no_aug_test1_VQA_RAD.csv'  # Placeholder path

# Calculate accuracies
open_accuracy, closed_accuracy, overall_accuracy = calculate_overall_accuracy(open_csv_file_path, closed_csv_file_path)

print("Open Accuracy:", open_accuracy)
print("Closed Accuracy:", closed_accuracy)
print("Overall Accuracy:", overall_accuracy)


Open Accuracy: 0.0859589040738053
Closed Accuracy: 0.7610294117647058
Overall Accuracy: 0.49309677123993606
