Skip to content

Commit

Permalink
Update default max tokens and model selection in AI chat methods
Browse files Browse the repository at this point in the history
Removed default max tokens and hard-coded model selection in various parts of the AI chat related code. The update allows the AI service to determine the optimal model and the number of tokens based on the dialog input, improving the flexibility and adaptability of the AI responses. Default max tokens are now nullable, offering further flexibility. The version has also been bumped from 4.0.0-alpha to 4.1.0-alpha.
  • Loading branch information
rodion-m committed Dec 17, 2023
1 parent 78e1ae8 commit 613d5f6
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 97 deletions.
2 changes: 1 addition & 1 deletion src/Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup>
<Version>4.0.0-alpha</Version>
<Version>4.1.0-alpha</Version>
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
Expand Down
21 changes: 16 additions & 5 deletions src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private static void ThrowUnkownProviderException(string provider)
/// <inheritdoc />
public Task<string> GetChatCompletions(
UserOrSystemMessage dialog,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Action<ChatCompletionRequest>? requestModifier = null,
Expand All @@ -60,7 +60,7 @@ private static void ThrowUnkownProviderException(string provider)
/// <inheritdoc />
public Task<string> GetChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Action<ChatCompletionRequest>? requestModifier = null,
Expand All @@ -73,7 +73,7 @@ private static void ThrowUnkownProviderException(string provider)
/// <inheritdoc />
public Task<ChatCompletionResponse> GetChatCompletionsRaw(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Action<ChatCompletionRequest>? requestModifier = null,
Expand All @@ -86,7 +86,7 @@ private static void ThrowUnkownProviderException(string provider)
/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Action<ChatCompletionRequest>? requestModifier = null,
Expand All @@ -99,7 +99,7 @@ private static void ThrowUnkownProviderException(string provider)
/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(
UserOrSystemMessage messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default,
int? maxTokens = null, string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false,
long? seed = null, Action<ChatCompletionRequest>? requestModifier = null,
CancellationToken cancellationToken = default)
Expand All @@ -108,6 +108,12 @@ private static void ThrowUnkownProviderException(string provider)
messages, maxTokens, model, temperature, user, jsonMode, seed, requestModifier, cancellationToken);
}

/// <inheritdoc />
public int? GetDefaultMaxTokens(string model)
{
return _client.GetDefaultMaxTokens(model);
}

/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(
ChatCompletionRequest request,
Expand All @@ -124,5 +130,10 @@ private static void ThrowUnkownProviderException(string provider)
return _client.StreamChatCompletionsRaw(request, cancellationToken);
}

public string GetOptimalModel(ChatCompletionMessage[] messages)
{
return _client.GetOptimalModel(messages);
}

internal IAiClient GetInnerClient() => _client;
}
16 changes: 4 additions & 12 deletions src/OpenAI.ChatGpt/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,10 @@ public async ValueTask DisposeAsync()
IsWriting = true;
try
{
var (model, maxTokens) = FindOptimalModelAndMaxToken(messages);
var response = await _client.GetChatCompletionsRaw(
messages,
maxTokens: maxTokens,
model: model,
maxTokens: Topic.Config.MaxTokens,
model:Topic.Config.Model ?? _client.GetOptimalModel(message),
user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null,
requestModifier: Topic.Config.ModifyRequest,
cancellationToken: cancellationToken
Expand All @@ -125,12 +124,6 @@ public async ValueTask DisposeAsync()
}
}

private (string model, int maxTokens) FindOptimalModelAndMaxToken(ChatCompletionMessage[] messages)
{
return ChatCompletionMessage.FindOptimalModelAndMaxToken(
messages, Topic.Config.Model, Topic.Config.MaxTokens);
}

public IAsyncEnumerable<string> StreamNextMessageResponse(
string message,
bool throwOnCancellation = true,
Expand Down Expand Up @@ -159,11 +152,10 @@ public async ValueTask DisposeAsync()
var messages = history.Append(message).ToArray();
var sb = new StringBuilder();
IsWriting = true;
var (model, maxTokens) = FindOptimalModelAndMaxToken(messages);
var stream = _client.StreamChatCompletions(
messages,
maxTokens: maxTokens,
model: model,
maxTokens: Topic.Config.MaxTokens,
model:Topic.Config.Model ?? _client.GetOptimalModel(message),
user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null,
requestModifier: Topic.Config.ModifyRequest,
cancellationToken: cancellationToken
Expand Down
24 changes: 19 additions & 5 deletions src/OpenAI.ChatGpt/IAiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ namespace OpenAI.ChatGpt;
/// </summary>
public interface IAiClient
{
/// <summary>
/// Retrieves the default maximum number of tokens for a given model.
/// </summary>
/// <param name="model">
/// The model name for which to retrieve the maximum number of tokens.
/// </param>
/// <returns>
/// The default maximum number of tokens as an integer or just null if it's reqired to delegate it to the AI service.
/// </returns>
int? GetDefaultMaxTokens(string model);

/// <summary>
/// Get a chat completion response as a string
/// </summary>
Expand Down Expand Up @@ -41,7 +52,7 @@ public interface IAiClient
/// <returns>The chat completion response as a string</returns>
Task<string> GetChatCompletions(
UserOrSystemMessage dialog,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand Down Expand Up @@ -84,7 +95,7 @@ public interface IAiClient
/// <returns>The chat completion response as a string</returns>
Task<string> GetChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand Down Expand Up @@ -126,7 +137,7 @@ public interface IAiClient
/// <returns>The raw chat completion response</returns>
Task<ChatCompletionResponse> GetChatCompletionsRaw(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand Down Expand Up @@ -167,7 +178,7 @@ public interface IAiClient
/// <returns>Chunks of LLM's response, one by one.</returns>
IAsyncEnumerable<string> StreamChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand Down Expand Up @@ -201,7 +212,7 @@ public interface IAiClient
/// <returns>Chunks of LLM's response, one by one</returns>
IAsyncEnumerable<string> StreamChatCompletions(
UserOrSystemMessage messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand All @@ -227,4 +238,7 @@ public interface IAiClient
/// <returns>A stream of raw chat completion responses</returns>
IAsyncEnumerable<ChatCompletionResponse> StreamChatCompletionsRaw(
ChatCompletionRequest request, CancellationToken cancellationToken = default);

string GetOptimalModel(ChatCompletionMessage[] messages);
string GetOptimalModel(UserOrSystemMessage dialog) => GetOptimalModel(dialog.GetMessages().ToArray());
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ namespace OpenAI.ChatGpt.Models.ChatCompletion;
/// </remarks>
public class ChatCompletionRequest
{
public const int MaxTokensDefault = 64;

private int _maxTokens = MaxTokensDefault;
private string _model = ChatCompletionModels.Default;
private float _temperature = ChatCompletionTemperatures.Default;
private IEnumerable<ChatCompletionMessage> _messages;
Expand Down Expand Up @@ -87,7 +84,6 @@ public float Temperature

/// <summary>
/// The maximum number of tokens allowed for the generated answer.
/// Defaults to <see cref="MaxTokensDefault"/>.
/// This value is validated and limited with <see cref="ChatCompletionModels.GetMaxTokensLimitForModel"/> method.
/// It's possible to calculate approximately tokens count using <see cref="ChatCompletionMessage.CalculateApproxTotalTokenCount()"/> method.
/// </summary>
Expand All @@ -98,15 +94,7 @@ public float Temperature
/// Encoding algorithm can be found here: https://github.com/latitudegames/GPT-3-Encoder
/// </remarks>
[JsonPropertyName("max_tokens")]
public int MaxTokens
{
get => _maxTokens;
set
{
ChatCompletionModels.EnsureMaxTokensIsSupported(Model, value);
_maxTokens = value;
}
}
public int? MaxTokens { get; set; } = null;

/// <summary>
/// Number between -2.0 and 2.0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,34 +113,4 @@ public override string ToString()
? $"{Role}: {Content}"
: string.Join(Environment.NewLine, _messages.Select(m => $"{m.Role}: {m.Content}"));
}

public static (string model, int maxTokens) FindOptimalModelAndMaxToken(
IEnumerable<ChatCompletionMessage> messages,
string? model,
int? maxTokens,
string smallModel = ChatCompletionModels.Default,
string bigModel = ChatCompletionModels.Gpt3_5_Turbo_16k,
bool useMaxPossibleTokens = true)
{
var tokenCount = CalculateApproxTotalTokenCount(messages);
switch (model, maxTokens)
{
case (null, null):
{
model = tokenCount > 6000 ? bigModel : smallModel;
maxTokens = GetMaxPossibleTokens(model);
break;
}
case (null, _):
model = smallModel;
break;
case (_, null):
maxTokens = useMaxPossibleTokens ? GetMaxPossibleTokens(model) : ChatCompletionRequest.MaxTokensDefault;
break;
}

return (model, maxTokens.Value);

int GetMaxPossibleTokens(string s) => ChatCompletionModels.GetMaxTokensLimitForModel(s) - tokenCount - 500;
}
}
1 change: 0 additions & 1 deletion src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ public class ChatGPTConfig

/// <summary>
/// The maximum number of tokens allowed for the generated answer.
/// Defaults to <see cref="ChatCompletionRequest.MaxTokensDefault"/>.
/// This value is validated and limited with <see cref="ChatCompletionModels.GetMaxTokensLimitForModel"/> method.
/// It's possible to calculate approximately tokens count using <see cref="ChatCompletionMessage.CalculateApproxTotalTokenCount()"/> method.
/// Maps to: <see cref="ChatCompletionRequest.MaxTokens"/>
Expand Down
30 changes: 21 additions & 9 deletions src/OpenAI.ChatGpt/OpenAiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private static Uri ValidateHost(string? host)
/// <inheritdoc />
public async Task<string> GetChatCompletions(
UserOrSystemMessage dialog,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand Down Expand Up @@ -185,7 +185,7 @@ private static Uri ValidateHost(string? host)
/// <inheritdoc />
public async Task<string> GetChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand Down Expand Up @@ -218,7 +218,7 @@ private static Uri ValidateHost(string? host)
/// <inheritdoc />
public async Task<ChatCompletionResponse> GetChatCompletionsRaw(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand Down Expand Up @@ -277,7 +277,7 @@ protected virtual string GetChatCompletionsEndpoint()
/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand All @@ -304,9 +304,9 @@ protected virtual string GetChatCompletionsEndpoint()
return StreamChatCompletions(request, cancellationToken);
}

private static ChatCompletionRequest CreateChatCompletionRequest(
private ChatCompletionRequest CreateChatCompletionRequest(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens,
int? maxTokens,
string model,
float temperature,
string? user,
Expand All @@ -316,6 +316,7 @@ protected virtual string GetChatCompletionsEndpoint()
Action<ChatCompletionRequest>? requestModifier)
{
ArgumentNullException.ThrowIfNull(messages);
maxTokens ??= GetDefaultMaxTokens(model);
var request = new ChatCompletionRequest(messages)
{
Model = model,
Expand All @@ -330,10 +331,15 @@ protected virtual string GetChatCompletionsEndpoint()
return request;
}

public int? GetDefaultMaxTokens(string model)
{
return null;
}

/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(
UserOrSystemMessage messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
int? maxTokens = null,
string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default,
string? user = null,
Expand All @@ -346,7 +352,8 @@ protected virtual string GetChatCompletionsEndpoint()
if (model == null) throw new ArgumentNullException(nameof(model));
EnsureJsonModeIsSupported(model, jsonMode);
ThrowIfDisposed();
var request = CreateChatCompletionRequest(messages.GetMessages(),
var request = CreateChatCompletionRequest(
messages.GetMessages(),
maxTokens,
model,
temperature,
Expand Down Expand Up @@ -393,7 +400,12 @@ await foreach (var response in StreamChatCompletionsRaw(request, cancellationTok
cancellationToken
);
}


public string GetOptimalModel(ChatCompletionMessage[] messages)
{
return ChatCompletionModels.Gpt4Turbo;
}

private static void EnsureJsonModeIsSupported(string model, bool jsonMode)
{
if(jsonMode && !ChatCompletionModels.IsJsonModeSupported(model))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public static GeneratedOpenAiClient CreateGeneratedOpenAiClient(HttpClient httpC
ArgumentNullException.ThrowIfNull(httpClient);
var authProvider = new AnonymousAuthenticationProvider();
var adapter = new HttpClientRequestAdapter(authProvider, httpClient: httpClient);
return new GeneratedOpenAiClient(adapter);
var openAiClient = new GeneratedOpenAiClient(adapter);
return openAiClient;
}

public static GeneratedAzureOpenAiClient CreateGeneratedAzureOpenAiClient(HttpClient httpClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,11 @@ public static class OpenAiClientExtensions
{
editMsg.Content += GetAdditionalJsonResponsePrompt(responseFormat, examples, jsonSerializerOptions);

(model, maxTokens) = FindOptimalModelAndMaxToken(
dialog.GetMessages(),
model,
maxTokens,
smallModel: ChatCompletionModels.Gpt4,
bigModel: ChatCompletionModels.Gpt4
);
model ??= client.GetOptimalModel(dialog);

var response = await client.GetChatCompletions(
dialog,
maxTokens.Value,
maxTokens,
model,
temperature,
user,
Expand Down

0 comments on commit 613d5f6

Please sign in to comment.