Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the use of OpenAI's rough token count #469

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions lib/langchain/llm/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class OpenAI < Base
n: 1,
temperature: 0.0,
chat_completion_model_name: "gpt-3.5-turbo",
embeddings_model_name: "text-embedding-ada-002"
embeddings_model_name: "text-embedding-ada-002",
token_counter: :tiktoken
}.freeze

EMBEDDING_SIZES = {
Expand All @@ -33,7 +34,8 @@ class OpenAI < Base
# Initialize an OpenAI LLM instance
#
# @param api_key [String] The API key to use
# @param client_options [Hash] Options to pass to the OpenAI::Client constructor
# @param llm_options [Hash] Options to pass to the OpenAI::Client constructor
# @param default_options [Hash] Options to customize the default behavior of the LLM
def initialize(api_key:, llm_options: {}, default_options: {})
depends_on "ruby-openai", req: "openai"

Expand Down Expand Up @@ -200,7 +202,11 @@ def with_api_error_handling
end

def validate_max_tokens(messages, model, max_tokens = nil)
LENGTH_VALIDATOR.validate_max_tokens!(messages, model, max_tokens: max_tokens, llm: self)
LENGTH_VALIDATOR.validate_max_tokens!(
messages, model,
max_tokens: max_tokens, llm: self,
token_counter: defaults[:token_counter]
)
end

def response_from_chunks
Expand Down
14 changes: 9 additions & 5 deletions lib/langchain/utils/token_length/openai_validator.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# frozen_string_literal: true

require "openai"
require "tiktoken_ruby"

module Langchain
Expand Down Expand Up @@ -61,6 +62,7 @@ class OpenAIValidator < BaseValidator
#
# @param text [String] The text to calculate the token length for
# @param model_name [String] The model name to validate against
# @param options [Hash] The options to customize the token length calculation
# @return [Integer] The token length of the text
#
def self.token_length(text, model_name, options = {})
Expand All @@ -69,8 +71,12 @@ def self.token_length(text, model_name, options = {})
model_name = "text-embedding-ada-002"
end

encoder = Tiktoken.encoding_for_model(model_name)
encoder.encode(text).length
if options[:token_counter] == :openai
::OpenAI.rough_token_count(text)
else
encoder = Tiktoken.encoding_for_model(model_name)
encoder.encode(text).length
end
end

def self.token_limit(model_name)
Expand All @@ -95,8 +101,6 @@ def self.validate_max_tokens!(content, model_name, options = {})
# @return [Integer] The token length of the messages
#
def self.token_length_from_messages(messages, model_name, options = {})
encoding = Tiktoken.encoding_for_model(model_name)

if ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613"].include?(model_name)
tokens_per_message = 3
tokens_per_name = 1
Expand All @@ -119,7 +123,7 @@ def self.token_length_from_messages(messages, model_name, options = {})
messages.each do |message|
num_tokens += tokens_per_message
message.each do |key, value|
num_tokens += encoding.encode(value).length
num_tokens += token_length(value, model_name, options)
num_tokens += tokens_per_name if ["name", :name].include?(key)
end
end
Expand Down
28 changes: 27 additions & 1 deletion spec/langchain/llm/openai_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,37 @@
end

it "passes correct options to the client" do
# openai-ruby sets global configuration options here: https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb
# openai-ruby sets global configuration options here:
# https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb
result = subject
expect(result.client.uri_base).to eq("http://localhost:1234")
end
end

context "when default_options are passed" do
let(:subject) do
described_class.new(
api_key: "123",
default_options: {temperature: 0.5, token_counter: :openai}
)
end

it "overrides the default values" do
default_options = Langchain::LLM::OpenAI::DEFAULTS.dup
default_options[:temperature] = 0.5
default_options[:token_counter] = :openai

expect(subject.defaults).to eq(default_options)
end
end

context "when default_options are not passed" do
let(:subject) { described_class.new(api_key: "123") }

it "uses the default values" do
expect(subject.defaults).to eq(Langchain::LLM::OpenAI::DEFAULTS)
end
end
end

describe "#embed" do
Expand Down
109 changes: 81 additions & 28 deletions spec/langchain/utils/token_length/openai_validator_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,37 @@
it "raises an error" do
expect {
subject
}.to raise_error(Langchain::Utils::TokenLength::TokenLimitExceeded, "This model's maximum context length is 4097 tokens, but the given text is 45000 tokens long.")
}.to raise_error(
Langchain::Utils::TokenLength::TokenLimitExceeded,
"This model's maximum context length is 4097 tokens, but the given text is 45000 tokens long."
)
end
end

context "when the text is not too long" do
let(:content) { "lorem ipsum" * 100 }
let(:model) { "gpt-4" }

it "does not raise an error" do
expect { subject }.not_to raise_error
context "when the token_counter is tiktoken" do
it "does not raise an error" do
expect { subject }.not_to raise_error
end

it "returns the correct max_tokens" do
expect(subject).to eq(7892)
end
end

it "returns the correct max_tokens" do
expect(subject).to eq(7892)
context "when the token_counter is openai" do
subject { described_class.validate_max_tokens!(content, model, token_counter: :openai) }

it "does not raise an error" do
expect { subject }.not_to raise_error
end

it "returns the correct max_tokens" do
expect(subject).to eq(7987)
end
end
end

Expand Down Expand Up @@ -141,34 +158,70 @@
}]
}

it "returns the correct token length for gpt-3.5-turbo-0301" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo-0301")
).to eq(127)
end
context "when the token counter is tiktoken" do
it "returns the correct token length for gpt-3.5-turbo-0301" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo-0301")
).to eq(127)
end

it "returns the correct token length for gpt-3.5-turbo-0613" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo-0613")
).to eq(129)
end
it "returns the correct token length for gpt-3.5-turbo-0613" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo-0613")
).to eq(129)
end

it "returns the correct token length for gpt-3.5-turbo" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo")
).to eq(129)
end
it "returns the correct token length for gpt-3.5-turbo" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo")
).to eq(129)
end

it "returns the correct token length for gpt-4-0613" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-4-0613")
).to eq(129)
it "returns the correct token length for gpt-4-0613" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-4-0613")
).to eq(129)
end

it "returns the correct token length for gpt-4" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-4")
).to eq(129)
end
end

it "returns the correct token length for gpt-4" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-4")
).to eq(129)
context "when the token counter is openai" do
let(:options) { {token_counter: :openai} }

it "returns the correct token length for gpt-3.5-turbo-0301" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo-0301", options)
).to eq(141)
end

it "returns the correct token length for gpt-3.5-turbo-0613" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo-0613", options)
).to eq(143)
end

it "returns the correct token length for gpt-3.5-turbo" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-3.5-turbo", options)
).to eq(143)
end

it "returns the correct token length for gpt-4-0613" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-4-0613", options)
).to eq(143)
end

it "returns the correct token length for gpt-4" do
expect(
described_class.token_length_from_messages(example_messages, "gpt-4", options)
).to eq(143)
end
end
end
end