In [None]:
import argparse
from tot.methods.bfs import solve
from tot.models import gpt
from tot.tasks.game24 import Game24Task
from tot.tasks.text import TextTask

import itertools
import numpy as np
from functools import partial

args = argparse.Namespace(backend='gpt-4-0613', temperature=0.7, task='text', naive_run=False, prompt_sample="standard", method_generate='sample', method_evaluate='vote', method_select='greedy', n_generate_sample=5, n_evaluate_sample=3, n_select_sample=5)

def get_value(task, x, y, n_evaluate_sample, cache_value=True):
    value_prompt = task.value_prompt_wrap(x, y)
    if cache_value and value_prompt in task.value_cache:
        return task.value_cache[value_prompt]
    value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
    value = task.value_outputs_unwrap(x, y, value_outputs)
    if cache_value:
        task.value_cache[value_prompt] = value
    return value

def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
    values = []
    local_value_cache = {}
    for y in ys:  # each partial output
        if y in local_value_cache:  # avoid duplicate candidates
            value = 0
        else:    
            value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
            local_value_cache[y] = value
        values.append(value)
    return values

def get_votes(task, x, ys, n_evaluate_sample):
    vote_prompt = task.vote_prompt_wrap(x, ys)
    print(f"Running Prompt: {vote_prompt}")
    vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
    values = task.vote_outputs_unwrap(vote_outputs, len(ys))
    print(f"Completed Prompt!")
    return values

def get_proposals(task, x, y): 
    propose_prompt = task.propose_prompt_wrap(x, y)
    proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
    return [y + _ + '\n' for _ in proposals]

def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
    if prompt_sample == 'standard':
        prompt = task.standard_prompt_wrap(x, y)
    elif prompt_sample == 'cot':
        prompt = task.cot_prompt_wrap(x, y)
    else:
        raise ValueError(f'prompt_sample {prompt_sample} not recognized')
    print(f"Running Prompt: {prompt}")
    samples = gpt(prompt, n=n_generate_sample, stop=stop)
    print(f"Completed Prompt!")
    return [y + _ for _ in samples]

def solve(args, task, idx, x, to_print=True):
    global gpt
    gpt = partial(gpt, model=args.backend, temperature=args.temperature)
    print(gpt)
    # x = task.get_input(idx)  # input
    ys = ['']  # current output candidates
    infos = []
    print('x:', x)
    for step in range(task.steps):
        # generation
        if args.method_generate == 'sample':
            new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
        elif args.method_generate == 'propose':
            new_ys = [get_proposals(task, x, y) for y in ys]
        new_ys = list(itertools.chain(*new_ys))
        ids = list(range(len(new_ys)))
        # evaluation
        if args.method_evaluate == 'vote':
            values = get_votes(task, x, new_ys, args.n_evaluate_sample)
        elif args.method_evaluate == 'value':
            values = get_values(task, x, new_ys, args.n_evaluate_sample)
        print(new_ys)
        # selection
        if args.method_select == 'sample':
            ps = np.array(values) / sum(values)
            select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
        elif args.method_select == 'greedy':
            select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
        select_new_ys = [new_ys[select_id] for select_id in select_ids]

        # log
        if to_print: 
            sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
            print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
        
        infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
        ys = select_new_ys
    
    if to_print: 
        print(ys)
    return ys, {'steps': infos}


task = TextTask()
x = gpt('Provide a four sentence summary of a restoration response to a sudden decline in california sage scrub diversity and density.', model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None)   
print("Summary of Tasks: ", x)
ys, infos = solve(args, task, x[0], 0)
print(ys[0])

Summary of Tasks:  ["In response to a sudden decline in California sage scrub diversity and density, restoration efforts involve replanting native species and removing invasive ones. This includes conducting thorough research and monitoring to identify the cause of the decline. Additionally, proactive measures such as controlled burns are used to maintain the ecosystem's health and prevent wildfires which can harm the sage scrub. Public education and community involvement are also essential components of the restoration response, fostering a sense of stewardship and promoting conservation efforts."]
functools.partial(<function gpt at 0x112891a80>, model='gpt-4-0613', temperature=0.7)
x: 0
Running Prompt: 
Write a coherent plan for ecological recovery of 4 short paragraphs. Each paragraph must address: 0

Completed Prompt!
Running Prompt: Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The be