In [1]:
import os
import json
from tqdm import tqdm
import torch
import numpy as np
from openai import OpenAI
import backoff

In [2]:
# load direct pred logs
def load_direct_pred_logs(task):
    if task == 'gpqa':
        logs = json.load(open(''))
    elif task == 'aime':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/runs.baselines/aime.ds-llama-70b.direct/test.7.3,21:37.json'))
    elif task == 'amc':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/runs.baselines/amc.ds-llama-70b.direct/test.7.3,21:23.json'))
    elif task == 'math500':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/runs.baselines/math500.ds-llama-70b.direct/test.7.3,21:59.json'))
    elif task == 'livecode':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/runs.baselines/livecode.ds-llama-70b.direct/test_1to4.7.3,22:24.json'))
    elif task == 'bamboogle':
        logs = json.load(open(''))
    else:
        raise NotImplementedError
    return logs

In [5]:
# Thought segmentation

def segment_thoughts_v1(x):
    return x.strip().split('\n\n')

def segment_thoughts_v2(x):
    # note: excluding things like "so" "therefore", "but", "let me" 
    reasoning_word_list = [
        'okay', 'hmm', 'wait', 'but wait', 'oh wait', 'but let me', 'but actually', 'alternatively', 'now', 'the question',
        'ah', 'oh', 'next', 'another angle', 'another approach', 'also', 'hold on', 'looking it up', 'another point', 
        'I don\'t think', 'perhaps I', 'putting this together', 'Putting it all together', 'i\'m', 'but i\'m',   
        'let me think again', 'I don\'t see'
    ]
    prefix_len = max([len(x) for x in reasoning_word_list])
    newline_segmented_thoughts = segment_thoughts_v1(x)
    final_thoughts = []
    for t in newline_segmented_thoughts:
        t_lower = t.lower()
        is_segment_start = False
        for r_w in reasoning_word_list:
            if t_lower.startswith(r_w.lower()):
                is_segment_start = True
                break
        if is_segment_start or not final_thoughts:
            final_thoughts.append(t)
        else:
            final_thoughts[-1] += '\n\n' + t
    return final_thoughts

def segment_thoughts_v3(x):
    # maybe implement with a small auxiliary model?
    return x


def visualize_thoughts_v2(x):
    print(json.dumps({f'Step {i+1}': t for i, t in enumerate(segment_thoughts_v2(x))}, indent=4))


In [7]:
# analyses (direct prediction thought length)

# for task in ['bamboogle']:
# for task in ['gpqa', 'aime', 'amc', 'math500', 'livecode', 'bamboogle']:
for task in ['aime', 'amc', 'math500', 'livecode']:

    logs = load_direct_pred_logs(task)
    print(f'Task: {task}')
    # print(logs[0]['Metrics'])
    # break
    
    if task in ['bamboogle']:
        ids_correct = set([x['id'] for x in logs if x['Metrics']['em']])
        ids_wrong = set([x['id'] for x in logs if not x['Metrics']['em']])
        assert len(ids_correct) + len(ids_wrong) == len(logs)
    elif task in ['livecode']:
        ids_correct = set([x['id'] for x in logs if x['Metrics']['pass@1']])
        ids_wrong = set([x['id'] for x in logs if not x['Metrics']['pass@1']])
        assert len(ids_correct) + len(ids_wrong) == len(logs)
    else:
        ids_correct = set([x['id'] for x in logs if x['Metrics']['math_equal']])
        ids_wrong = set([x['id'] for x in logs if not x['Metrics']['math_equal']])
        assert len(ids_correct) + len(ids_wrong) == len(logs)
    print(f'Loaded {len(logs)} examples: {len(ids_correct)} correct and {len(ids_wrong)} wrong.')
    
    n_steps_v1, n_steps_v2, n_steps_v2_correct, n_steps_v2_wrong = [], [], [], []
    for i, entry in enumerate(logs):
        # if entry['id'] not in ids_wrong:
        #     continue
        
        thoughts = entry['Output'].split('</think>')[0]
        thoughts_segmented_v1 = segment_thoughts_v1(thoughts)
        n_steps_v1.append(len(thoughts_segmented_v1))
    
        thoughts_segmented_v2 = segment_thoughts_v2(thoughts)
        n_steps_v2.append(len(thoughts_segmented_v2))

        if entry['id'] in ids_wrong:
            n_steps_v2_wrong.append(len(thoughts_segmented_v2))
        else:
            n_steps_v2_correct.append(len(thoughts_segmented_v2))

        # if i == 9:
        #     for t in thoughts_segmented_v2:
        #         # if t.count('\n\n') > 1:
        #         print('\n===================================================\n')
        #         print(t)
        #     # break

    print(round(np.mean(n_steps_v1), 0), round(np.mean(n_steps_v2), 0), round(np.mean(n_steps_v2_correct), 0), round(np.mean(n_steps_v2_wrong), 0))

Task: aime
Loaded 30 examples: 18 correct and 12 wrong.
257.0 60.0 30.0 106.0
Task: amc
Loaded 40 examples: 39 correct and 1 wrong.
122.0 23.0 23.0 13.0
Task: math500
Loaded 500 examples: 437 correct and 63 wrong.
62.0 14.0 10.0 41.0
Task: livecode
Loaded 112 examples: 40 correct and 72 wrong.
338.0 67.0 35.0 85.0
