In [None]:
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationChain
import os

from taogpt.orchestrator import *
from taogpt.utils import *
from taogpt.llm_model import LangChainLLM
from taogpt.prompts import PromptDb
import taogpt.utils as utils

utils.enable_debugging(0)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
TEMPERATURE = 0.7

In [None]:
with open(os.path.join(os.environ['HOME'], '.ssh', 'openai-zillow.json'), 'r') as f:
    credentials = json.load(f)
os.environ["OPENAI_API_KEY"] = credentials['key']
os.environ["OPENAI_API_BASE"] = credentials['url']
llm3_5 = ChatOpenAI(model_name='gpt-3.5-turbo-16k', temperature=TEMPERATURE)
llm4 = ChatOpenAI(model_name='gpt-4-32k', temperature=TEMPERATURE)

conversation = ConversationChain(llm=llm3_5)
conversation.predict(input="What's your model version?")

In [None]:
llm3_5.model_name, llm4.model_name

In [None]:
prompts = PromptDb.load_defaults()
logger = MarkdownLogger('logs/taogpt_log.md')
config = Config(
    ask_user_before_execute_codes=False,
    ask_user_questions_in_one_prompt=True,
    initial_expansion=2,
    max_tree_branches=6,
    max_tokens=10000,
    check_final=True
)
executor = Orchestrator(
    config=config,
    llm=LangChainLLM(llm4, logger=logger),
    prompts=prompts,
    markdown_logger=logger,
    # sage_llm=LangChainLLM(llm4, logger=logger),
)

experiment_name = 'example'
executor.start("""
Solve this 4x4 Sudoku:

```text
+-+-+-+-+
| |3| |1|
+-+-+-+-+
|1| | |3|
+-+-+-+-+
|2| | |4|
+-+-+-+-+
|3|4| |2|
+-+-+-+-+
```
""", analyze_first=True)

In [None]:
executor.resume(10000)

In [None]:
logger = MarkdownLogger(f'examples/{experiment_name}.final.md')
logger.log_conversation(executor.show_conversation_thread(with_extras=True))
logger.log(f"**total tokens**: {executor.llm.total_tokens}")


In [None]:
import pickle
with open("sukudo4x4_example.pkl", "wb") as f:
    pickle.dump([inv.step for inv in executor.chain], f)

In [None]:
backup_chain = executor.chain.copy()

In [None]:
executor._chain = executor.chain[:8]

In [None]:
print(executor.chain[-1].step)
print(executor.chain[-1].step.description)

In [None]:
executor._prompts = PromptDb.load_defaults()

print(executor.prompts.tao_templates)

In [None]:
executor.resume(10000)