### Use GPT-4o to take text chunks & image summaries (or image itself) as input and generate queries and reference answers that fit to the article and/or figure info and instruction (what kind of queries are acceptable).

In [None]:
AZURE_OPENAI_API_KEY=""
AZURE_OPENAI_ENDPOINT=""

In [None]:
import openai
from openai import AzureOpenAI

MODEL = 'gpt4o-240513'
API_VERSION = "2024-02-15-preview"
client = AzureOpenAI(
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_key=AZURE_OPENAI_API_KEY,
    api_version=API_VERSION
)

In [None]:
#get the full texts of your data
import pandas as pd
import s3fs

fs = s3fs.S3FileSystem(anon=False)
s3_path = ""
files = fs.glob(f"{s3_path}*.parquet")
print(f"Found {len(files)} files in the directory.")

dfs = []
for file in files:
    try:
        df = pd.read_parquet(f"s3://{file}", filesystem=fs)
        dfs.append(df)
    except Exception as e:
        print(f"Skipping file {file} due to error: {e}")

df_full_texts = pd.concat(dfs, ignore_index=True)
print(df_full_texts.head())

In [None]:
#get the generated captions of your figures
image_summary_df = pd.read_csv("data/captions.csv")
all_chunks_df = pd.read_csv("data/399k_all_chunks.csv")
image_summary_df.columns

In [None]:
captions = pd.read_csv("data/50k_captions.csv")

In [None]:
#### instruction prompts for different eval categories
## text-based
numerical_prompt = """Generate a query asking for numerical metrics from research papers. Follow these steps:
Extract the domain (e.g., machine learning, blockchain) from the text.
Frame the query to request exact values: 'What [metric] is reported for [concept]?'
The answer must cite values from methodology/results sections. If values are implied (e.g., 'significant improvement'), add a disclaimer like 'typically reported as...'.
Strict Rules:
Never invent unsupported numbers.
Reject queries if the context lacks numerical data and that do not relate to the domain of Computer Science.
Examples:
Context: 'PoS consensus reduces energy consumption by 99.99% compared to PoW (Buterin, 2022).'
Query: 'What energy savings are reported for Proof-of-Stake vs. Proof-of-Work?'
Answer: 'PoS reduces energy use by 99.99% compared to PoW.'
text_sentence: 'PoS consensus reduces energy consumption by 99.99% compared to PoW'
image_sentence: 'N/A'
Context: 'Sharding improves database throughput by 10x (Chen et al., 2023).'
Query: 'What throughput increase does sharding provide in distributed databases?'
Answer: 'Sharding increases throughput by 10× in horizontally scaled systems.'
text_sentence: 'Sharding improves database throughput by 10x'
image_sentence: 'N/A'"""

factual_prompt = """Generate a query asking for a formal definition relating to the domain of Computer Science. Follow these steps:
Identify key terms in the text.
Frame the query along the lines of: 'How is [concept] defined?'
The answer MUST paraphrase definitions verbatim from literature reviews/introductions.
If definitions are incomplete, infer using related terms but add: 'Commonly understood as...'.
Strict Rules:
Never conflate concepts.
Flag uncertainty with disclaimers.
Reject queries that do not relate to the domain of Computer Science.
Example:
Context: 'Federated learning trains models on decentralized devices (Kairouz et al., 2021).'
Query: 'How is federated learning defined?'
Answer: 'Decentralized training where raw data remains on client devices.'
text_sentece: 'Federated learning trains models on decentralized devices'
image_sentence: 'N/A'"""

compare_prompt = """Generate a query comparing two concepts relating to the domain of Computer Science. Follow these steps:
Identify contrasting terms in the text.
Frame the query as: 'How do [Concept A] and [Concept B] differ in [aspect]?'
The answer MUST cite explicit differences from discussion sections.
If differences are indirect, infer but note: 'Based on general trends...'.
Strict Rules:
Never assume unstated trade-offs.
Never conflate concepts.
Flag uncertainty with disclaimers.
Example:
Context: 'RSA is slower but more compatible; ECC is faster but less adopted (Chen, 2021).'
Query: 'Compare RSA and ECC encryption in speed and compatibility.'
Answer: 'RSA is slower but widely compatible; ECC is faster but less adopted (Chen, 2021).'
text_sentece: 'RSA is slower but more compatible; ECC is faster but less adopted'
image_sentence: 'N/A'"""


compound_multipart_prompt = '''Generate a multi-part query asking for components/stages and their evaluation. Follow these steps:
1. Identify a concept with distinct components/stages from the text.
2. Frame the query as: "What are the [number] [components/stages] of [concept], and how are they measured/evaluated?"
3. The answer MUST cite methodology sections for components and evaluation sections for metrics.
4. If context lacks evaluation criteria, state: "Evaluation methods not described."

Strict Rules:
- Never invent components or metrics.
- Reject queries if components/stages are unspecified or unrelated to Computer Science.

Examples:
Context: "Blockchain consensus involves proposal, validation, and commit phases measured by latency (ms)."
Query: "What are the three phases of blockchain consensus, and how is their latency measured?"
Answer: "Proposal, validation, commit; measured in milliseconds (ms)." 
text_sentence: "Blockchain consensus involves proposal, validation, and commit phases measured by latency (ms)."
image_sentence: "N/A"
Context: "GAN training includes generator/discriminator steps evaluated with FID scores."
Query: "What are the two stages of GAN training, and what metric evaluates their performance?"
Answer: "Generator and discriminator stages; evaluated using FID scores."
text_sentence: "GAN training includes generator/discriminator steps evaluated with FID scores."
image_sentence: "N/A"'''

trend_analysis_prompt = '''Generate a query about research trends or variable relationships. Follow these steps:
1. Identify two variables/metrics (e.g., model size vs. accuracy) from the text.
2. Frame the query as: "What [trend/correlation] exists between [Variable A] and [Variable B] in [domain]?"
3. The answer MUST synthesize findings from results sections or survey papers.
4. If trends are implied, add: "Literature suggests..."

Strict Rules:
- Never assert causation without explicit evidence.
- Reject queries lacking variable comparisons or unrelated to Computer Science.

Examples:
Context: "Model compression reduces accuracy by 2-5% but cuts inference time by 60%."
Query: "What trade-off exists between model compression and accuracy in NLP?"
Answer: "Compression reduces accuracy by 2-5% while cutting inference time by 60%."
text_sentence: "Model compression reduces accuracy by 2-5% but cuts inference time by 60%."
image_sentence: "N/A"
Context: "Energy consumption scales quadratically with transformer depth."
Query: "How does transformer depth correlate with energy usage?"
Answer: "Quadratic scaling: doubling layers increases energy use 4×."
text_sentence: "Energy consumption scales quadratically with transformer depth."
image_sentence: "N/A"'''

limitation_identification_prompt = '''Generate a query about reported limitations. Follow these steps:
1. Identify a concept/technique and its constraints from discussion/limitations sections.
2. Frame the query as: "What limitations exist for [concept] in [domain/use case]?"
3. The answer MUST paraphrase limitations verbatim. If indirect, state: "Challenges include..."

Strict Rules:
- Never conflate limitations across domains.
- Reject queries without explicit limitation mentions or unrelated to Computer Science.

Examples:
Context: "Differential privacy reduces utility in high-dimensional datasets due to noise."
Query: "What limitations does differential privacy introduce for high-dimensional data?"
Answer: "Noise addition reduces data utility in high-dimensional spaces."
text_sentence: "Differential privacy reduces utility in high-dimensional datasets due to noise."
image_sentence: "N/A"
Context: "Blockchain sharding risks cross-shard communication overhead."
Query: "What scalability challenges exist for blockchain sharding?"
Answer: "Cross-shard communication creates significant overhead."
text_sentence: "Blockchain sharding risks cross-shard communication overhead."
image_sentence: "N/A"'''

## image-summary based
feature_enumeration_prompt = """Generate a query asking for design features relating to the domain of Computer Science. Follow these steps:
Identify the system from the image summaries.
Frame the query as: 'What features characterize [system] architectures?'
The answer MUST list attributes (e.g., 'API Gateway') from the summary.
If features are missing, reply: 'Key features not described'.
Strict Rules:
Only list explicitly described features.
Reject queries that do not relate to the domain of Computer Science.
Either make queries general enough for multiple papers to possibly answer it, or make sure its clear which system/ model/ feature/ phenomenon is asked for! 
Example:
Context: 'Microservice diagram shows API Gateway and User Service.'
Query: 'What defines a microservice architecture?'
Answer: 'API Gateway and User Service.'
text_sentence: 'N/A'
image_sentence: 'Microservice diagram shows API Gateway and User Service.'"""

visual_identification_prompt = """Generate a query asking about components in technical diagrams. Follow these steps:
Identify the system/architecture (e.g., transformers, zero-trust networks) from the image summaries relating to the domain of Computer Science.
Frame the query along the lines of: 'What [components/features] are shown in [system] diagrams?'
The answer must list elements explicitly described (e.g., 'encoder layers', 'API Gateway').
If components are unclear, infer common ones but clarify: 'typically include...'.
Strict Rules:
Reject if the summary lacks component descriptions or queries that do not relate to the domain of Computer Science.
Never assume unlabeled elements.
Either make queries general enough for multiple papers to possibly answer it, or make sure its clear which system/ model/ feature/ phenomenon is asked for! 
Examples:
Context: 'Transformer diagram shows stacked encoder/decoder layers with multi-head attention blocks.'
Query: 'What layers define a transformer architecture?'
Answer: 'Encoder/decoder stacks with multi-head attention layers.'
text_sentence: 'N/A'
image_sentence: 'Transformer diagram shows stacked encoder/decoder layers with multi-head attention blocks.'
Context: 'Zero-trust network diagram includes microsegmentation and encrypted tunnels.'
Query: 'What elements define a zero-trust network architecture?'
Answer: 'Microsegmentation and encrypted communication channels.'
text_sentence: 'N/A'
image_sentence: 'Zero-trust network diagram includes microsegmentation and encrypted tunnels.'"""

data_interpretation_prompt = """Generate a query about trends in visualized data over extended periods (e.g., years, iterations). Follow these steps:
Identify the metric (e.g., accuracy, latency, energy efficiency) and time axis (e.g., epochs, years) from the image summary.
Frame the query to ask about long-term behavior relating to the domain of Computer Science:
'What trend does the [metric] plot reveal over [time period]?'
'How does [metric] evolve over time in [system]?'
'What does the long-term [metric] curve suggest about [system] behavior?'
The answer must synthesize multi-phase patterns (e.g., 'plateauing', 'exponential decay', 'linear growth') from the summary.
If trends are unclear or short-term, state: 'No long-term trend described.'
Strict Rules:
Never guess. Reject summaries lacking time-axis descriptions.
Use precise terms like 'asymptotic convergence' or 'periodic fluctuations'.
Either make queries general enough for multiple papers to possibly answer it, or make sure its clear which system/ model/ feature/ phenomenon is asked for! 
Examples:
Context:
Image Summary: "Energy efficiency graph declines by 15% over 5 hardware generations due to thermal throttling."
Query: "How does energy efficiency evolve across hardware iterations?"
Answer: "Efficiency declines by 15% over 5 generations, likely due to thermal limitations.
text_sentence: 'N/A'
image_sentence: 'Energy efficiency graph declines by 15% over 5 hardware generations due to thermal throttling.'"""


image_retrieval_prompt = '''Generate a query requesting a specific visualization. Follow these steps:
1. Identify a phenomenon (e.g., latency distribution) from image summaries.
2. Frame the query as: "[chart/diagram] illustrating [phenomenon]" OR "Show the [trend] of [phenomenon/ concept/ metric] recently."
3. The answer MUST reference explicit summaries (e.g., "histogram," "architecture diagram").

Strict Rules:
- Never invent visualization types.
- Reject queries if no image context exists or unrelated to Computer Science.
- Either make queries general enough for multiple papers to possibly answer it, or make sure its clear which system/ model/ feature/ phenomenon is asked for! 

Examples:
Context: "Figure 3: Training loss curve with epoch vs. accuracy."
Query: "What chart type shows the relationship between training epochs and accuracy?"
Answer: "Line chart plotting accuracy against epochs."
text_sentence: "N/A"
image_sentence: "Figure 3: Training loss curve with epoch vs. accuracy."
Context: "System diagram includes load balancers and worker nodes."
Query: "What diagram type represents the distributed system architecture?"
Answer: "Component diagram with load balancers and worker nodes."
text_sentence: "N/A"
image_sentence: "System diagram includes load balancers and worker nodes."'''

functional_flow_prompt = '''Generate a query about process steps in diagrams. Follow these steps:
1. Identify a system/process (e.g., API request handling) from image summaries.
2. Frame the query as: "What sequence is shown in [system] diagrams?" 
3. The answer MUST list steps (e.g., "1. Request, 2. Authentication, 3. Response") from flowcharts.

Strict Rules:
- Never infer unlabeled steps.
- Reject queries lacking flowchart/sequence context or unrelated to Computer Science.
- Either make queries general enough for multiple papers to possibly answer it, or make sure its clear which system/ model/ feature/ phenomenon is asked for! 

Examples:
Context: "Flowchart: User login → Token generation → Access grant."
Query: "What steps are shown in the authentication flowchart?"
Answer: "1. User login, 2. Token generation, 3. Access grant."
text_sentence: "N/A"
image_sentence: "Flowchart: User login → Token generation → Access grant."
Context: "Diagram illustrates data ingestion → preprocessing → model training."
Query: "What sequence does the ML pipeline diagram depict?"
Answer: "Data ingestion, preprocessing, then model training."
text_sentence: "N/A"
image_sentence: "Diagram illustrates data ingestion → preprocessing → model training."'''

annotation_prompt = '''Generate a query about diagram annotations. Follow these steps:
1. Identify labels/symbols (e.g., arrows, layers) from image summaries.
2. Frame the query as: "What do [annotations] signify in [diagram type]?"
3. The answer MUST cite explicit descriptions (e.g., "Arrows denote data flow").

Strict Rules:
- Never interpret unlabeled elements.
- Reject queries without annotation context or unrelated to Computer Science.
- Either make queries general enough for multiple papers to possibly answer it, or make sure its clear which system/ model/ feature/ phenomenon is asked for! 

Examples:
Context: "Layers are labeled 'Conv1', 'Pool1' in the CNN diagram."
Query: "How are convolutional layers annotated in neural network diagrams?"
Answer: "Labeled as 'Conv1', 'Conv2', etc."
text_sentence: "N/A"
image_sentence: "Layers are labeled 'Conv1', 'Pool1' in the CNN diagram."
Context: "Red dashed lines in the flowchart indicate error handling."
Query: "What do red dashed lines represent in the workflow diagram?"
Answer: "Error handling pathways."
text_sentence: "N/A"
image_sentence: "Red dashed lines in the flowchart indicate error handling."'''

## cross-modal 

process_explanation_prompt = """Generate a query about a workflow. Follow these steps:
Identify the process (e.g., federated learning, gradient clipping) from text and image summaries relating to the domain of Computer Science.
Frame the query as: 'How does [process] work in [domain/ system]?'
The answer MUST fuse modalities (include appropriate info from both modalities).
If context is incomplete, state: 'Insufficient data to explain fully'.
Strict Rules:
Never force synthesis if modalities conflict.
Reject queries that do not relate to the domain of Computer Science.
Example:
Context:
Text: 'Residual connections bypass layers to mitigate vanishing gradients.'
Image: 'Diagram shows skip paths around convolutional blocks.'
Query: 'How do residual connections prevent vanishing gradients?'
Answer: 'Skip paths (diagram) allow gradients to bypass layers (text).'
text_sentence: 'Residual connections bypass layers to mitigate vanishing gradients.'
image_sentence: 'Diagram shows skip paths around convolutional blocks.'"""

conceptual_explanation_prompt = """Generate a query explaining a concept. Follow these steps:
Identify the concept (e.g., self-attention) from text and diagrams relating to the domain of Computer Science.
Frame the query as: 'Explain [concept] in the context of [domain/ system].'
The answer MUST link definitions to visuals (e.g., 'Q/K/V matrices in text → parallel heads in diagrams').
If links are unclear, state: 'Visual evidence is incomplete'.
Strict Rules:
Reject queries that do not relate to the domain of Computer Science.
Only assert connections explicitly described.
Example:
Context:
Text: 'Self-attention computes token interactions via Q/K/V matrices.'
Image: 'Diagram shows parallel attention heads.'
Query: 'Explain how self-attention works in transformers.'
Answer: 'Q/K/V matrices (text) process inputs through parallel heads (diagram).'
text_sentence: 'Self-attention computes token interactions via Q/K/V matrices.'
image_sentence: 'Diagram shows parallel attention heads.'"""


conceptual_alignment_prompt = '''Generate a query comparing textual representations to visual data. Follow these steps:
1. Identify a concept which describes a visual phenomenon (e.g., self-attention).
2. Frame the query as: "How do textual descriptions fit the visual representations of [concept]?"
3. The answer MUST map text (e.g., "Q/K/V matrices") to visuals (e.g., "parallel heads").

Strict Rules:
- Never assert alignment without explicit evidence.
- Reject queries lacking multi-modal context or unrelated to Computer Science.

Examples:
Context: 
Text: "Consensus algorithms require leader election and log replication phases."
Image: "State transition diagram shows candidate → leader → follower states."

Query: "How do textual and visual representations of consensus algorithms align?"
Answer: "Leader election precedes log replication, with node states progressing through candidate, leader, and follower phases during election cycles."
text_sentence: "Consensus algorithms require leader election and log replication phases."
image_sentence: "State transition diagram shows candidate → leader → follower states."'''

In [None]:
# for generating prompts for figure images specifically another prompt:
NEW_IMAGE_EVAL_PROMPT = """Analyze COMPUTER SCIENCE image summaries to create retrieval-friendly queries. Follow these steps:

1. CATEGORY SELECTION & GUIDANCE:
[1] General Visualization (Typical Representations)
- Query Patterns: "How are [CS concepts] typically visualized...", "What diagram format shows [phenomenon]..."
- Example:  
Context: "Neural architecture diagram"
Query: "How are attention mechanisms typically diagrammed?"
Answer: "Usually shown as multi-head blocks with query/key/value matrices"

2 Specific Retrieval (Unique Elements)
- Query Patterns: "Show a [diagram type] of [system] with [features]", "Display the [chart] comparing [metrics]..."
- Example:
Context: "Blockchain consensus flowchart"
Query: "Show a flowchart of PBFT consensus with prepare/commit phases"
Answer: "PBFT diagram showing client request → pre-prepare → prepare → commit"

3 Data Patterns (Visualization Conventions)
- Query Patterns: "What chart type displays [metric] relationships?", "How is [phenomenon] visualized over [time]..."
- Example:
Context: "Latency boxplots across regions"
Query: "What visualization shows statistical latency distributions?"
Answer: "Boxplots comparing median latency and outliers per region"

4 Process Flows (System Sequences)
- Query Patterns: "What flowchart elements show [process] steps?", "Display workflow for [system] with [steps]..."
- Example:
Context: "CI/CD pipeline diagram"
Query: "Show a deployment pipeline with testing and rollback steps"
Answer: "Diagram shows code commit → test suite → staging → production"

5 Annotations (Symbol Conventions)
- Query Patterns: "What annotations indicate [function]?", "How are [elements] labeled in [diagram type]..."
- Example:
Context: "Red arrows in API diagram"
Query: "What annotations show request flows in API diagrams?"
Answer: "Arrows labeled with HTTP methods and endpoints"

6 Comparative (Contrasting Approaches)
- Query Patterns: "Compare visualization of [A] vs [B]", "How do diagrams differ between [X] and [Y]..."
- Example:
Context: "CNN/Transformer comparison figure"
Query: "Compare layer representations in CNN vs transformer diagrams"
Answer: "CNNs show convolution kernels, transformers show attention matrices"

7 Unsuitable (Non-CS/Unclear)

2. QUERY RULES:
- MUST be answerable from image summary alone
- 
- Either:
  a) General: Usable across papers ("How are... typically")
  b) Specific: Unique to context ("Show [exact topic and context]"), but NOT detailed plot description  
- NEVER assume prior article knowledge

3. OUTPUT FORMAT:
chosen_category: [1-7]
query: [precise visual question]
answer: [elements from summary]
text_sentence: "N/A"
image_sentence: [EXACT QUOTE FROM MOST FITTING CONTEXT]

Now analyze:"""

In [None]:
##### FINAL COMBINED PROMPTS

In [None]:
TEXT_EVAL_PROMPT = f"""Analyze this COMPUTER SCIENCE text context and:

1. CATEGORY SELECTION - Choose the MOST appropriate text-based category:
1. Numerical/Quantitative - Metrics/trends (accuracy, FLOPs, energy)
2. Factual/Definitional - Formal CS definitions
3. Compound Analysis - Multi-component systems
4. Research Trends - Variable relationships
5. Compare/Contrast - Method tradeoffs  
6. Limitation Identification - System constraints
7. Unsuitable - Non-CS or incomplete

2. INSTRUCTIONS - Use corresponding prompt:
1. {numerical_prompt}
2. {factual_prompt}

4. {trend_analysis_prompt}
5. {compare_prompt}
6. {limitation_identification_prompt}
7. Unsuitable

3. GENERATION RULES:
- Strict CS focus (ML, systems, security)
- Queries must NOT require visual analysis
- Generate queries that are not too specific for the given article.
- Reject speculative or multi-modal questions

4. OUTPUT FORMAT:
chosen_category: [1-7]
query: [CS question answerable through text]
answer: [technical answer from text context]
text_sentence: [exact quote | "N/A"]      #MAKE SURE THIS IS A DIRECT QUOTE FROM THE CONTEXT
image_sentence: "N/A"

Example:
chosen_category: 4  
query: How does model width impact training stability in transformers?
answer: Wider models show 38% lower gradient variance but require careful initialization.
text_sentence: "Width-1024 layers exhibit σ²=0.12 gradients vs σ²=0.19 for width-512 (Table 3)."   #MAKE SURE THIS IS A DIRECT QUOTE FROM THE CONTEXT
image_sentence: "N/A"

Now process:"""

IMAGE_EVAL_PROMPT = f"""Analyze COMPUTER SCIENCE image summaries and:

1. CATEGORY SELECTION - Choose visual analysis category:
1. Design Features - Architecture components
2. Visual ID - Diagram elements
3. Data Trends - Metric visualizations
4. Image Retrieval - Visualization types  
5. Functional Flow - Process sequences
6. Annotation Analysis - Diagram symbology
7. Unsuitable - Non-CS or unclear

2. INSTRUCTIONS - Use corresponding prompt:
1. {feature_enumeration_prompt}
2. {visual_identification_prompt}
3. {data_interpretation_prompt}
4. {image_retrieval_prompt}
5. {functional_flow_prompt}
6. {annotation_prompt}
7. Unsuitable

3. GENERATION RULES:
- Purely visual analysis (no text references)
- Require explicit elements from image summaries
- Reject queries needing textual explanations
- MAKE SURE THE QUERY MENTIONS THE SPECIFIC SYTEM/ ARCHITECTURE/ METRIC

4. OUTPUT FORMAT:  
chosen_category: [1-7]
query: [visual analysis question]
answer: [elements from image summaries]
text_sentence: "N/A"
image_sentence: [exact image summary quote]      #MAKE SURE THIS IS A DIRECT QUOTE FROM THE CONTEXT

Example:
chosen_category: 2
query: What layers are shown in the neural architecture diagram?
answer: Input embedding layer, four transformer blocks, classification head.
image_sentence: "Diagram labels: Embedding Layer → Transformer Block ×4 → CLS Head"

Now process:"""

CROSS_MODAL_EVAL_PROMPT = f"""Analyze BOTH text and images in this CS context and:

1. CATEGORY SELECTION - Choose integration category:
1. Method Validation - Technique-result alignment
2. Process Explanation - Workflow integration  
3. Conceptual Synthesis - Theory-visual links
4. Representation Alignment - Text-diagram consistency
5. Unsuitable - Single-modality sufficient

2. INSTRUCTIONS - Use corresponding prompt:
1. {process_explanation_prompt}
2. {conceptual_explanation_prompt}

4. Unsuitable

3. GENERATION RULES:
- MUST require both modalities to answer
- Bridge text and visual elements
- Flag modality conflicts
- Reject questions answerable with one source

4. OUTPUT FORMAT:
chosen_category: [1-4]  
query: [multi-modal CS question]
answer: [synthesis of both sources]
text_sentence: [relevant text quote | "N/A"]      #MAKE SURE THIS IS A DIRECT QUOTE FROM THE TEXT CONTEXT
image_sentence: [relevant image quote | "N/A"]     #MAKE SURE THIS IS A DIRECT QUOTE FROM THE IMAGE CONTEXT

Example:
chosen_category: 3
query: How do textual descriptions of attention mechanisms align with their visual representations?
answer: Text describes parallel attention heads processing Q/K/V vectors (Vaswani 2017), while diagrams show eight-headed blocks with interleaved connections.
text_sentence: "Multi-head attention enables parallel processing of different representation subspaces."
image_sentence: "Figure 2: 8 attention heads with cross-connecting weights."

Now process:"""

### Generating Query & Answer

In [None]:
rndm_key = get_rndm_key()
all_contexts = get_contexts_for_key(rndm_key)
rndm_full_text = all_contexts['full_text'][0]
rndm_sumaries = all_contexts['image_descriptions']

In [None]:
# text-only
gene_query_t = generate_test_query(
    instruction=TEXT_EVAL_PROMPT,
    text_context=rndm_full_text,
    #image_descriptions=rndm_sumaries
)
print(gene_query_t)

In [None]:
# summaries-only
gene_query_i = generate_test_query(
    instruction=NEW_IMAGE_EVAL_PROMPT,
    #text_context=rndm_full_text,
    image_descriptions=rndm_sumaries
)
print(gene_query_i)

In [None]:
#cross-modal
gene_query = generate_test_query(
    instruction=CROSS_MODAL_EVAL_PROMPT,
    text_context=rndm_full_text,
    image_descriptions=rndm_sumaries
)
print(gene_query)

In [None]:
#can be replaced with actual 4o call to extend with GT directly instead of first filtering the response of first call
augmented_df_one = augment_with_ground_truth(all_chunks_df, captions,one_new_df)   
#augmented_df_one

In [None]:
augmented_df_one['type'] = 'cross-modal'    #'text-only' #'image-only' #'cross-modal'

In [None]:
#second call to get optimal GT elements 

In [None]:
final_df = process_dataframe_final(df_new)


###  code: functions and helpers

In [None]:
#functions needed to run to the notebook:

In [None]:

import json
import time
from typing import Optional, List

def generate_test_query(
    instruction: str,
    text_context: Optional[str] = None,
    image_descriptions: Optional[List[str]] = None
) -> str:
    content = []
    text_parts = [
    "INSTRUCTION: Generate a test query and the according reference answer for a multi-modal RAG system FOR THE DOMAIN OF COMPUTER SCIENCE based on:",
    f"- User instruction: {instruction}",
    "CONTEXT GROUND TRUTH:"
    ]
    
    if text_context:
        text_parts.append(f"Full article text: {text_context}")
    
    if image_descriptions:
        text_parts.append("Image descriptions from article:")
        for i, desc in enumerate(image_descriptions, 1):
            text_parts.append(f"{i}. {desc}")
    
    content.append({
        "type": "text",
        "text": "\n".join(text_parts)
    })
    
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": content}],
        max_tokens=1000,
        temperature=0.7
    )
    
    return response.choices[0].message.content


def ask_4o(row):
    cur_key = row['key']
    cur_query = row['query']
    cur_answer = row['answer']
    type_ = row['type']
    
    both_contexts = get_sorted_contexts_for_key(cur_key)
    text_chunks = both_contexts.get('text_chunks', [])
    img_summaries = both_contexts.get('image_contexts', [])
    
    text_context_str = "\n".join([f"{idx}: {chunk}" for idx, chunk in enumerate(text_chunks)])
    img_context_str = "\n".join([f"Path: {img['image_path']}\nSummary: {img['summary']}" for img in img_summaries])
    
    base_instruction = """Analyze this computer science query and reference answer which will be used
    for testing a retrieval system. You MUST follow these rules:
    1. Be strictly factual and objective
    2. Only select items directly supporting both query AND answer
    3. Never invent or guess information
    4. If multiple candidates exist, choose the BEST match using cross-modal understanding
    
    QUERY: {query}
    ANSWER: {answer}
    """
    
    if type_ == 'image':
        final_prompt = f"""{base_instruction}
        
        IMAGE SUMMARIES FROM SOURCE ARTICLE:
        {img_context_str}
        
        {EX_OUTPUT_image}
        """
        
    elif type_ == 'text':
        final_prompt = f"""{base_instruction}
        
        TEXT CHUNKS FROM SOURCE ARTICLE:
        {text_context_str}
        
        {EX_OUTPUT_text}
        """
        
    elif type_ == 'cross-modal':
        final_prompt = f"""{base_instruction}
        
        TEXT CHUNKS FROM SOURCE ARTICLE:
        {text_context_str}
        
        IMAGE SUMMARIES FROM SOURCE ARTICLE:
        {img_context_str}
        
        {EX_OUTPUT_both}
        """
    else:
        raise ValueError(f"Invalid type: {type_}")
    
    final_prompt = final_prompt.format(query=cur_query, answer=cur_answer)
    
    final_prompt += "\n\nYOUR RESPONSE MUST USE EXACTLY THIS FORMAT:\n" + \
        json.dumps({"choices": [{"text_chunk_index": int, "image_path": str}]}) + \
        "\nOnly include actually matched items!"
    
    return call_4o_most_appr(final_prompt)

def call_4o_most_appr(prompt):
    response = client.chat.completions.create(
        model=MODEL,
        messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
        max_tokens=300,
        temperature=0.1 
    )
    
    try:
        return json.loads(response.choices[0].message.content)
    except json.JSONDecodeError:
        return {"error": "Invalid JSON response", "raw": response.choices[0].message.content}

In [None]:
import pandas as pd
import re
from PIL import Image
import random

df_ten = pd.read_csv("data/tenk_subset.csv")   #file with keys from a subset

tenk_keys = df_ten["doc_id"].unique().tolist()


def get_rndm_key():
    return random.choice(tenk_keys)

#getting optimal GT elements to compare to previous elements and if match fine if not delete or choose more fitting one youself
def process_dataframe_final(df):
    """Final debugged processor with proper string formatting"""
    df = df.copy()
    df['clean_type'] = df['type'].str.lower().str.strip() \
        .replace({'text-only': 'text', 'image-only': 'image', 'crossmodal': 'cross-modal'})
    
    valid_df = df[df['clean_type'].isin(['text', 'image', 'cross-modal'])].copy()
    
    
    def get_prompt(row):
        context = get_sorted_contexts_for_key(row['key'])
        base = f"""Analyze this query and answer pair from computer science research papers and
FIND THE MOST APPROPRIATE CONTEXT TO ANSWER THE QUESTION.
Query: {row['query']}
Answer: {row['answer']}

You MUST respond in EXACTLY this format:
"""
        if row['clean_type'] == 'image':
            images = "\n".join([f"- {img['image_path']}: {img.get('summary', '')}"  
                              for img in context.get('image_contexts', [])])
            return f"""{base}image path: [exact_filename_from_below]
Available images:
{images}"""

        elif row['clean_type'] == 'text':
            chunks = "\n".join([f"{idx}: {chunk}"
                               for idx, chunk in enumerate(context.get('text_chunks', []))])
            return f"""{base}chunk index: [number_from_below]
Available chunks:
{chunks}"""

        else:  #cross
            images = "\n".join([f"- {img['image_path']}: {img.get('summary', '')}"  
                              for img in context.get('image_contexts', [])])
            chunks = "\n".join([f"{idx}: {chunk}"
                               for idx, chunk in enumerate(context.get('text_chunks', []))])
            return f"""{base}image path: [exact_filename]
chunk index: [number]

Available images:
{images}

Available chunks:
{chunks}"""

    results = []
    for idx, row in tqdm(valid_df.iterrows(), total=len(valid_df)):
        try:
            response = client.chat.completions.create(
                model=MODEL,
                messages=[{"role": "user", "content": get_prompt(row)}],
                temperature=0.1,
                max_tokens=200
            ).choices[0].message.content

            image_path, chunk_index = None, None
            for line in response.split('\n'):
                clean_line = line.lower().strip()
                if 'image path:' in clean_line:
                    image_path = clean_line.split(':')[-1].strip()
                if 'chunk index:' in clean_line:
                    chunk_index = clean_line.split(':')[-1].strip()

            results.append({
                'raw_response': response,
                'image_path': image_path,
                'chunk_index': int(chunk_index) if chunk_index and chunk_index.isdigit() else None
            })
            
        except Exception as e:
            results.append({
                'raw_response': f"ERROR: {str(e)}",
                'image_path': None,
                'chunk_index': None
            })

    return pd.concat([
        valid_df.reset_index(drop=True),
        pd.DataFrame(results)
    ], axis=1)

def fix_extracted_values_fast(df):
    unique_keys = df['key'].unique()
    context_cache = {}
    
    for key in unique_keys:
        context = get_sorted_contexts_for_key(key)
        image_paths = [img['image_path'] for img in context.get('image_contexts', [])]
        text_chunks = context.get('text_chunks', [])
        
        context_cache[key] = {
            'path_lookup': {path.lower(): path for path in image_paths},
            'max_chunk_idx': len(text_chunks) - 1
        }
    
    df = df.copy()
    df['image_path'] = (
        df.groupby('key', group_keys=False)['image_path']
        .transform(lambda x: x.str.lower().map(context_cache[x.name]['path_lookup'])))
    
    df['chunk_index'] = (
        df['chunk_index']
        .astype(str)
        .str.extract(r'(\d+)', expand=False)
        .astype('Int64')
    )
    
    df['max_chunk'] = df['key'].map(lambda k: context_cache[k]['max_chunk_idx'])
    df['chunk_index'] = df['chunk_index'].mask(
        (df['chunk_index'].lt(0)) | 
        (df['chunk_index'].gt(df['max_chunk'])),
        pd.NA
    )
    df = df.drop(columns=['max_chunk'])
    
    return df

def fix_extracted_values_final(df):
    path_maps = {}
    for key in df['key'].unique():
        context = get_sorted_contexts_for_key(key)
        image_contexts = context.get('image_contexts', [])
        path_maps[key] = {
            img['path'].lower(): img['path'] 
            for img in image_contexts
            if 'path' in img  
        }
    
    df = df.copy()
    if 'image_path' in df:
        df['image_path'] = df.apply(
            lambda row: path_maps.get(row['key'], {}).get(
                str(row['image_path']).lower().strip(), 
                None  
            ),
            axis=1
        )
    
    if 'chunk_index' in df:
        df['chunk_index'] = (
            df['chunk_index']
            .astype(str)
            .str.extract(r'(\d+)', expand=False)  
            .astype('Int64')  
        )
        
        chunk_counts = {
            key: len(get_sorted_contexts_for_key(key).get('text_chunks', []))
            for key in df['key'].unique()
        }
        df['chunk_index'] = df.apply(
            lambda row: (
                row['chunk_index'] 
                if pd.notna(row['chunk_index']) and 
                   0 <= row['chunk_index'] < chunk_counts.get(row['key'], 0)
                else None
            ),
            axis=1
        )
    
    return df


def process_dataframe_final(df):

    df = df.copy()
    df['clean_type'] = df['type'].str.lower().str.strip() \
        .replace({'text-only': 'text', 'image-only': 'image', 'crossmodal': 'cross-modal'})
    
    valid_df = df[df['clean_type'].isin(['text', 'image', 'cross-modal'])].copy()
    
    def get_prompt(row):
        context = get_sorted_contexts_for_key(row['key'])
        base = f"""Analyze this query and answer pair from computer science research papers and
FIND THE MOST APPROPRIATE CONTEXT TO ANSWER THE QUESTION.
Query: {row['query']}
Answer: {row['answer']}

You MUST respond in EXACTLY this format:
"""
        if row['clean_type'] == 'image':
            images = "\n".join([f"- {img['image_path']}: {img.get('summary', '')}" 
                              for img in context.get('image_contexts', [])])
            return f"""{base}image path: [exact_filename_from_below]
Available images:
{images}"""

        elif row['clean_type'] == 'text':
            chunks = "\n".join([f"{idx}: {chunk}"
                               for idx, chunk in enumerate(context.get('text_chunks', []))])
            return f"""{base}chunk index: [number_from_below]
Available chunks:
{chunks}"""

        else: 
            images = "\n".join([f"- {img['image_path']}: {img.get('summary', '')}"  
                              for img in context.get('image_contexts', [])])
            chunks = "\n".join([f"{idx}: {chunk}"
                               for idx, chunk in enumerate(context.get('text_chunks', []))])
            return f"""{base}image path: [exact_filename]
chunk index: [number]

Available images:
{images}

Available chunks:
{chunks}"""

    results = []
    for idx, row in tqdm(valid_df.iterrows(), total=len(valid_df)):
        try:
            response = client.chat.completions.create(
                model=MODEL,
                messages=[{"role": "user", "content": get_prompt(row)}],
                temperature=0.1,
                max_tokens=200
            ).choices[0].message.content

            image_path, chunk_index = None, None
            for line in response.split('\n'):
                clean_line = line.lower().strip()
                if 'image path:' in clean_line:
                    image_path = clean_line.split(':')[-1].strip()
                if 'chunk index:' in clean_line:
                    chunk_index = clean_line.split(':')[-1].strip()

            results.append({
                'raw_response': response,
                'image_path': image_path,
                'chunk_index': int(chunk_index) if chunk_index and chunk_index.isdigit() else None
            })
            
        except Exception as e:
            results.append({
                'raw_response': f"ERROR: {str(e)}",
                'image_path': None,
                'chunk_index': None
            })

    return pd.concat([
        valid_df.reset_index(drop=True),
        pd.DataFrame(results)
    ], axis=1)

#this includes the option to manually input info
def parse_response_to_dataframe(response_str, key, existing_df=None):
    responses = [response_str] if isinstance(response_str, str) else response_str
    keys = [key] if isinstance(key, str) else key
    
    if len(responses) != len(keys):
        raise ValueError("Number of responses and keys must match")

    pattern = r"""
        chosen\s*category:\s*(?P<chosen_category>\d+).*?
        query:\s*(?P<query>.*?)(?=\s*answer:|$)
        .*?answer:\s*(?P<answer>.*?)(?=\s*text_sentence:|$)
        .*?text_sentence:\s*(?P<text_sentence>.*?)(?=\s*image_sentence:|$)
        .*?image_sentence:\s*(?P<image_sentence>.*?)(?=\s*\w+:|$)
    """.replace('\n', '')

    parsed_data = []
    
    for resp, k in zip(responses, keys):
        clean_resp = re.sub(r'(\n\s*)+', '\n', resp.strip())
        match = re.search(pattern, clean_resp, re.DOTALL | re.IGNORECASE | re.VERBOSE)
        
        if match:
            entry = match.groupdict()
            entry = {k: v.strip() if v else 'N/A' for k, v in entry.items()}
            entry['key'] = k
        else:
            print(f"\nFailed to parse response. Please enter fields manually for key: {k}")
            print("Original text snippet:", clean_resp[:200] + "...\n")
            
            entry = {
                'chosen_category': input("Category (1-9): ").strip(),
                'query': input("Query: ").strip(),
                'answer': input("Answer: ").strip(),
                'text_sentence': input("Text sentence (or 'N/A'): ").strip(),
                'image_sentence': input("Image sentence (or 'N/A'): ").strip(),
                'key': k
            }

        try:
            for field in ['text_sentence', 'image_sentence']:
                entry[field] = re.sub(r'^["\']|["\']$', '', entry[field])

            if not entry['chosen_category'].isdigit():
                raise ValueError("Category must be a number 1-9")
            if int(entry['chosen_category']) not in range(1,10):
                raise ValueError("Category must be between 1-9")
            if not entry['query'] or not entry['answer']:
                raise ValueError("Query and Answer are required")
            if entry['text_sentence'] == 'N/A' and entry['image_sentence'] == 'N/A':
                raise ValueError("At least one sentence (text or image) required")

        except ValueError as e:
            print(f"\nValidation error: {e}")
            print("Please fix the entry:")
            entry = {
                'chosen_category': input(f"Category [current: {entry['chosen_category']}]: ") or entry['chosen_category'],
                'query': input(f"Query [current: {entry['query']}]: ") or entry['query'],
                'answer': input(f"Answer [current: {entry['answer']}]: ") or entry['answer'],
                'text_sentence': input(f"Text sentence [current: {entry['text_sentence']}]: ") or entry['text_sentence'],
                'image_sentence': input(f"Image sentence [current: {entry['image_sentence']}]: ") or entry['image_sentence'],
                'key': k
            }

        parsed_data.append(entry)

    new_df = pd.DataFrame(parsed_data)[[
        'chosen_category', 'query', 'answer', 
        'text_sentence', 'image_sentence', 'key'
    ]]
    new_df['chosen_category'] = pd.to_numeric(new_df['chosen_category'], errors='coerce')

    if existing_df is not None:
        if not isinstance(existing_df, pd.DataFrame):
            raise TypeError("existing_df must be a pandas DataFrame")
        
        required_columns = {'chosen_category', 'query', 'answer', 
                           'text_sentence', 'image_sentence', 'key'}
        missing_cols = required_columns - set(existing_df.columns)
        if missing_cols:
            raise ValueError(f"existing_df missing columns: {missing_cols}")
        
        return pd.concat([existing_df, new_df], ignore_index=True)
    
    return new_df



def parse_response_to_dataframe(response_str, key, existing_df=None):
    """
    Robust parser that handles markdown formatting and various field arrangements
    """
    responses = [response_str] if isinstance(response_str, str) else response_str
    keys = [key] if isinstance(key, str) else key
    
    if len(responses) != len(keys):
        raise ValueError("Number of responses and keys must match")

    field_pattern = re.compile(
        r'(?:^|\n)\s*[*#]*(chosen_category|query|answer|text_sentence|image_sentence)[*#]*[\s:]+',
        re.IGNORECASE
    )

    parsed_data = []
    
    for resp, k in zip(responses, keys):
        clean_resp = re.sub(r'\*\*|#{2,}', '', resp)  
        clean_resp = re.sub(r'\n\s+', '\n', clean_resp) 
        clean_resp = re.sub(r'(\w)\n(\w)', r'\1 \2', clean_resp) 
        clean_resp = clean_resp.strip()

        fields = {}
        last_field = None
        for part in field_pattern.split(clean_resp):
            part = part.strip()
            if not part:
                continue
            
            if part.lower() in ['chosen_category', 'query', 'answer', 
                               'text_sentence', 'image_sentence']:
                last_field = part.lower()
            elif last_field:
                value = re.split(r'\n\s*(?=\S+:)', part)[0]
                value = re.sub(r'^["\']|["\']$', '', value.strip())
                fields[last_field] = value
                last_field = None

        required = ['chosen_category', 'query', 'answer']
        if not all(f in fields for f in required):
            print(f"\nMissing fields in response for key {k}. Original snippet:")
            print(clean_resp[:200] + "...")
            print("Please enter missing fields manually:")
            
            fields = {
                'chosen_category': input(f"Category (1-9) [detected: {fields.get('chosen_category')}]: ") 
                                or fields.get('chosen_category'),
                'query': input(f"Query [detected: {fields.get('query')}]: ") 
                         or fields.get('query'),
                'answer': input(f"Answer [detected: {fields.get('answer')}]: ") 
                          or fields.get('answer'),
                'text_sentence': input(f"Text sentence [detected: {fields.get('text_sentence')}]: ") 
                             or fields.get('text_sentence', 'N/A'),
                'image_sentence': input(f"Image sentence [detected: {fields.get('image_sentence')}]: ") 
                              or fields.get('image_sentence', 'N/A'),
            }

        entry = {
            'chosen_category': str(fields.get('chosen_category', '')).strip(),
            'query': fields.get('query', '').replace('\n', ' ').strip(),
            'answer': fields.get('answer', '').replace('\n', ' ').strip(),
            'text_sentence': fields.get('text_sentence', 'N/A'),
            'image_sentence': fields.get('image_sentence', 'N/A'),
            'key': k
        }

        while True:
            try:
                if not entry['chosen_category'].isdigit():
                    raise ValueError("Category must be a number")
                if int(entry['chosen_category']) not in range(1,10):
                    raise ValueError("Category must be 1-9")
                if len(entry['query']) < 10 or len(entry['answer']) < 10:
                    raise ValueError("Query/Answer too short")
                break
            except ValueError as e:
                print(f"Validation error: {e}")
                entry['chosen_category'] = input(f"Category (1-9) [current: {entry['chosen_category']}]: ") or entry['chosen_category']
                entry['query'] = input(f"Query [current: {entry['query']}]: ") or entry['query']
                entry['answer'] = input(f"Answer [current: {entry['answer']}]: ") or entry['answer']

        parsed_data.append(entry)

    new_df = pd.DataFrame(parsed_data)[[
        'chosen_category', 'query', 'answer', 
        'text_sentence', 'image_sentence', 'key'
    ]]
    new_df['chosen_category'] = pd.to_numeric(
        new_df['chosen_category'], 
        errors='coerce', 
        downcast='integer'
    )

    if existing_df is not None:
        return pd.concat([existing_df, new_df], ignore_index=True).drop_duplicates()
    
    return new_df


def get_contexts_for_key(key):
    full_text = df_full_texts[df_full_texts['key'] == key]['full_text'].tolist()

    image_descriptions = image_summary_df[image_summary_df['doc_id'] == key]['image_summary'].tolist()

    return {
        'full_text': full_text,
        'image_descriptions': image_descriptions
    }


def get_sorted_contexts_for_key(key):
    text_chunks = all_chunks_df[all_chunks_df['doc_id'] == key] \
        .sort_values('chunk_index') \
        .apply(lambda row: {
            'chunk_index': row['chunk_index'],
            'chunk_content': row['text_content']
        }, axis=1) \
        .tolist()

    image_contexts = image_summary_df[image_summary_df['doc_id'] == key] \
        .apply(lambda row: {
            'image_path': row['original_image_path'].split("/")[-1], 
            'summary': row['image_summary']
        }, axis=1) \
        .tolist()

    return {
        'text_chunks': text_chunks,
        'image_contexts': image_contexts
    }



def process_image_path(path: str) -> str:
    if not path.startswith('data/'):
        path = f"data/399k_imgs/{path.split('/')[-1]}"
    Image.open(path).show()
    return " "

def augment_with_ground_truth(df1: pd.DataFrame,
                             df2: pd.DataFrame,
                             df3: pd.DataFrame) -> pd.DataFrame:

    
    df2 = df2.copy()
    df2['doc_key'] = df2['image_path'].str.extract(r'.*/(.*?)_', expand=False)
    
    text_lookup = df1.groupby('doc_id')['text_content'].agg(list).to_dict()
    image_lookup = (df2.groupby('doc_key')
                    .apply(lambda x: list(zip(x['image_path'], x['generated_caption'])))
                    .to_dict())

    def find_text_matches(row):
        if pd.isna(row['text_sentence']) or row['text_sentence'] == 'N/A':
            return []
        
        chunks = text_lookup.get(row['key'], [])
        return [
            idx for idx, content in enumerate(chunks)
            if row['text_sentence'][:25] in content
        ]
    
    def find_image_matches(row):
        if pd.isna(row['image_sentence']) or row['image_sentence'] == 'N/A':
            return []
        
        images = image_lookup.get(row['key'], [])
        return [
            f"{path}" #: {caption}
            for path, caption in images
            if row['image_sentence'][:25] in caption
        ]
    
    df3 = df3.copy()
    df3['gt_chunk_index'] = df3.apply(find_text_matches, axis=1)
    df3['gt_summary'] = df3.apply(find_image_matches, axis=1)
    
    return df3

def enrich_with_context(df):
    df = df.copy()
    df['text_content'] = None
    df['image_summary'] = None
    df['processed_image'] = None
    
    for idx, row in df.iterrows():
        key = row['key']
        context = get_sorted_contexts_for_key(key)
        
        if pd.notna(row.get('chunk_index')):
            try:
                chunk = next(c for c in context['text_chunks'] 
                           if c['chunk_index'] == row['chunk_index'])
                df.at[idx, 'text_content'] = chunk['chunk_content']
            except (StopIteration, KeyError):
                pass
        
        if pd.notna(row.get('image_path')):
            try:
                img = next(i for i in context['image_contexts']
                         if i['image_path'] == row['image_path'])
                df.at[idx, 'image_summary'] = img['summary']
            except (StopIteration, KeyError):
                pass
    
    return df

In [None]:
### parse output
#one_new_df = parse_response_to_dataframe(gene_query_t, rndm_key)   #, new_df
#one_new_df = parse_response_to_dataframe(gene_query_i, rndm_key)   #, new_df
one_new_df = parse_response_to_dataframe(gene_query, rndm_key)   #, new_df