In [None]:
%reload_ext autoreload
%autoreload 2
%cd llama3/llama

In [2]:
import os
from llama import Workflow, Llama
from llama.util import load_model_and_tokenizer

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

workflow = Workflow.build(
    ckpt_dir='/scratch4/jeisner1/tjbai/llama_8b',
    tokenizer_path='/scratch4/jeisner1/tjbai/llama_8b/tokenizer.model',
    max_seq_len=8192,
    max_batch_size=4,
    model_parallel_size=1,
    max_nodes=20,
    use_lora=True,
    lora_rank=8,
    lora_alpha=16,
    lora_dropout=0.1
)

workflow.model.get_trainable_param_percentage()

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Converting to LoRA
Loaded in 18.88 seconds


0.08480376642881861

In [11]:
import torch
from llama.workflows.finetune import TotTrainer
from llama.workflows.tot import cot_prompt, finish_prompt, format_vote_system_prompt, format_problem

sample = torch.load('tot_data/problem_0.pt', weights_only=True)
trainer = TotTrainer(workflow, branching_factor=8, voters=4)

cot, vote, finish = workflow.insert([
    {'messages': [
        {'role': 'system', 'content': cot_prompt},
        {'role': 'user', 'content': format_problem(sample['problem'])}
    ], 'parent_ids': []},
    {'messages': [
        {'role': 'system', 'content': format_vote_system_prompt(8)},
        {'role': 'user', 'content': format_problem(sample['problem'])}
    ], 'parent_ids': []},
    {'messages': [
        {'role': 'system', 'content': finish_prompt},
        {'role': 'user', 'content': format_problem(sample['problem'])}
    ], 'parent_ids': []},
], training=True)

proposal_tasks = [
    {'header': ('assistant', None),
     'prefill': f'Solution #{i+1}:\n\n',
     'parent_ids': [cot['id']]}
    for i in range(8)
]
target_proposal_ids = [p + [workflow.tokenizer.eot_id] for p in sample['result']['proposal_tokens']]
proposal_nodes, proposal_logprobs = workflow.train_step(proposal_tasks, target_proposal_ids)

Training 6.8M / 8.0B parameters


In [15]:
workflow.reset()
workflow.model.set_adapter_state(enabled=False)

[system] = workflow.insert([
    {
        'messages': [{'role': 'system', 'content': 'Answer the user\'s question please.'}],
        'parent_ids': [],
    },
])

[user_1] = workflow.insert([
    {
        'messages': [{'role': 'user', 'content': 'What is the capital of France?'}],
        'parent_ids': [system['id']],
    },
])

[output], _ = workflow.step(
    [
        {
            'header': ('assistant', None),
            'prefill': '',
            'parent_ids': [system['id'], user_1['id']],
        }
    ]
)

workflow.tokenizer.decode(output)

'The capital of France is Paris.'

In [14]:
workflow.reset()
workflow.model.set_adapter_state(enabled=True)

[system] = workflow.insert([
    {
        'messages': [{'role': 'system', 'content': 'Answer the user\'s question please.'}],
        'parent_ids': [],
    },
])

[user_1] = workflow.insert([
    {
        'messages': [{'role': 'user', 'content': 'What is the capital of France?'}],
        'parent_ids': [system['id']],
    },
])

[output], _ = workflow.step(
    [
        {
            'header': ('assistant', None),
            'prefill': '',
            'parent_ids': [system['id'], user_1['id']],
        }
    ]
)

workflow.tokenizer.decode(output)

'The capital of France is Paris.'

## sanity check forward-backward

In [5]:
import torch
from llama.workflows.finetune import TotTrainer
from llama.workflows.tot import cot_prompt, finish_prompt, format_vote_system_prompt, format_problem

problem = torch.load('tot_data/problem_0.pt', weights_only=True)
trainer = TotTrainer(workflow, branching_factor=8, voters=4)

workflow.model.train()
workflow.model.set_adapter_state(enabled=True)
workflow.model.zero_grad()

total_loss, metrics = trainer.step(problem)

metrics

Training 6.8M / 8.0B parameters


In [None]:
from llama.workflows.finetune import finetune

finetune(
    data_path='tot_data',
    ckpt_dir='/scratch4/jeisner1/tjbai/llama_8b',
    tokenizer_path='/scratch4/jeisner1/tjbai/llama_8b/tokenizer.model',
    output_dir='/scratch4/jeisner1/tjbai/checkpoints',
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    max_seq_len=6144,
)

## load and run checkpoints

In [None]:
import torch

checkpoint = torch.load('/scratch4/jeisner1/tjbai/checkpoints/lora_epoch-0_step-399.pt', weights_only=True)
workflow.model.load_state_dict(checkpoint['lora'])

workflow.reset()
workflow.model.eval()
workflow.model.set_adapter_state(True)

[system] = workflow.insert([
    {
        'messages': [{'role': 'system', 'content': 'Answer ALL of the user\'s question(s).'}],
        'parent_ids': [],
    },
])

[user_1, user_2] = workflow.insert([
    {
        'messages': [{'role': 'user', 'content': 'What is the capital of France?'}],
        'parent_ids': [system['id']],
    },
    {
        'messages': [{'role': 'user', 'content': 'What is the largest planet in the solar system?'}],
        'parent_ids': [system['id']],
    },
])

[output], _ = workflow.step(
    [
        {
            'header': ('assistant', None),
            'prefill': '',
            'parent_ids': [system['id'], user_1['id'], user_2['id']],
        }
    ],
)

workflow.tokenizer.decode(output)

## evaluate trick prompt results across checkpoints

In [None]:
import torch
import json
from llama import Llama
from llama.workflows.tot import load_math_problems, benchmark_tricky_tot
from tqdm import tqdm

problems = load_math_problems(
    '../data/MATH',
    split='train',
    problem_types=['counting_and_probability']
)[:100]

for id in [99, 199, 299, 399]: 
    checkpoint = torch.load(f'/scratch4/jeisner1/tjbai/checkpoints/lora_epoch-0_step-{id}.pt', weights_only=True)
    workflow.model.load_state_dict(checkpoint['lora'])
    llama = Llama(workflow.model, workflow.tokenizer)
    print(f'Loaded checkpoint-{id}')
    print(f'Memory allocated: {torch.cuda.memory_allocated()}')

    comps = []
    for problem in tqdm(problems):
        comps.append(benchmark_tricky_tot(
            llama=llama,
            workflow=workflow,
            problem=problem['problem'],
            branching_factor=8,
            voters=4
        ))
        
    with open(f'checkpoint-{id}.json', 'w') as f:
        json.dump(comps, f)

## generate and evaluate final solutions

In [None]:
import torch
import json
from llama import Llama
from llama.workflows.tot import load_math_problems, benchmark_solution_quality
from tqdm import tqdm

problems = load_math_problems(
    '../data/MATH',
    split='train',
    problem_types=['counting_and_probability']
)[:200]

for id in [99, 199, 299, 399]: 
    checkpoint = torch.load(f'/scratch4/jeisner1/tjbai/checkpoints/lora_epoch-0_step-{id}.pt', weights_only=True)
    workflow.model.load_state_dict(checkpoint['lora'])
    llama = Llama(workflow.model, workflow.tokenizer)
    print(f'Loaded checkpoint-{id}')
    print(f'Memory allocated: {torch.cuda.memory_allocated()}')

    comps = []
    for problem in tqdm(problems):
        comps.append(benchmark_solution_quality(
            llama=llama,
            workflow=workflow,
            problem=problem['problem'],
            branching_factor=8,
            voters=4,
            compact=False,
        ))
        
    with open(f'checkpoint-{id}_solution_quality.json', 'w') as f:
        json.dump(comps, f)        

In [None]:
llama.model.set_adapter_state(enabled=False)

import re
import json
import random
from tqdm import tqdm
from collections import Counter
from llama.workflows.tot import load_math_problems, benchmark_solution_quality, parse_choice


evaluator_prompt = '''
You are evaluating final answers to AMC/AIME competition problems. You will receive:

1. A problem statement
2. The ground truth solution 
3. A shared solution proposal that both contestants used
4. Two final answers based on this proposal

Your task is to evaluate how effectively each contestant converted the shared proposal into a valid solution.
Note that valid solutions may differ from the ground truth approach while remaining correct.

Evaluate both answers focusing on:
1. Answer Format Quality
- Clarity and conciseness of final statement
- Proper mathematical notation
- Inclusion of key numerical result

2. Mathematical Validity
- Correctness of final numerical answer
- Completeness (all parts answered)
- Any invalid mathematical claims

3. Justification Level
- Appropriate amount of supporting context
- Balance between brevity and explanation
- Clear connection to previous reasoning

Walk through each of these criterion and compare the 2 solutions. 

You must format your response as:

VERDICT: [1 or 2]
VERDICT_NOTE: (one sentence explanation)
'''

problems = load_math_problems(
    '../data/MATH',
    split='train',
    problem_types=['counting_and_probability']
)[:200]

for id in [0, 99, 199, 299, 399]:
    with open(f'checkpoint-{id}_solution_quality.json', 'r') as f:
        data = json.load(f)

    baseline_win = 0
    cached_win = 0

    for d, problem_obj in tqdm(zip(data, problems), total=200):
        problem = d['problem']
        solution = problem_obj['solution']
        baseline_final = d['baseline_final']
        cached_final = d['cached_final']

        votes = [
            choice for resp in d['voters'] if
            (choice := parse_choice(resp)) is not None
        ]
        best = Counter(votes).most_common(1)[0][0] - 1

        baseline_first = random.choice([True, False])
        ans1 = baseline_final if baseline_first else cached_final 
        ans2 = cached_final if baseline_first else baseline_final

        dialog = [
            {'role': 'system', 'content': evaluator_prompt},
            {'role': 'user', 'content': f'''
PROBLEM STATEMENT:
{problem}

GROUND TRUTH SOLUTION:
{solution}

SOLUTION PROPOSAL:
{d['proposals'][best]}

FINAL ANSWER #1:
{ans1}

FINAL ANSWER #2:
{ans2}
'''
            }
        ]

        [evaluation] = llama.chat_completion(
            [dialog],
            max_gen_len=256,
            temperature=0.7,
            top_p=0.9,
            seed=42,
        )

        match = re.search(r'VERDICT:\s*(\d)', evaluation['generation']['content'])
        if match:
            num = int(match.group(1))
            if (num == 1 and baseline_first) or (num == 2 and not baseline_first):
                baseline_win += 1
            else:
                cached_win += 1
                
    print(baseline_win, cached_win)

  5%|▌         | 10/200 [00:25<06:12,  1.96s/it]