Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing in max_tokens: to Langchain::LLM::Azure #404

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/langchain/llm/azure.rb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def complete(prompt:, **params)
parameters = compose_parameters @defaults[:completion_model_name], params

parameters[:messages] = compose_chat_messages(prompt: prompt)
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
parameters[:max_tokens] = params[:max_tokens] || validate_max_tokens(parameters[:messages], parameters[:model])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an improvement however I think the following might be better.

validate_max_tokens(parameters[:messages], parameters[:model]) # raises exception if maximum exceeded
parameters[:max_tokens] = params[:max_tokens] if params[:max_tokens]

This should still perform the validation and only pass max_tokens to the OpenAI Client if it is set as an argument to this method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's another approach that I forgot that @bricolage tackled: https://github.com/andreibondarev/langchainrb/pull/388/files

validate_max_tokens() accepts a 3rd argument which is user passed-in max_tokens and then selects the smaller (min) one between the allowed max_tokens and passed in one.

Thoughts on this approach here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem with that approach is when max_tokens is not passed-in. This solution will set the max_tokens to the maximum amount of tokens supported by the model. (and then Azure will consume the maximum amount of tokens)

In this scenario I think it would be better to not set anything and so the Azure defaults are used.
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also skip the whole thing by adding a skip_max_tokens_validation: true/false in each LLM class.

llm = Langchain::LLM::Azure.new ...
llm.skip_max_tokens_validation = true

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would provide a solution however I don't think its ideal. I think we always want to run the validation part of the validate_max_tokens method because its good to know of the payload size exceeds the maximum supported context before making a request.

I think there are 4 use cases we are trying to support:

  1. No configuration - It should not set max_tokens and so it will us the Azure defaults
  2. Setting default max_tokens for an llm to be used for all requests
  3. Setting override for max_tokens per request (e..g via complete or chat methods)
  4. (current behaviour) Setting max_tokens to the remaining tokens in the context.

I think the following would satisfy these but its messy.

llm = Langchain::LLM::Azure.new ...
llm.default_max_tokens = 300 # use 300 tokens by default in request
# or
llm.default_max_tokens = :max # use all tokens available in request context.
# or (not set)
llm.default_max_tokens = nil 
...

def complete(max_tokens: 300) # or max_tokens: :max
  ...
  remaining_tokens = validate_max_tokens(parameters[:messages], parameters[:model]) # raises exception if  maximum exceeded

  max_tokens =  params.fetch(:max_tokens, llm.default_max_tokens) # if not passed in use llm default max tokens

  parameters[:max_tokens] = remaining_tokens if max_tokens == :max
  parameters[:max_tokens] ||= max_tokens if max_tokens
  ...
  # call OpenAI Client with parameters
end


response = with_api_error_handling do
chat_client.chat(parameters: parameters)
Expand Down
10 changes: 8 additions & 2 deletions spec/langchain/llm/azure_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,18 @@

context "with prompt and parameters" do
let(:parameters) do
{parameters: {n: 1, model: "gpt-3.5-turbo", messages: [{content: "Hello World", role: "user"}], temperature: 1.0, max_tokens: 4086}}
{parameters: {n: 1, model: "gpt-3.5-turbo", messages: [{content: "Hello World", role: "user"}], temperature: 1.0, max_tokens: 16}}
end

it "returns a completion" do
response = subject.complete(prompt: "Hello World", model: "gpt-3.5-turbo", temperature: 1.0)
response = subject.complete(
prompt: "Hello World",
model: "gpt-3.5-turbo",
temperature: 1.0,
max_tokens: 16 # `max_tokens` can be passed in and overwritten.
)

expect(response.completion_tokens).to eq(16)
expect(response.completion).to eq("The meaning of life is subjective and can vary from person to person.")
end
end
Expand Down