# Take merged verification files and merge them with the rollout files

In [3]:
import re
import json

model_name = "o4-mini"
dataset_name = "AI2D"
merged_verification_file = f"/mnt/fast10/brandon/mmr_rollout_data/merged_verification_files/{dataset_name}_final_verification_processed_{model_name}.jsonl"
# output_path = f"/mnt/fast10/brandon/mmr_rollout_data/processed_full_verification_files/{dataset_name}_final_mc_and_verification_merged_{model_name}.jsonl"

# Extract verification_solutions from merged file
verification_solutions = []
solution_pattern = re.compile(r'<solution>(.*?)</solution>', re.DOTALL)

with open(merged_verification_file, 'r') as f:
    for line_num, line in enumerate(f, 1):
        item = json.loads(line)
        try:
            text = item["body"]["messages"][0]["content"][0]["text"]
            # Find all matches and get the second one
            matches = solution_pattern.findall(text)
            if len(matches) >= 2:
                solution_text = matches[1].strip()  # Get second occurrence
                if solution_text:  # Only add non-emptyverification_solutions 
                    verification_solutions.append({
                        "custom_id": item.get("custom_id", "ERROR: custom_id not found"),
                        "unique_key": solution_text
                    })
            elif len(matches) == 1:
                print(f"Warning: Only one <solution> tag found in line {line_num}")
            else:
                print(f"Warning: No <solution> tags found in line {line_num}")
        except (KeyError, IndexError, TypeError) as e:
            print(f"Error accessing text in line {line_num}: {e}")

print(f"Extracted {len(verification_solutions)} valid verification_solutions")

Extracted 24595 valid verification_solutions


In [15]:
# Load flattened file once into memory
full_raw_rollout_data_file = "/mnt/fast10/brandon/mmr_rollout_data/flattened_rollout_files/AI2D_flattened.jsonl"
full_raw_rollout_data_array = []

with open(full_raw_rollout_data_file, 'r') as f:
    for line in f:
        item = json.loads(line)
        full_raw_rollout_data_array.append({
            "response": item.get("response", ""),
            "response_uid": item.get("uid", ""),
            "image_path": item.get("image_path", ""),
        })

print(f"Loaded {len(full_raw_rollout_data_array)} items from flattened file")

Loaded 25557 items from flattened file


In [16]:
# full_raw_rollout_data_array[0].keys()

print(full_raw_rollout_data_array[0].keys())


for k,v in full_raw_rollout_data_array[0].items():
    print(k)
    print(v)

dict_keys(['response', 'response_uid', 'image_path'])
response
[Visual Elements]
<step_1>
Identify all organisms in the food web: short-eared owl, vole, meadow pippit, emperor moth larvae, red grouse, heather, fox, brown hare, red kite or hen harrier.
</step_1>
<step_2>
Note the arrows indicating feeding relationships: arrows point from food to consumer.
</step_2>
<step_3>
Observe that the fox eats: red grouse, brown hare, and vole (arrows pointing from each of these to fox).
</step_3>
<step_4>
Determine what eats fox: no arrows point to fox, indicating it is a top predator.
</step_4>
<step_5>
Determine what else eats red grouse: Arrow from red grouse to fox, and red kite or hen harrier. So red grouse is eaten by fox and red kite/hen harrier.
</step_5>
<step_6>
Determine who eats meadow pippit: Arrow from meadow pippit to fox only.
</step_6>
<step_7>
Note which options are present: more grouse, more pippit, less grouse, less owl.
</step_7>

[Reasoning]
<step_1>
The question asks what w

In [None]:
# collision_errors = []
# no_matches_array = []
# for sol in verification_solutions:
#     unique_key = sol["unique_key"]
    
#     # Find all matches
#     matches = [item for item in full_raw_rollout_data_array if item["response"].strip() == unique_key]
    
#     if len(matches) > 1:
#         collision_errors.append({
#             "solution_unique_key": unique_key,
#             "solution_custom_id": sol["custom_id"],
#             "matches": matches
#         })
#     elif len(matches) == 0:
#         no_matches_array.append({
#             "solution_unique_key": unique_key,
#             "solution_custom_id": sol["custom_id"],
#         })

# # Report collision errors
# if collision_errors:
#     print(f"\n🚨 COLLISION ERRORS FOUND: {len(collision_errors)} unique_keys have multiple matches!")
#     for error in collision_errors:
#         print(f"\nCollision for rollout_uid: {error['rollout_uid']}")
#         print(f"solution_unique_key: {error['solution_unique_key'][:100]}...")
#         print(f"Found {len(error['matches'])} matches:")
#         for match in error['matches']:
#             print(f"  - uid: {match['uid']}, response: {match['response'][:50]}...")
    
#     raise ValueError(f"{len(collision_errors)} collision errors found. See details above.")
# else:
#     print(f"\n✅ No collisions found! All {len(verification_solutions)} solutions have at most one match.")

# if no_matches_array:
#     print(f"\n🚨 NO MATCHES FOUND: {len(no_matches_array)} unique_keys have no matches!")
#     for error in no_matches_array:
#         print(f"\nNo match found for solution_custom_id: {error['solution_custom_id']}")
#         print(f"solution_unique_key: {error['solution_unique_key'][:100]}...")
    
#     raise ValueError(f"{len(no_matches_array)} no matches found. See details above.")
# else:
#     print(f"\n✅ No no matches found! All {len(verification_solutions)} solutions have at least one match.")


✅ No collisions found! All 24595 solutions have at most one match.

✅ No no matches found! All 24595 solutions have at least one match.


In [None]:
def check_for_collisions(verification_solutions, full_raw_rollout_data_array):
    """
    Check for collisions and missing matches between verification solutions and rollout data.
    
    Returns:
        tuple: (collision_errors, no_matches_array, has_collisions, has_no_matches)
    """
    collision_errors = []
    no_matches_array = []
    
    for sol in verification_solutions:
        unique_key = sol["unique_key"]
        
        # Find all matches
        matches = [item for item in full_raw_rollout_data_array if item["response"].strip() == unique_key]
        
        if len(matches) > 1:
            collision_errors.append({
                "solution_unique_key": unique_key,
                "solution_custom_id": sol["custom_id"],
                "matches": matches
            })
        elif len(matches) == 0:
            no_matches_array.append({
                "solution_unique_key": unique_key,
                "solution_custom_id": sol["custom_id"],
            })

    # Report collision errors
    has_collisions = len(collision_errors) > 0
    has_no_matches = len(no_matches_array) > 0
    
    if has_collisions:
        print(f"\n�� COLLISION ERRORS FOUND: {len(collision_errors)} unique_keys have multiple matches!")
        for error in collision_errors:
            print(f"\nCollision for solution_custom_id: {error['solution_custom_id']}")
            print(f"solution_unique_key: {error['solution_unique_key'][:100]}...")
            print(f"Found {len(error['matches'])} matches:")
            for match in error['matches']:
                print(f"  - response_uid: {match['response_uid']}, response: {match['response'][:50]}...")
    else:
        print(f"\n✅ No collisions found! All {len(verification_solutions)} solutions have at most one match.")

    if has_no_matches:
        print(f"\n🚨 NO MATCHES FOUND: {len(no_matches_array)} unique_keys have no matches!")
        for error in no_matches_array:
            print(f"\nNo match found for solution_custom_id: {error['solution_custom_id']}")
            print(f"solution_unique_key: {error['solution_unique_key'][:100]}...")
    else:
        print(f"\n✅ No no matches found! All {len(verification_solutions)} solutions have at least one match.")
    
    return collision_errors, no_matches_array, has_collisions, has_no_matches

# Test the function
collision_errors, no_matches_array, has_collisions, has_no_matches = check_for_collisions(
    verification_solutions, full_raw_rollout_data_array
)

In [14]:
def merge_rollout_and_verification_data(verification_solutions, full_raw_rollout_data_array, output_path):
    """
    Merge verification solutions with rollout data when no collisions are detected.
    Uses full_raw_rollout_data_array as the reference point.
    
    Args:
        verification_solutions: List of verification solution dicts
        full_raw_rollout_data_array: List of rollout data dicts
        output_path: Path to save the merged output file
    """
    # First check for collisions
    collision_errors, no_matches_array, has_collisions, has_no_matches = check_for_collisions(
        verification_solutions, full_raw_rollout_data_array
    )
    
    if has_collisions:
        raise ValueError(f"{len(collision_errors)} collision errors found. Cannot proceed with merge.")
    
    # Create lookup dictionary for verification solutions
    verification_lookup = {sol["unique_key"]: sol for sol in verification_solutions}
    
    # Initialize trackers
    rollouts_without_verification = 0
    rollouts_with_verification = 0
    invalid_verification_values = []
    
    # Merge the data - iterate over rollout data as reference
    merged_data = []
    for rollout_item in full_raw_rollout_data_array:
        response_key = rollout_item["response"].strip()
        
        # Check if this rollout has a corresponding verification solution
        if response_key in verification_lookup:
            # Found matching verification solution
            verification_sol = verification_lookup[response_key]
            
            # Validate o4-mini_isVerified value
            is_verified_value = verification_sol.get("o4-mini_isVerified")
            if is_verified_value not in [True, False]:
                invalid_verification_values.append({
                    "custom_id": verification_sol["custom_id"],
                    "o4-mini_isVerified_value": is_verified_value,
                    "type": type(is_verified_value).__name__
                })
            
            merged_item = {
                "verification_custom_id": verification_sol["custom_id"],
                "response_uid": rollout_item["response_uid"],
                "rollout_response": rollout_item["response"],
                "rollout_image_path": rollout_item["image_path"],
                "o4-mini_verification_solution": verification_sol["verification_response"],
                "o4-mini_isVerified": is_verified_value
            }
            rollouts_with_verification += 1
        else:
            # No matching verification solution found
            merged_item = {
                "verification_custom_id": None,
                "response_uid": rollout_item["response_uid"],
                "rollout_response": rollout_item["response"],
                "rollout_image_path": rollout_item["image_path"],
                "o4-mini_verification_solution": None,
                "o4-mini_isVerified": None
            }
            rollouts_without_verification += 1
        
        merged_data.append(merged_item)
    
    # Save to file
    with open(output_path, 'w') as f:
        for item in merged_data:
            f.write(json.dumps(item) + '\n')
    
    print(f"\n✅ Successfully merged {len(merged_data)} items to {output_path}")
    print(f"📊 Summary:")
    print(f"   - Rollouts with verification: {rollouts_with_verification}")
    print(f"   - Rollouts without verification: {rollouts_without_verification}")
    print(f"   - Total rollouts: {len(merged_data)}")
    
    # Report invalid verification values
    if invalid_verification_values:
        print(f"\n⚠️  INVALID VERIFICATION VALUES FOUND: {len(invalid_verification_values)} items")
        print("These custom_ids have invalid o4-mini_isVerified values (not True/False):")
        for item in invalid_verification_values:
            print(f"   - custom_id: {item['custom_id']}, value: {item['value']} (type: {item['type']})")
    else:
        print(f"\n✅ All verification values are valid (True/False)")
    
    return merged_data

# Define output path and run merge
output_path = f"/mnt/fast10/brandon/mmr_rollout_data/processed_full_verification_files/{dataset_name}_final_mc_and_verification_merged_{model_name}.jsonl"

merged_data = merge_rollout_and_verification_data(
    verification_solutions, 
    full_raw_rollout_data_array, 
    output_path
)


✅ No collisions found! All 24595 solutions have at most one match.

✅ No no matches found! All 24595 solutions have at least one match.


KeyError: 'response_uid'