In [None]:
#load context dataset of all images
#includes: keywords, mentions, original_caption, path to figure image and the full text
import pandas as pd
import s3fs
fs = s3fs.S3FileSystem(anon=False)
s3_path = "s3://___.parquet/"

files = fs.glob(f"{s3_path}*.parquet")
print(f"Found {len(files)} files in the directory.")

dfs = []
for file in files:
    df = pd.read_parquet(f"s3://{file}", filesystem=fs)
    dfs.append(df)
df = pd.concat(dfs, ignore_index=True)
df.shape

In [None]:
from huggingface_hub import login
login(token='')


In [None]:
#needed to run captioning
import torch
import pandas as pd
from PIL import Image
import numpy as np
import os
import json
from tqdm.auto import tqdm
import math
import datetime
import time
import concurrent.futures
import warnings
import multiprocessing as mp 
from itertools import chain
from worker_utils_mult_GPU_context import inference_worker_batch    #loading and calling the model with batches (also pasted at the bottom of this notebook) 

In [None]:
base_repo_prompt = """You are an assistant tasked with summarizing images for retrieval. \
                    You are given contextual information as well as the image. \
                    These summaries will be embedded and used to retrieve the raw image. \
                    Give a concise summary of the image that is well optimized for retrieval."""

In [None]:
#Prompt generation
final_prompt = """Analyze the scientific figure and any provided context to create a concise, retrieval-optimized caption. Follow these guidelines:

Overview: Summarize the figure’s main subject, purpose and experimental focus.

Visual Details: Identify and describe important visual elements (e.g. charts, diagrams, symbols, numeric labels, legends).

Relationship & Patterns: Explain interactions, trends, or hierarchies evident in the figure.

Terminology & Keywords: Use relevant domain-specific terms (with plain-language equivalents if needed), then list essential keywords for retrieval.

Relevance & Accuracy: Include only information visible in the figure or clearly stated in the context. Avoid speculation, repetition or irrelevant details.

Formatting: Present the description coherently, in a paragraph of 300- to 400 words. Omit statements about missing elements unless explicitly relevant. """


# 0= original caption only, 1= title, abstract, keywords; 2: sentences, mention_section; 3: sentences, keywords, original caption


def generate_prompt(row, config_num):   #
    prompt_components = []

    if config_num == 0:
        prompt_components.append(f"[CAPTION-START]{row['original_caption']}[CAPTION-END]\n")
    elif config_num == 1:
        prompt_components.append(f"[TITLE-START]{ row['title']}[TITLE-END]\n")
        prompt_components.append(f"[ABSTRACT-START]{row['abstract']}[ABSTRACT-END]\n")
        prompt_components.append(f"[KEYWORDS-START]{row['keywords']}[KEYWORDS-END]\n")
    elif config_num == 2:
        sentences = row.get('sentences')
        if hasattr(sentences, '__len__') and len(sentences) > 0:
             if isinstance(sentences, (list, np.ndarray)):
                  sentences_str = " ".join(map(str, sentences)) 
             else:
                  sentences_str = str(sentences)
             prompt_components.append(f"[MENTION-START]{sentences_str}[MENTION-END]\n")
            
        prompt_components.append(f"[SECTION-START] This Figure is from the section: {row['mention_section']}[SECTION-END]\n")
    elif config_num == 3:
        sentences = row.get('sentences')
        if hasattr(sentences, '__len__') and len(sentences) > 0:
             if isinstance(sentences, (list, np.ndarray)):
                  sentences_str = " ".join(map(str, sentences)) 
             else:
                  sentences_str = str(sentences)
             prompt_components.append(f"[MENTION-START]{sentences_str}[MENTION-END]\n")
            
        prompt_components.append(f"[KEYWORDS-START]{row['keywords']}[KEYWORDS-END]\n")
        prompt_components.append(f"[CAPTION-START]{row['original_caption']}[CAPTION-END]\n")
            
    else:
        print('INCORRECT CONFIG_NUM')

    context_str = "".join(prompt_components)
    final_prompt_instr = final_prompt

    if context_str:
        return context_str + "\n" + final_prompt_instr
    else:
        return final_prompt_instr

In [None]:
#check for multiGPU
try:
    mp.set_start_method('spawn', force=True)
    print("spawn set")
except RuntimeError as e:
    print(f"Warning: {e}")


if not torch.cuda.is_available():
    raise SystemError("no CUDA!")

NUM_GPUS = torch.cuda.device_count() #- 2
print(f"Found {NUM_GPUS} GPUs")
if NUM_GPUS == 0:
     raise SystemError("No GPUs")

In [None]:
#needed arguments/ params
WORKER_BATCH_SIZE = 8
MODEL_ID = "openbmb/MiniCPM-V-2_6"
IMAGE_COLUMN_NAME = 'image_path'
cache_path = 'cache/'
os.makedirs(cache_path, exist_ok=True)
model_type = 'miniCPM_testing_baseline_prmpt_REPO_run' #where to save
caption_cache_file = os.path.join(cache_path, f'image_captions_{model_type}.csv')
fail_log_file = os.path.join(cache_path, f'captioning_fails_{model_type}.txt')
prompts_to_process_file = os.path.join(cache_path, f'prompts_to_process_{model_type}.jsonl')
INPUT_DF_NAME = 'test_subset' 

In [None]:
#loading data & set up cache to save captions
df_input = test_subset.copy()
if IMAGE_COLUMN_NAME not in df_input.columns:
     print(f"Need image paths avalable!")


processed_captions_list = []
processed_image_paths = set()
if os.path.exists(caption_cache_file):
    print("Loading existing captions")
    try:
        df_cache = pd.read_csv(caption_cache_file, usecols=['image_path', 'generated_caption', 'config_num'])
        if not df_cache.empty:
            cached_records = df_cache.apply(lambda x: (x['image_path'], x['config_num']), axis=1).tolist()
            processed_image_paths.update(cached_records)
        else:
            print("Cache file exists but is empty.")
    except:
        print("New cache.")
        processed_image_paths = set()
else:
    print("Starting fresh.")

In [None]:
#Generating prompts first
IMAGE_BASE_PATH = "./data/399k_imgs/"

prompts_to_process_data = [] 
skipped_cached_count = 0
skipped_invalid_count = 0
invalid_paths_logged_phase2 = set()

with open(fail_log_file, 'w') as f_fail:
     f_fail.write(f"Captioning Fail Log - Run Started: {datetime.datetime.now()}\n")
     f_fail.write("--- Phase 2 Failures (Path/Prompt Gen) ---\n")

print("generating prompts")
for index, row in tqdm(df_input.iterrows(), total=len(df_input), desc="Checking Images & Generating Prompts"):
    row_dict = row.to_dict()
    row_index_for_log = index

    relative_image_path = row_dict.get(IMAGE_COLUMN_NAME)
    if not relative_image_path or pd.isna(relative_image_path):
        log_key = f"Row {row_index_for_log}: Missing Path"
        if log_key not in invalid_paths_logged_phase2:
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Skipped row {row_index_for_log}: Missing image path in DataFrame\n")
            invalid_paths_logged_phase2.add(log_key)
        skipped_invalid_count += 1
        continue

    full_image_path = os.path.join(IMAGE_BASE_PATH, str(relative_image_path))
    full_image_path = os.path.normpath(full_image_path)

    if not os.path.exists(full_image_path):
        log_key = f"no file: {full_image_path}"
        if log_key not in invalid_paths_logged_phase2:
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"File not found at '{full_image_path}'\n")
            invalid_paths_logged_phase2.add(log_key)
        skipped_invalid_count += 1
        continue


    try:
        for config_num in range(4):
            prompt = generate_prompt(row_dict, config_num)
            prompt_data = {
                'image_path': full_image_path,
                'prompt': prompt,
                'config_num': config_num  
            }
            prompts_to_process_data.append(prompt_data)
            with open(prompts_to_process_file, 'a') as f_prompts:
                f_prompts.write(json.dumps(prompt_data) + '\n')

    except Exception as e:
        log_key = f"Prompt Gen Error: {full_image_path}"
        if log_key not in invalid_paths_logged_phase2:
            print(f"Failed to generate prompt for row {row_index_for_log}, image: {full_image_path}. Error: {e}")
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Failed row {row_index_for_log} during prompt generation: {full_image_path} | Error: {str(e)}\n")
            invalid_paths_logged_phase2.add(log_key)
        skipped_invalid_count += 1
        continue

num_prompts_to_process = len(prompts_to_process_data)
print(f"\nIdentified {num_prompts_to_process} new images")
print(f"Skipped {skipped_cached_count} in cache.")
print(f"Skipped {skipped_invalid_count} rows")
if skipped_invalid_count > 0:
     print(f"(Check '{fail_log_file}' for details on failed items)")
if num_prompts_to_process > 0:
    print(f"Details of items to process saved to: {prompts_to_process_file}")

In [None]:
#captioning code for any number GPUs
if num_prompts_to_process == 0:
    print("no new images.")
else:
    print(f"using {NUM_GPUS} GPUs")
    start_time = time.time()
    
    if num_prompts_to_process < NUM_GPUS:
         actual_num_workers = num_prompts_to_process
         print("Warning: less items than GPUs")
         data_chunks = [[item] for item in prompts_to_process_data]
         while len(data_chunks) < NUM_GPUS:
             data_chunks.append([])
    else:
         actual_num_workers = NUM_GPUS 
         data_chunks_np = np.array_split(prompts_to_process_data, actual_num_workers)
         data_chunks = [chunk.tolist() for chunk in data_chunks_np]


    with open(fail_log_file, 'a') as f_fail:
        f_fail.write("\n--- Phase 3 Failures (Worker Processing) ---\n")

    pool_args = [
        (gpu_id, data_chunks[gpu_id], MODEL_ID, WORKER_BATCH_SIZE, fail_log_file)
        for gpu_id in range(actual_num_workers) 
         if len(data_chunks[gpu_id]) > 0
    ]

    all_results_list = []
    print("parallel part")
    with mp.Pool(processes=actual_num_workers) as pool:
        try:
            all_results_list = pool.starmap(inference_worker_batch, pool_args)
        except Exception as pool_exc:
             print("CRITICAL ERROR  PARALLEL PART")
             print(f"error: {pool_exc}")
             import traceback
             traceback.print_exc()


    end_time = time.time()

    print(f"Parallel part took: {end_time - start_time:.2f} secs.")



    all_processed_items = list(chain.from_iterable(all_results_list))
    print(f"Got: {len(all_processed_items)} captions")

    successful_captions = [item for item in all_processed_items if item['status'] == 'success' and item.get('generated_caption')]
    failed_items_from_workers = [item for item in all_processed_items if item['status'] == 'failed']

    newly_processed_count = len(successful_captions)
    failed_during_processing_count = len(failed_items_from_workers)

    print(f"Successfully generated {newly_processed_count} new captions.")
    print(f"Detected {failed_during_processing_count} failures during worker processing.")

    if successful_captions:
        df_new_captions = pd.DataFrame(successful_captions)[['image_path', 'generated_caption', 'config_num']]

        try:
            file_exists = os.path.exists(caption_cache_file)
            is_empty = not file_exists or os.path.getsize(caption_cache_file) == 0

            df_new_captions.to_csv(caption_cache_file, mode='a', header=is_empty, index=False)
        except Exception as e:
            print(f"ERROR sending to cache file: {e}")
            temp_fail_path = f"{caption_cache_file}.failed_append_{datetime.datetime.now():%Y%m%d_%H%M%S}.csv"
            df_new_captions.to_csv(temp_fail_path, index=False)
            print(f"failed data at: {temp_fail_path}")

    else:
        print("Not working!!")


    print("FINAL PROCESSED NUMBERS: ")
    print(f"images processed successfully: {newly_processed_count}")
    total_failures_reported = skipped_invalid_count + failed_during_processing_count
    print(f"failed: {total_failures_reported}")


print(f"\nScript finished at: {datetime.datetime.now()}")

In [None]:
base_cot_prompt = """Generate a retrieval-optimized caption through these steps:
1. VISUAL FOUNDATION (Image Analysis Only):
Identify primary visualization type and core components
Catalog essential elements: axes, labels, legends, data representations
Note quantitative relationships and spatial patterns
Flag ambiguous elements requiring contextual clarification
2. CONTEXTUAL MEANING (Text Analysis Only):
Extract research purpose and experimental focus
Identify domain-specific terminology and claims
Note methodological parameters and key hypotheses
Highlight contextual expectations for the figure
3. ALIGNED SYNTHESIS (Intelligent Combination):
Visual-Text Alignment: Where context describes visuals, incorporate directly
Visual-Only Elements: Describe objectively without forcing textual justification
Context-Only Insights: Include critical claims without visual evidence 
Conflict Resolution: Note contradictions between text and visuals
Formulate integrated overview of purpose + main subject
4. CAPTION GENERATION (Composition Guidelines):
Compose a single 300-400 word paragraph that:
• Starts with figure purpose and main subject
• Integrates visual components with contextual meaning
• Explains relationships using domain terminology
• Incorporates keywords naturally (technical + plain-language)
• Maintains objective tone without speculation
5. SELF-VALIDATION (Quality Control):
Verify claims traceable to visual/textual evidence
Confirm no unsupported speculation
Ensure 300-400 word length
Check keyword coverage for retrieval
Validate objective description of visual-textual relationships
Remove redundant statements
Check for completeness of visual information"""

In [None]:

def generate_prompt(row, config_num):   #
    prompt_components = []

    if config_num == 0:
        prompt_components.append(f"[CAPTION-START]{row['original_caption']}[CAPTION-END]\n")
    elif config_num == 1:
        prompt_components.append(f"[TITLE-START]{ row['title']}[TITLE-END]\n")
        prompt_components.append(f"[ABSTRACT-START]{row['abstract']}[ABSTRACT-END]\n")
        prompt_components.append(f"[KEYWORDS-START]{row['keywords']}[KEYWORDS-END]\n")
    elif config_num == 2:
        sentences = row.get('sentences')
        if hasattr(sentences, '__len__') and len(sentences) > 0:
             if isinstance(sentences, (list, np.ndarray)):
                  sentences_str = " ".join(map(str, sentences)) 
             else:
                  sentences_str = str(sentences)
             prompt_components.append(f"[MENTION-START]{sentences_str}[MENTION-END]\n")
            
        prompt_components.append(f"[SECTION-START] This Figure is from the section: {row['mention_section']}[SECTION-END]\n")
    elif config_num == 3:
        sentences = row.get('sentences')
        if hasattr(sentences, '__len__') and len(sentences) > 0:
             if isinstance(sentences, (list, np.ndarray)):
                  sentences_str = " ".join(map(str, sentences)) 
             else:
                  sentences_str = str(sentences)
             prompt_components.append(f"[MENTION-START]{sentences_str}[MENTION-END]\n")
            
        prompt_components.append(f"[KEYWORDS-START]{row['keywords']}[KEYWORDS-END]\n")
        prompt_components.append(f"[CAPTION-START]{row['original_caption']}[CAPTION-END]\n")
            
    else:
        print('INCORRECT CONFIG_NUM')

    context_str = "".join(prompt_components)
    final_prompt_instr = base_repo_prompt  

    return final_prompt_instr


In [None]:
if num_prompts_to_process == 0:
    print("no new images.")
else:
    print(f"using {NUM_GPUS} GPUs")
    start_time = time.time()
    
    if num_prompts_to_process < NUM_GPUS:
         actual_num_workers = num_prompts_to_process
         print("Warning: less items than GPUs")
         data_chunks = [[item] for item in prompts_to_process_data]
         while len(data_chunks) < NUM_GPUS:
             data_chunks.append([])
    else:
         actual_num_workers = NUM_GPUS 
         data_chunks_np = np.array_split(prompts_to_process_data, actual_num_workers)
         data_chunks = [chunk.tolist() for chunk in data_chunks_np]


    with open(fail_log_file, 'a') as f_fail:
        f_fail.write("\n--- Phase 3 Failures (Worker Processing) ---\n")

    pool_args = [
        (gpu_id, data_chunks[gpu_id], MODEL_ID, WORKER_BATCH_SIZE, fail_log_file)
        for gpu_id in range(actual_num_workers) 
         if len(data_chunks[gpu_id]) > 0
    ]

    all_results_list = []
    print("parallel part")
    if actual_num_workers == 1:
        gpu_id, data_chunk, model_id, worker_batch_size, fail_log_file = pool_args[0]
        result = inference_worker_batch(gpu_id, data_chunk, model_id, worker_batch_size, fail_log_file)
        all_results_list = [result]
    else:
        with mp.Pool(processes=actual_num_workers) as pool:
            try:
                all_results_list = pool.starmap(inference_worker_batch, pool_args)
            except Exception as pool_exc:
                print("CRITICAL ERROR  PARALLEL PART")
                print(f"error: {pool_exc}")
                import traceback
                traceback.print_exc()

    end_time = time.time()

    print(f"Parallel part took: {end_time - start_time:.2f} secs.")

    all_processed_items = list(chain.from_iterable(all_results_list))
    print(f"Got: {len(all_processed_items)} captions")

    successful_captions = [item for item in all_processed_items if item['status'] == 'success' and item.get('generated_caption')]
    failed_items_from_workers = [item for item in all_processed_items if item['status'] == 'failed']

    newly_processed_count = len(successful_captions)
    failed_during_processing_count = len(failed_items_from_workers)

    print(f"Successfully generated {newly_processed_count} new captions.")
    print(f"Detected {failed_during_processing_count} failures during worker processing.")

    if successful_captions:
        df_new_captions = pd.DataFrame(successful_captions)[['image_path', 'generated_caption', 'config_num']]

        try:
            file_exists = os.path.exists(caption_cache_file)
            is_empty = not file_exists or os.path.getsize(caption_cache_file) == 0

            df_new_captions.to_csv(caption_cache_file, mode='a', header=is_empty, index=False)
        except Exception as e:
            print(f"ERROR sending to cache file: {e}")
            temp_fail_path = f"{caption_cache_file}.failed_append_{datetime.datetime.now():%Y%m%d_%H%M%S}.csv"
            df_new_captions.to_csv(temp_fail_path, index=False)
            print(f"failed data at: {temp_fail_path}")


    else:
        print("Not working!!")


    print("FINAL PROCESSED NUMBERS: ")
    print(f"images processed successfully: {newly_processed_count}")
    total_failures_reported = skipped_invalid_count + failed_during_processing_count
    print(f"failed: {total_failures_reported}")


print(f"\nScript finished at: {datetime.datetime.now()}")

In [None]:
# FINAL: EVAL BY L(V)LM of generated captions + human eval on rubrics


In [None]:
#Generating prompts first
IMAGE_BASE_PATH = "./data/399k_imgs/"

prompts_to_process_data = [] 
skipped_cached_count = 0
skipped_invalid_count = 0
invalid_paths_logged_phase2 = set()

with open(fail_log_file, 'w') as f_fail:
     f_fail.write(f"Captioning Fail Log - Run Started: {datetime.datetime.now()}\n")
     f_fail.write("--- Phase 2 Failures (Path/Prompt Gen) ---\n")

print("generating prompts")
for index, row in tqdm(df_input.iterrows(), total=len(df_input), desc="Checking Images & Generating Prompts"):
    row_dict = row.to_dict()
    row_index_for_log = index

    relative_image_path = row_dict.get(IMAGE_COLUMN_NAME)
    if not relative_image_path or pd.isna(relative_image_path):
        log_key = f"Row {row_index_for_log}: Missing Path"
        if log_key not in invalid_paths_logged_phase2:
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Skipped row {row_index_for_log}: Missing image path in DataFrame\n")
            invalid_paths_logged_phase2.add(log_key)
        skipped_invalid_count += 1
        continue

    full_image_path = os.path.join(IMAGE_BASE_PATH, str(relative_image_path))
    full_image_path = os.path.normpath(full_image_path)

    if not os.path.exists(full_image_path):
        log_key = f"no file: {full_image_path}"
        if log_key not in invalid_paths_logged_phase2:
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"File not found at '{full_image_path}'\n")
            invalid_paths_logged_phase2.add(log_key)
        skipped_invalid_count += 1
        continue

    try:
        prompt = generate_prompt(row_dict, 3)
        prompt_data = {
            'image_path': full_image_path,
            'prompt': prompt,
            'config_num': 3 
        }
        prompts_to_process_data.append(prompt_data)
        with open(prompts_to_process_file, 'a') as f_prompts:
            f_prompts.write(json.dumps(prompt_data) + '\n')

    except Exception as e:
        log_key = f"Prompt Gen Error: {full_image_path}"
        if log_key not in invalid_paths_logged_phase2:
            print(f"Failed to generate prompt for row {row_index_for_log}, image: {full_image_path}. Error: {e}")
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Failed row {row_index_for_log} during prompt generation: {full_image_path} | Error: {str(e)}\n")
            invalid_paths_logged_phase2.add(log_key)
        skipped_invalid_count += 1
        continue

num_prompts_to_process = len(prompts_to_process_data)
print(f"\nIdentified {num_prompts_to_process} new images")
print(f"Skipped {skipped_cached_count} in cache.")
print(f"Skipped {skipped_invalid_count} rows")
if skipped_invalid_count > 0:
     print(f"(Check '{fail_log_file}' for details on failed items)")
if num_prompts_to_process > 0:
    print(f"Details of items to process saved to: {prompts_to_process_file}")

In [None]:
base_eval_rubrics = """Technical Correctness: Accuracy of scientific terminology, quantitative values, and conceptual relationships depicted in the figure.
Scoring Scale:
Incorrect (1): Contains factual errors or misrepresents data relationships
Partially Correct (2): Generally accurate but contains minor technical inaccuracies 
Mostly Correct (3): Precise technical language with isolated omissions 
Fully Correct (4): Technically precise with complete quantitative details 
Completeness: Coverage of all critical visual elements and conceptual components necessary for figure interpretation.
Scoring Scale:
Incomplete (1): Misses >50% of salient elements 
Partially Complete (2): Covers primary elements with notable gaps 
Substantially Complete (3): Includes most elements with minor omissions 
Exhaustively Complete (4): Comprehensive coverage including secondary elements 
Conciseness & Focus: Precision in highlighting essential information while eliminating redundancy
Scoring Scale:
Redundant/Unfocused (1): Contains significant extraneous information obscuring key insights
Partially Focused (2): Communicates main points but with noticeable digressions
Mostly Concise (3): Direct presentation with minor non-essential details
Precisely Focused (4): Economical delivery of maximum relevant information
Method-Context Integration: Appropriate incorporation of experimental methodology relevant to figure interpretation.
Scoring Scale:
Absent (1): No methodological references 
Basic (2): Generic method mentions
Contextualized (3): Specific technique references 
Interpretive (4): Methodological details enabling critical analysis 
Visual-Context Synthesis: Effective utilization of visual elements beyond basic description.
Scoring Scale:
Superficial (1): Basic element listing 
Descriptive (2): Visual feature identification 
Analytical (3): Integrated visual-textual analysis 
Interpretive (4): Synthesized visual-data insights"""

In [None]:
API_VERSION = "2024-02-15-preview"
MODEL_ID = "gpt4o-240513"

def get_azure_client(): 
    return AzureOpenAI(
        azure_endpoint=AZURE_ENDPOINT,
        api_key=AZURE_API_KEY,
        api_version=API_VERSION
    )

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('ascii')

def inference_worker_batch(worker_id, data_chunk, model_id, batch_size, fail_log_file):
    client = get_azure_client()
    results = []
    
    for item in data_chunk:
        image_path = item['image_path']
        try:
            base64_image = encode_image(image_path)
            messages = [{
                "role": "user",
                "content": [
                    {"type": "text", "text": item['prompt']},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
                }]
            }]

            response = client.chat.completions.create(
                model=MODEL_ID,
                messages=messages,
                max_tokens=512
            )
            
            results.append({
                'image_path': image_path,
                'generated_caption': response.choices[0].message.content.strip(),
                'status': 'success',
                'error': None
            })
            
        except Exception as e:
            error_msg = str(e)
            results.append({
                'image_path': image_path,
                'generated_caption': None,
                'status': 'failed',
                'error': error_msg
            })
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"[Worker {worker_id}] Error processing {image_path}: {error_msg}\n")
    
    return results

In [None]:
'grfe'

In [None]:
import pandas as pd
import numpy as np
import base64
from openai import AzureOpenAI
import time

evaluation_prompt = """Evaluate the provided figure caption based on the image, context, and these criteria:

[FIGURE CONTEXT]
{context_str}

[GENERATED CAPTION]
{generated_caption}

[EVALUATION RUBRICS]
{base_eval_rubrics}

[INSTRUCTIONS]
1. Carefully examine the figure and context
2. For each rubric category:
   - Compare caption content against visual evidence
   - Check alignment with contextual information
   - Assign score (1-4) based on rubric definitions
   - Provide 1-sentence justification referencing specific evidence
3. Maintain strict objectivity: Base scores only on visible/contextual evidence
4. Output JSON format: {{"rubric": {{"score": int, "justification": str}}}}

[OUTPUT REQUIREMENTS]
- Valid JSON object with all 5 rubrics
- Scores must be integers 1-4
- Justifications reference figure/context specifics
- No additional commentary
"""

def build_context_str(context_row):
    prompt_components = []
    sentences = context_row.get('sentences')
    if hasattr(sentences, '__len__') and len(sentences) > 0:
        if isinstance(sentences, (list, np.ndarray)):
            sentences_str = " ".join(map(str, sentences))
        else:
            sentences_str = str(sentences)
        prompt_components.append(f"[MENTION-START]{sentences_str}[MENTION-END]\n")
    prompt_components.append(f"[KEYWORDS-START]{context_row['keywords']}[KEYWORDS-END]\n")
    prompt_components.append(f"[CAPTION-START]{context_row['original_caption']}[CAPTION-END]\n")
    return "".join(prompt_components)

In [None]:


def evaluate_captions_with_vllm(df, context_df, fail_log_file="fail_log.txt"):
    client = get_azure_client()
    eval_answers = []

    context_lookup = context_df.set_index('image_path').to_dict('index')

    for idx, row in df.iterrows():
        og_path = row['image_path']
        image_path = og_path.split('/')[-1]
        generated_caption = row['generated_caption']
        context_row = context_lookup.get(image_path)
        if context_row is None:
            eval_answers.append({'status': 'failed', 'error': 'Context not found'})
            continue

        context_str = build_context_str(context_row)

        prompt = evaluation_prompt.format(
            context_str=context_str,
            generated_caption=generated_caption,
            base_eval_rubrics=base_eval_rubrics
        )

        #print('here:  ',  prompt, image_path)
        try:
            base64_image = encode_image(og_path)
        except Exception as e:
            eval_answers.append({'status': 'failed', 'error': f'Image encoding error: {str(e)}'})
            continue

        messages = [{
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
            ]
        }]

        print(time.time())
        try:
            response = client.chat.completions.create(
                model=MODEL_ID,
                messages=messages,
                max_tokens=512
            )
            eval_answer = response.choices[0].message.content.strip()
            eval_answers.append(eval_answer)
        except Exception as e:
            error_msg = str(e)
            eval_answers.append({'status': 'failed', 'error': error_msg})
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Error processing {image_path}: {error_msg}\n")

        time.sleep(8)


    df['eval_answer'] = eval_answers
    return df

In [None]:

def try_4o_cot(df, context_df, fail_log_file="fail_log.txt"):
    client = get_azure_client()
    eval_answers = []

    context_lookup = context_df.set_index('image_path').to_dict('index')

    for idx, row in df.iterrows():
        og_path = row['image_path']
        image_path = 'data/399k_imgs/' + og_path
        context_row = context_lookup.get(og_path)
        if context_row is None:
            eval_answers.append({'status': 'failed', 'error': 'Context not found'})
            continue

        context_str = build_context_str(context_row)
        prompt = context_str + "\n" +  base_cot_prompt

        #print('here:  ',  prompt, image_path)
        try:
            base64_image = encode_image(image_path)
        except Exception as e:
            eval_answers.append({'status': 'failed', 'error': f'Image encoding error: {str(e)}'})
            continue

        messages = [{
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
            ]
        }]

        print(time.time())
        try:
            response = client.chat.completions.create(
                model=MODEL_ID,
                messages=messages,
                max_tokens=512
            )
            answer = response.choices[0].message.content.strip()
            eval_answers.append(answer)
        except Exception as e:
            error_msg = str(e)
            eval_answers.append({'status': 'failed', 'error': error_msg})
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Error processing {image_path}: {error_msg}\n")

        time.sleep(3)


    df['generated_caption'] = eval_answers
    return df

In [None]:
results = evaluate_captions_with_vllm(
    df=df_1k.iloc[1698:],   #[1436:], #   + 262
    context_df=df_full_texts,
    cache_file_path="cache/caption_exp_contexts_4.csv",
    experiment_name="context",
    fail_log_file="cache/errors_exp.log"
)

In [None]:
results_2 = evaluate_captions_with_vllm(
    df=new_cot.iloc[100:],
    context_df=df_full_texts,
    cache_file_path="cache/caption_exp_cpm_cot_2.csv",
    experiment_name="cot",
    fail_log_file="cache/errors_exp.log"
)

In [None]:
results_3 = evaluate_captions_with_vllm(
    df=filtered_4o.iloc[100:],
    context_df=df_full_texts,
    cache_file_path="cache/caption_exp_4o_2.csv",
    experiment_name="4o",
    fail_log_file="cache/errors_exp.log"
)

In [None]:
results_4 = evaluate_captions_with_vllm(
    df=filtered_cpm.iloc[225:],
    context_df=df_full_texts,
    cache_file_path="cache/caption_exp_cpm_2.csv",
    experiment_name="cpm",
    fail_log_file="cache/errors_exp.log"
)

results_5 = evaluate_captions_with_vllm(
    df=filtered_qwen.iloc[100:],
    context_df=df_full_texts,
    cache_file_path="cache/caption_exp_qwen_2.csv",
    experiment_name="qwen",
    fail_log_file="cache/errors_exp.log"
)

In [None]:
def evaluate_captions_with_vllm(df, context_df, cache_file_path, experiment_name, fail_log_file="fail_log.txt"):
    client = get_azure_client()
    eval_answers = []
    cache_dict = {}
    
    if os.path.exists(cache_file_path):
        try:
            cache_df = pd.read_csv(cache_file_path)

            
            if not cache_df.empty:  
                for _, row in cache_df.iterrows():
                    key = (row['image_path'], row['generated_caption'], row['experiment'])
                    cache_dict[key] = row['eval_answer']
        except (pd.errors.EmptyDataError, KeyError):
            pass

    os.makedirs(os.path.dirname(cache_file_path), exist_ok=True)
    
    if not os.path.exists(cache_file_path):
        with open(cache_file_path, 'w', newline='', encoding='utf-8') as f_cache:
            writer = csv.writer(f_cache)
            writer.writerow(['image_path', 'generated_caption', 'experiment', 'eval_answer'])

    context_lookup = context_df.set_index('image_path').to_dict('index')

    for idx, row in df.iterrows():
        image_path = row['image_path']
        og_path = 'data/399k_imgs/' + image_path   
        generated_caption = row['generated_caption']
        if type(generated_caption) != str:
            continue
        cache_key = (og_path, generated_caption, experiment_name)

        if cache_key in cache_dict:
            eval_answers.append(cache_dict[cache_key])
            continue

        context_row = context_lookup.get(image_path)
        if context_row is None:
            error_str = json.dumps({'status': 'failed', 'error': 'Context not found'})
            eval_answers.append(error_str)
            cache_dict[cache_key] = error_str
            continue

        context_str = build_context_str(context_row)
        prompt = evaluation_prompt.format(
            context_str=context_str,
            generated_caption=generated_caption,
            base_eval_rubrics=base_eval_rubrics
        )

        try:
            base64_image = encode_image(og_path)
        except Exception as e:
            error_str = json.dumps({'status': 'failed', 'error': f'Image encoding error: {str(e)}'})
            eval_answers.append(error_str)
            cache_dict[cache_key] = error_str
            continue

        messages = [{
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
            ]
        }]

        try:
            response = client.chat.completions.create(
                model=MODEL_ID,
                messages=messages,
                max_tokens=512
            )
            eval_answer = response.choices[0].message.content.strip()
        except Exception as e:
            eval_answer = json.dumps({'status': 'failed', 'error': str(e)})
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Error processing {og_path}: {str(e)}\n")
        
        with open(cache_file_path, 'a', newline='', encoding='utf-8') as f_cache:
            writer = csv.writer(f_cache)
            writer.writerow([og_path, generated_caption, experiment_name, eval_answer])
        
        cache_dict[cache_key] = eval_answer
        eval_answers.append(eval_answer)
        time.sleep(3)

    df['eval_answer'] = eval_answers
    return df

In [None]:
#same function but using the original captions instead of the generated ones for evaluation
def evaluate_captions_with_vllm(df, context_df, cache_file_path, experiment_name, fail_log_file="fail_log.txt"):
    client = get_azure_client()
    eval_answers = []
    cache_dict = {}
    
    if os.path.exists(cache_file_path):
        try:
            cache_df = pd.read_csv(cache_file_path)

            
            if not cache_df.empty:  
                for _, row in cache_df.iterrows():
                    key = (row['image_path'], row['generated_caption'], row['experiment'])
                    cache_dict[key] = row['eval_answer']
        except (pd.errors.EmptyDataError, KeyError):
            pass

    os.makedirs(os.path.dirname(cache_file_path), exist_ok=True)
    
    if not os.path.exists(cache_file_path):
        with open(cache_file_path, 'w', newline='', encoding='utf-8') as f_cache:
            writer = csv.writer(f_cache)
            writer.writerow(['image_path', 'generated_caption', 'experiment', 'eval_answer'])

    context_lookup = context_df.set_index('image_path').to_dict('index')

    for idx, row in df.iterrows():
        image_path = row['image_path']
        small_path = image_path.split('/')[-1]
        og_path = image_path
        generated_caption = context_lookup.get(small_path).get('original_caption') 
        if type(generated_caption) != str:
            continue
        cache_key = (og_path, generated_caption, experiment_name)

        if cache_key in cache_dict:
            eval_answers.append(cache_dict[cache_key])
            continue

        context_row = context_lookup.get(small_path)
        if context_row is None:
            error_str = json.dumps({'status': 'failed', 'error': 'Context not found'})
            eval_answers.append(error_str)
            cache_dict[cache_key] = error_str
            continue

        context_str = build_context_str(context_row)
        prompt = evaluation_prompt.format(
            context_str=context_str,
            generated_caption=generated_caption,
            base_eval_rubrics=base_eval_rubrics
        )
        #print('THIS PROMPT IS:  ',prompt )
        try:
            base64_image = encode_image(og_path)
        except Exception as e:
            error_str = json.dumps({'status': 'failed', 'error': f'Image encoding error: {str(e)}'})
            eval_answers.append(error_str)
            cache_dict[cache_key] = error_str
            continue

        messages = [{
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
            ]
        }]

        try:
            response = client.chat.completions.create(
                model=MODEL_ID,
                messages=messages,
                max_tokens=512
            )
            eval_answer = response.choices[0].message.content.strip()
        except Exception as e:
            eval_answer = json.dumps({'status': 'failed', 'error': str(e)})
            with open(fail_log_file, 'a') as f_fail:
                f_fail.write(f"Error processing {og_path}: {str(e)}\n")
        
        with open(cache_file_path, 'a', newline='', encoding='utf-8') as f_cache:
            writer = csv.writer(f_cache)
            writer.writerow([og_path, generated_caption, experiment_name, eval_answer])
        
        cache_dict[cache_key] = eval_answer
        eval_answers.append(eval_answer)
        time.sleep(3)

    df['eval_answer'] = eval_answers
    return df

In [None]:
#to pick images for (human) evaluation:
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os

def categorize_images(df):
    result_df = pd.DataFrame(columns=['image_path', 'category'])
    
    os.makedirs('temp_images', exist_ok=True)
    
    try:
        working_df = df.sample(frac=1)
        for idx, row in df.iterrows():
            img_path = row['image_path']
            
            if not os.path.exists(img_path):
                print(f"Image not found: {img_path}")
                continue
                
            img = Image.open(img_path)
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"Image {idx+1}/{len(df)}")
            plt.show()
            while True:
                action = input("Action? [y=include, n=skip, s=stop]: ").strip().lower()
                
                if action == 's':  
                    print("Stopping categorization...")
                    return result_df
                    
                elif action == 'n':  
                    print("Skipping image...")
                    plt.close()
                    break
                    
                elif action == 'y':  
                    category = input("Enter category for this image: ").strip()
                    if not category:
                        print("Category cannot be empty!")
                        continue
                        
                    result_df = pd.concat([
                        result_df,
                        pd.DataFrame([{'image_path': img_path, 'category': category}])
                    ], ignore_index=True)
                    
                    print(f"Added as: {category}")
                    plt.close()
                    break
                    
                else:
                    print("Invalid input! Please enter y, n, or s")
    
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
    
    plt.close('all')  
    return result_df


In [None]:
# MAIN, CPM_COT, 4o, 4oCoT, 1= original caption only, 2= title, abstract, keywords; 3: sentences, mention_section

In [None]:
#for human evaluation
display_image_and_captions('data/399k_imgs/S0098135407000828_fig_gr2.jpg', df4, df_mini_cot, df_4o, df_4o_cot, df1, df2, df3, context_df)

In [None]:
AZURE_API_KEY=""
AZURE_ENDPOINT=""

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def display_image_and_captions(image_path, df1, df2, df3, df4, df5, df6, df7, context_df):
    full_image_path = image_path  #f'data/images/{image_path}'
    try:
        img = mpimg.imread(full_image_path)
        plt.figure(figsize=(10, 8))
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Image: {image_path}')
        plt.show()
    except FileNotFoundError:
        print(f"Image not found: {full_image_path}")
        return
    
    caption_dfs = [df1, df2, df3, df4, df5, df6, df7]
    print("\n" + "="*50)
    print("Generated Captions from Six Models:")
    print("="*50)
    for i, df in enumerate(caption_dfs, 1):
        matches = df[df['image_path'] == image_path]
        
        if matches.empty:
            print(f"Model {i}: No caption found for '{image_path}'\n")
        else:
            caption = matches.iloc[0]['generated_caption']
            print(f"Model {i} Caption:\n{caption}\n{'-'*50}\n")
    
    print("\n" + "="*50)
    print("Contextual Information:")
    print("="*50)
    
    context_matches = context_df[context_df['image_path'] == image_path.split('/')[-1]]
    
    if context_matches.empty:
        print(f"No contextual information found for '{image_path}'")
    else:
        context_data = context_matches.iloc[0] 
        print(f"Sentences:\n{context_data['sentences']}\n")
        print(f"Keywords:\n{context_data['keywords']}\n")
        print(f"Original Caption:\n{context_data['original_caption']}")