In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
from llama import Workflow, Llama

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

workflow = Workflow.build(
    ckpt_dir='/scratch4/jeisner1/tjbai/llama_8b',
    tokenizer_path='/scratch4/jeisner1/tjbai/llama_8b/tokenizer.model',
    max_seq_len=512*16,
    max_batch_size=4,
    model_parallel_size=1
)

llama = Llama(workflow.model, workflow.tokenizer)

In [None]:
import re
from collections import Counter

cot_prompt = '''
You are a creative problem solver with deep expertise in competition mathematics.
Your goal is to find a clean and insightful approach to solving the provided problem.

Before proceeding, first plan out your approach. Here are some tips:
- Break down the reasoning into clear, atomic steps
- Explicitly state any assumptions or key insights

Keep your proposal high-level and concise. You do not need to solve the entire problem.

Format your response as:
INSIGHT: (1 sentence summary)
POTENTIAL CHALLENGES: (2-3 sentences)
APPROACH:
1. ...
2. ...
'''

finish_prompt = '''
You are a creative problem solver with deep expertise in competition mathematics.
Your goal is to solve the provided problem. You may find it helpful to utilize the previously generated proposal.
Keep your solution relatively concise, leaning mostly on previously developed ideas.

Format your response as:
ANSWER: (2-3 sentence summary of solution and final answer)
'''

def format_vote_prompt(n):
    return f'''
    You are a rigorous mathematical evaluator with deep expertise in competition mathematics.
    You will be shown several different solution strategies for a math problem. Your task is to analyze each and select the most promising one to pursue.

    It may help to evaluate on the following principles:
    - Mathematical soundness: Are all the steps logicall valid?
    - Tractability: Is the proposed solution computationally complex? Are there potential dead ends?
    - Elegance: Does this approach effectively leverage key problem structure?

    Vote on the best proposal and justify your choice. You will see {n} proposals, so respond with the proposal 1 through {n}.
    Do not attempt to solve the problem. You only need to evaluate each proposal and select the best option.

    Format your response as:
    BEST CHOICE: (index of best solution)
    RATIONALE: (1 sentence justification)
    '''

def format_problem(problem):
    return f'Here is the provided problem:\n{problem}'

def parse_choice(text):
    match = re.search(r'BEST CHOICE:\s*(\d+)', text)
    if match:
        return int(match.group(1))
    return None

def solve(problem, n=5):
    workflow.reset()

    cot, vote, finish = workflow.insert([
        {
            'message': {'role': 'system', 'content': cot_prompt},
            'parent_ids': []
        },
        {
            'message': {'role': 'system', 'content': format_vote_prompt(n)},
            'parent_ids': []
        },
        {
            'message': {'role': 'system', 'content': finish_prompt},
            'parent_ids': []
        },
    ])

    # seems to matter that these are separate
    cot_user, vote_user, finish_user = workflow.insert([
        {
            'message': {'role': 'user', 'content': format_problem(problem)},
            'parent_ids': [cot['id']]
        },
        {
            'message': {'role': 'user', 'content': format_problem(problem)},
            'parent_ids': [vote['id']]
        },
        {
            'message': {'role': 'user', 'content': format_problem(problem)},
            'parent_ids': [finish['id']]
        },
    ])
    
    proposal_tokens, proposal_nodes = workflow.step(
        [
            {
                'expects': ('assistant', f'solution {i+1}'),
                'parent_ids': [cot['id'], cot_user['id']]
            }
            for i in range(n)
        ],
        compact=False,
        prefill=True,
        max_gen_len=1024,
        temperature=0.7,
        top_p=0.9,
        seed=42,
    )

    vote_tokens, vote_nodes = workflow.step(
        [
            {
                'expects': ('assistant', None),
                'parent_ids': [vote['id'], vote_user['id']] + [p['id'] for p in proposal_nodes]
            }
            for _ in range(5)
        ],
        compact=False,
        prefill=True,
        max_gen_len=256,
        temperature=0.7,
        top_p=0.9,
        seed=42,
    )
    
    res = None
    votes = [
        choice for resp in vote_tokens if 
        (choice := parse_choice(workflow.tokenizer.decode(resp))) is not None
    ]
    
    if len(votes) > 0:
        best = Counter(votes).most_common(1)[0][0]
        [res], _ = workflow.step(
            [
                {
                    'expects': ('assistant', None),
                    'parent_ids': [finish['id'], finish_user['id']] + [proposal_nodes[best-1]['id']]
                }
            ],
            compact=False,
            prefill=True,
            max_gen_len=256,
            temperature=0.7,
            top_p=0.9,
            seed=42,
        )
    
    return proposal_tokens, vote_tokens, res, votes