Skip to content

Commit

Permalink
🗂️ move actual prompts out of chains
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Jul 17, 2023
1 parent d7fa612 commit 000213d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 86 deletions.
59 changes: 7 additions & 52 deletions codeinterpreterapi/chains/modifications_check.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,11 @@
import json
from json import JSONDecodeError
from typing import List

from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import (
AIMessage,
OutputParserException,
SystemMessage,
)


prompt = ChatPromptTemplate(
input_variables=["code"],
messages=[
SystemMessage(
content="The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n"
"Answer with a function call `determine_modifications` and list them inside.\n"
"If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
),
HumanMessagePromptTemplate.from_template("{code}"),
],
)

functions = [
{
"name": "determine_modifications",
"description": "Based on code of the user determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n",
"parameters": {
"type": "object",
"properties": {
"modifications": {
"type": "array",
"items": {"type": "string"},
"description": "The filenames that are modified by the code.",
},
},
"required": ["modifications"],
},
}
]
from langchain.schema import AIMessage, OutputParserException

from codeinterpreterapi.prompts import determine_modifications_function, determine_modifications_prompt


async def get_file_modifications(
Expand All @@ -55,8 +15,8 @@ async def get_file_modifications(
) -> List[str] | None:
if retry < 1:
return None
messages = prompt.format_prompt(code=code).to_messages()
message = await llm.apredict_messages(messages, functions=functions)
messages = determine_modifications_prompt.format_prompt(code=code).to_messages()
message = await llm.apredict_messages(messages, functions=[determine_modifications_function])

if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")
Expand All @@ -71,7 +31,7 @@ async def get_file_modifications(


async def test():
llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
llm = ChatOpenAI(model="gpt-3.5") # type: ignore

code = """
import matplotlib.pyplot as plt
Expand All @@ -87,17 +47,12 @@ async def test():
plt.show()
"""

code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"

modifications = await get_file_modifications(code2, llm)

print(modifications)
print(await get_file_modifications(code, llm))


if __name__ == "__main__":
import asyncio
from dotenv import load_dotenv

load_dotenv()

asyncio.run(test())
41 changes: 7 additions & 34 deletions codeinterpreterapi/chains/remove_download_link.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,15 @@
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import (
AIMessage,
OutputParserException,
SystemMessage,
HumanMessage,
)


prompt = ChatPromptTemplate(
input_variables=["input_response"],
messages=[
SystemMessage(
content="The user will send you a response and you need to remove the download link from it.\n"
"Reformat the remaining message so no whitespace or half sentences are still there.\n"
"If the response does not contain a download link, return the response as is.\n"
),
HumanMessage(
content="The dataset has been successfully converted to CSV format. You can download the converted file [here](sandbox:/Iris.csv)."
),
AIMessage(content="The dataset has been successfully converted to CSV format."),
HumanMessagePromptTemplate.from_template("{input_response}"),
],
)
from langchain.schema import AIMessage, OutputParserException

from codeinterpreterapi.prompts import remove_dl_link_prompt


async def remove_download_link(
input_response: str,
llm: BaseLanguageModel,
) -> str:
messages = prompt.format_prompt(input_response=input_response).to_messages()
messages = remove_dl_link_prompt.format_prompt(input_response=input_response).to_messages()
message = await llm.apredict_messages(messages)

if not isinstance(message, AIMessage):
Expand All @@ -47,15 +23,12 @@ async def test():

example = "I have created the plot to your dataset.\n\nLink to the file [here](sandbox:/plot.png)."

modifications = await remove_download_link(example, llm)

print(modifications)
print(await remove_download_link(example, llm))


if __name__ == "__main__":
import asyncio
import dotenv

dotenv.load_dotenv()
from dotenv import load_dotenv
load_dotenv()

asyncio.run(test())
2 changes: 2 additions & 0 deletions codeinterpreterapi/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .system_message import system_message as code_interpreter_system_message
from .modifications_check import determine_modifications_function, determine_modifications_prompt
from .remove_dl_link import remove_dl_link_prompt
35 changes: 35 additions & 0 deletions codeinterpreterapi/prompts/modifications_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

from langchain.schema import SystemMessage
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate


determine_modifications_prompt = ChatPromptTemplate(
input_variables=["code"],
messages=[
SystemMessage(
content="The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n"
"Answer with a function call `determine_modifications` and list them inside.\n"
"If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
),
HumanMessagePromptTemplate.from_template("{code}"),
],
)


determine_modifications_function = {
"name": "determine_modifications",
"description": "Based on code of the user determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n",
"parameters": {
"type": "object",
"properties": {
"modifications": {
"type": "array",
"items": {"type": "string"},
"description": "The filenames that are modified by the code.",
},
},
"required": ["modifications"],
},
}
26 changes: 26 additions & 0 deletions codeinterpreterapi/prompts/remove_dl_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import (
AIMessage,
SystemMessage,
HumanMessage,
)


remove_dl_link_prompt = ChatPromptTemplate(
input_variables=["input_response"],
messages=[
SystemMessage(
content="The user will send you a response and you need to remove the download link from it.\n"
"Reformat the remaining message so no whitespace or half sentences are still there.\n"
"If the response does not contain a download link, return the response as is.\n"
),
HumanMessage(
content="The dataset has been successfully converted to CSV format. You can download the converted file [here](sandbox:/Iris.csv)."
),
AIMessage(content="The dataset has been successfully converted to CSV format."),
HumanMessagePromptTemplate.from_template("{input_response}"),
],
)

0 comments on commit 000213d

Please sign in to comment.