Skip to content

Commit

Permalink
check on tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Mar 18, 2023
1 parent 1949eff commit aeeda44
Show file tree
Hide file tree
Showing 6 changed files with 527 additions and 47 deletions.
97 changes: 53 additions & 44 deletions genai/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,61 @@
except ImportError:
PANDAS_INSTALLED = False

from . import tokens


def craft_message(text, role="user"):
return {"content": text, "role": role}


def craft_user_message(code):
return {
"content": code,
"role": "user",
}
return craft_message(code, "user")


def repr_genai_pandas(output):
if not PANDAS_INSTALLED:
return repr(output)

if isinstance(output, pd.DataFrame):
# to_markdown() does not use the max_rows and max_columns options
# so we have to truncate the dataframe ourselves

num_columns = min(pd.options.display.max_columns, output.shape[1])
num_rows = min(pd.options.display.max_rows, output.shape[0])

sampled = output.sample(num_columns, axis=1).sample(num_rows, axis=0)

return sampled.to_markdown()

if isinstance(output, pd.Series):
# Similar truncation for series
num_rows = min(pd.options.display.max_rows, output.shape[0])
sampled = output.sample(num_rows)
return sampled.to_markdown()

return repr(output)


def repr_genai(output):
'''Compute a GPT-3.5 friendly representation of the output of a cell.
For DataFrames and Series this means Markdown.
'''
if not PANDAS_INSTALLED:
return repr(output)

with pd.option_context(
'display.max_rows', 5, 'display.html.table_schema', False, 'display.max_columns', 20
):
return repr_genai_pandas(output)


def craft_output_message(output):
if PANDAS_INSTALLED:
with pd.option_context(
'display.max_rows', 5, 'display.html.table_schema', False, 'display.max_columns', 20
):
if isinstance(output, pd.DataFrame):
# to_markdown() does not use the max_rows and max_columns options
# so we have to truncate the dataframe ourselves

num_columns = min(pd.options.display.max_columns, output.shape[1])
num_rows = min(pd.options.display.max_rows, output.shape[0])

sampled = output.sample(num_columns, axis=1).sample(num_rows, axis=0)

return {
"content": sampled.to_markdown(),
"role": "system",
}

if isinstance(output, pd.Series):
# Similar truncation for series
num_rows = min(pd.options.display.max_rows, output.shape[0])
sampled = output.sample(num_rows)
return {
"content": output.to_markdown(),
"role": "system",
}

return {
"content": repr(output),
"role": "system",
}

return {
"content": repr(output),
"role": "system",
}


# tokens to idenfify which cells to ignore based on the start
"""Craft a message from the output of a cell."""
return craft_message(repr_genai(output), "system")


# tokens to idenfify which cells to ignore based on the first line
ignore_tokens = [
"# genai:ignore",
"#ignore",
Expand All @@ -70,7 +77,7 @@ def craft_output_message(output):
]


def get_historical_context(ipython, num_messages=5):
def get_historical_context(ipython, num_messages=5, model="gpt-3.5-turbo-0301"):
"""Create a series of messages to use as context for ChatGPT."""
raw_inputs = ipython.history_manager.input_hist_raw

Expand All @@ -96,4 +103,6 @@ def get_historical_context(ipython, num_messages=5):
if index in outputs:
context.append(craft_output_message(outputs[index]))

context = tokens.trim_messages_to_fit_token_limit(context, model=model)

return context
57 changes: 57 additions & 0 deletions genai/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tiktoken

MAX_TOKENS = {
"gpt-3.5-turbo-0301": 2048,
"gpt-3.5-turbo": 2048,
"gpt-4": 8192,
"gpt-4-0314": 8192,
}


# Copied from https://platform.openai.com/docs/guides/chat/introduction on 3/17/2023
# Modified to support gpt-4 as a best guess
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo-0301": # note: future models may deviate from this
num_tokens = 0
for message in messages:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
# TODO: Watch for when the new models are released and update this
if model == "gpt-3.5-turbo" or model == "gpt-4" or model == "gpt-4-0314":
num_tokens = 0
for message in messages:
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not presently implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" # noqa: E501
)


def trim_messages_to_fit_token_limit(messages, model="gpt-3.5-turbo-0301", max_tokens=None):
"""Reduce the number of messages until they are below the max token limit."""
num_tokens = num_tokens_from_messages(messages, model=model)

if max_tokens is None:
max_tokens = MAX_TOKENS[model]

while num_tokens > max_tokens:
messages.pop(0)
num_tokens = num_tokens_from_messages(messages, model=model)
return messages
Loading

0 comments on commit aeeda44

Please sign in to comment.