Skip to content

Commit

Permalink
Relocate test and refine structure for ai-client multi-provider support
Browse files Browse the repository at this point in the history
The diff relocates the ChatGptEntityFrameworkIntegrationTests from the IntegrationTests project to the UnitTests project. This was necessary to prevent parallel execution. The diff also introduces minor structural changes, validating base addresses and authorization headers, as necessary, and adds a new method for aiClient creation for ubiquitous usage within services. This refactoring provides better support for handling multiple providers. Appropriate changes in variable lifetimes, argument orderings, exception messages fixes, and encapsulation of shared procedures are addressed. The update is designed to enhance code design and maintainability.

In the interest of structured code design and future assimilation of probable changes, the Commit clears any unused/unnecessary method arguments, deals with boolean flag adjustments, and refactors/overloads methods. The diff modifies IConfiguration object flow and moves AzureOpenAiClient's Settings validations for better code organization improving code aesthetics, clarity, and logic flow.
  • Loading branch information
rodion-m committed Dec 5, 2023
1 parent 86fd5b7 commit 78e1ae8
Show file tree
Hide file tree
Showing 14 changed files with 310 additions and 141 deletions.
4 changes: 3 additions & 1 deletion samples/ChatGpt.TelegramBotExample/Helpers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace ChatGpt.TelegramBotExample;
using OpenAI.ChatGpt;

namespace ChatGpt.TelegramBotExample;

public static class Helpers
{
Expand Down
2 changes: 1 addition & 1 deletion src/Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup>
<Version>4.0.0</Version>
<Version>4.0.0-alpha</Version>
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
Expand Down
27 changes: 18 additions & 9 deletions src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ private static void ThrowUnkownProviderException(string provider)
}

/// <inheritdoc />
public Task<string> GetChatCompletions(UserOrSystemMessage dialog,
public Task<string> GetChatCompletions(
UserOrSystemMessage dialog,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Expand All @@ -57,7 +58,8 @@ private static void ThrowUnkownProviderException(string provider)
}

/// <inheritdoc />
public Task<string> GetChatCompletions(IEnumerable<ChatCompletionMessage> messages,
public Task<string> GetChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Expand All @@ -69,7 +71,8 @@ private static void ThrowUnkownProviderException(string provider)
}

/// <inheritdoc />
public Task<ChatCompletionResponse> GetChatCompletionsRaw(IEnumerable<ChatCompletionMessage> messages,
public Task<ChatCompletionResponse> GetChatCompletionsRaw(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Expand All @@ -81,7 +84,8 @@ private static void ThrowUnkownProviderException(string provider)
}

/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(IEnumerable<ChatCompletionMessage> messages,
public IAsyncEnumerable<string> StreamChatCompletions(
IEnumerable<ChatCompletionMessage> messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault,
string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default,
string? user = null, bool jsonMode = false, long? seed = null,
Expand All @@ -93,27 +97,32 @@ private static void ThrowUnkownProviderException(string provider)
}

/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(UserOrSystemMessage messages,
public IAsyncEnumerable<string> StreamChatCompletions(
UserOrSystemMessage messages,
int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default,
float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false,
long? seed = null, Action<ChatCompletionRequest>? requestModifier = null,
CancellationToken cancellationToken = default)
{
return _client.StreamChatCompletions(messages, maxTokens, model, temperature, user, jsonMode, seed,
requestModifier, cancellationToken);
return _client.StreamChatCompletions(
messages, maxTokens, model, temperature, user, jsonMode, seed, requestModifier, cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<string> StreamChatCompletions(ChatCompletionRequest request,
public IAsyncEnumerable<string> StreamChatCompletions(
ChatCompletionRequest request,
CancellationToken cancellationToken = default)
{
return _client.StreamChatCompletions(request, cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<ChatCompletionResponse> StreamChatCompletionsRaw(ChatCompletionRequest request,
public IAsyncEnumerable<ChatCompletionResponse> StreamChatCompletionsRaw(
ChatCompletionRequest request,
CancellationToken cancellationToken = default)
{
return _client.StreamChatCompletionsRaw(request, cancellationToken);
}

internal IAiClient GetInnerClient() => _client;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ namespace OpenAI.ChatGpt.AspNetCore;

internal class AiClientStartupValidationBackgroundService : BackgroundService
{
private readonly AiClientFromConfiguration _aiClient;

public AiClientStartupValidationBackgroundService(AiClientFromConfiguration aiClient)
public AiClientStartupValidationBackgroundService(AiClientFromConfiguration _)
{
_aiClient = aiClient ?? throw new ArgumentNullException(nameof(aiClient));
}

protected override Task ExecuteAsync(CancellationToken stoppingToken) => Task.CompletedTask;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;

namespace OpenAI.ChatGpt.AspNetCore.Extensions;

Expand All @@ -14,12 +14,16 @@ public static class ServiceCollectionExtensions

public static IServiceCollection AddChatGptInMemoryIntegration(
this IServiceCollection services,
IConfiguration configuration,
bool injectInMemoryChatService = true,
string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault,
string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault,
string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault,
string azureOpenAiCredentialsConfigSectionPath = AzureOpenAiCredentialsConfigSectionPathDefault,
string openRouterCredentialsConfigSectionPath = OpenRouterCredentialsConfigSectionPathDefault,
bool validateAiClientProviderOnStart = true)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(configuration);
if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
Expand All @@ -32,15 +36,29 @@ public static class ServiceCollectionExtensions
nameof(completionsConfigSectionPath));
}

if (string.IsNullOrWhiteSpace(azureOpenAiCredentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(azureOpenAiCredentialsConfigSectionPath));
}
if (string.IsNullOrWhiteSpace(openRouterCredentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(openRouterCredentialsConfigSectionPath));
}

services.AddSingleton<IChatHistoryStorage, InMemoryChatHistoryStorage>();
if (injectInMemoryChatService)
{
services.AddScoped<ChatService>(CreateChatService);
}

return services.AddChatGptIntegrationCore(
credentialsConfigSectionPath: credentialsConfigSectionPath,
configuration,
completionsConfigSectionPath: completionsConfigSectionPath,
credentialsConfigSectionPath: credentialsConfigSectionPath,
azureOpenAiCredentialsConfigSectionPath,
openRouterCredentialsConfigSectionPath,
validateAiClientProviderOnStart: validateAiClientProviderOnStart
);
}
Expand Down Expand Up @@ -69,14 +87,16 @@ private static ChatService CreateChatService(IServiceProvider provider)
}

public static IServiceCollection AddChatGptIntegrationCore(this IServiceCollection services,
string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault,
IConfiguration configuration,
string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault,
string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault,
string azureOpenAiCredentialsConfigSectionPath = AzureOpenAiCredentialsConfigSectionPathDefault,
string openRouterCredentialsConfigSectionPath = OpenRouterCredentialsConfigSectionPathDefault,
ServiceLifetime gptFactoryLifetime = ServiceLifetime.Scoped,
bool validateAiClientProviderOnStart = true)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(configuration);
if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
Expand All @@ -89,12 +109,55 @@ private static ChatService CreateChatService(IServiceProvider provider)
nameof(completionsConfigSectionPath));
}

if (string.IsNullOrWhiteSpace(azureOpenAiCredentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(azureOpenAiCredentialsConfigSectionPath));
}
if (string.IsNullOrWhiteSpace(openRouterCredentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(openRouterCredentialsConfigSectionPath));
}

services.AddOptions<ChatGPTConfig>()
.BindConfiguration(completionsConfigSectionPath)
.Configure(_ => { }) //make optional
.ValidateDataAnnotations()
.ValidateOnStart();

services.AddSingleton<ITimeProvider, TimeProviderUtc>();
services.Add(new ServiceDescriptor(typeof(ChatGPTFactory), typeof(ChatGPTFactory), gptFactoryLifetime));

services.AddAiClient(configuration, credentialsConfigSectionPath, azureOpenAiCredentialsConfigSectionPath, openRouterCredentialsConfigSectionPath, validateAiClientProviderOnStart);

return services;
}

internal static void AddAiClient(
this IServiceCollection services,
IConfiguration configuration,
string credentialsConfigSectionPath,
string azureOpenAiCredentialsConfigSectionPath,
string openRouterCredentialsConfigSectionPath,
bool validateAiClientProviderOnStart)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(configuration);
if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath))
throw new ArgumentException("Value cannot be null or whitespace.", nameof(credentialsConfigSectionPath));
if (string.IsNullOrWhiteSpace(azureOpenAiCredentialsConfigSectionPath))
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(azureOpenAiCredentialsConfigSectionPath));
if (string.IsNullOrWhiteSpace(openRouterCredentialsConfigSectionPath))
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(openRouterCredentialsConfigSectionPath));

services.AddOptions<OpenAICredentials>()
.BindConfiguration(credentialsConfigSectionPath)
.Configure(_ => { }) //make optional
.ValidateDataAnnotations()
.ValidateOnStart();
.ValidateOnStart();
services.AddOptions<AzureOpenAICredentials>()
.BindConfiguration(azureOpenAiCredentialsConfigSectionPath)
.Configure(_ => { }) //make optional
Expand All @@ -105,30 +168,21 @@ private static ChatService CreateChatService(IServiceProvider provider)
.Configure(_ => { }) //make optional
.ValidateDataAnnotations()
.ValidateOnStart();

services.AddOptions<ChatGPTConfig>()
.BindConfiguration(completionsConfigSectionPath)
.Configure(_ => { }) //make optional
.ValidateDataAnnotations()
.ValidateOnStart();

services.AddSingleton<ITimeProvider, TimeProviderUtc>();
services.Add(new ServiceDescriptor(typeof(ChatGPTFactory), typeof(ChatGPTFactory), gptFactoryLifetime));

services.AddHttpClient(nameof(OpenAiClient));
services.AddHttpClient(nameof(AzureOpenAiClient));
services.AddHttpClient(nameof(OpenRouterClient));

services.AddSingleton<IAiClient, AiClientFromConfiguration>();
services.AddSingleton<AiClientFactory>();
#pragma warning disable CS0618 // Type or member is obsolete
// will be removed in 5.0
services.AddSingleton<IOpenAiClient, AiClientFromConfiguration>();
#pragma warning restore CS0618 // Type or member is obsolete

if (validateAiClientProviderOnStart)
{
services.AddHostedService<AiClientStartupValidationBackgroundService>();
}

return services;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using static OpenAI.ChatGpt.AspNetCore.Extensions.ServiceCollectionExtensions;

Expand All @@ -9,17 +10,20 @@ public static class ServiceCollectionExtensions
/// <summary>
/// Adds the <see cref="IChatHistoryStorage"/> implementation using Entity Framework Core.
/// </summary>
public static IServiceCollection AddChatGptEntityFrameworkIntegration(this IServiceCollection services,
public static IServiceCollection AddChatGptEntityFrameworkIntegration(
this IServiceCollection services,
Action<DbContextOptionsBuilder> optionsAction,
string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault,
IConfiguration configuration,
string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault,
string credentialsConfigSectionPath = OpenAiCredentialsConfigSectionPathDefault,
string azureOpenAiCredentialsConfigSectionPath = AzureOpenAiCredentialsConfigSectionPathDefault,
string openRouterCredentialsConfigSectionPath = OpenRouterCredentialsConfigSectionPathDefault,
ServiceLifetime serviceLifetime = ServiceLifetime.Scoped,
bool validateAiClientProviderOnStart = true)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(optionsAction);
ArgumentNullException.ThrowIfNull(configuration);
if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
Expand Down Expand Up @@ -48,8 +52,10 @@ public static class ServiceCollectionExtensions
throw new ArgumentOutOfRangeException(nameof(serviceLifetime), serviceLifetime, null);
}

return services.AddChatGptIntegrationCore(credentialsConfigSectionPath: credentialsConfigSectionPath,
return services.AddChatGptIntegrationCore(
configuration,
completionsConfigSectionPath: completionsConfigSectionPath,
credentialsConfigSectionPath: credentialsConfigSectionPath,
azureOpenAiCredentialsConfigSectionPath: azureOpenAiCredentialsConfigSectionPath,
openRouterCredentialsConfigSectionPath: openRouterCredentialsConfigSectionPath,
serviceLifetime,
Expand Down
6 changes: 4 additions & 2 deletions src/OpenAI.ChatGpt/AzureOpenAiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ internal static void SetupHttpClient(HttpClient httpClient, string endpointUrl,
httpClient.DefaultRequestHeaders.Add("api-key", azureKey);
}

public AzureOpenAiClient(HttpClient httpClient, string apiVersion) : base(httpClient)
public AzureOpenAiClient(HttpClient httpClient, string apiVersion)
: base(httpClient, validateAuthorizationHeader: false, validateBaseAddress: true)
{
_apiVersion = apiVersion ?? throw new ArgumentNullException(nameof(apiVersion));
}

public AzureOpenAiClient(HttpClient httpClient) : base(httpClient)
public AzureOpenAiClient(HttpClient httpClient)
: base(httpClient, validateAuthorizationHeader: false, validateBaseAddress: true)
{
_apiVersion = DefaultApiVersion;
}
Expand Down
11 changes: 3 additions & 8 deletions src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ public class ChatGPTConfig
};

private int? _maxTokens;
private string? _model;
private float? _temperature;

/// <summary>
Expand Down Expand Up @@ -75,7 +74,7 @@ public class ChatGPTConfig
{
if (value is { } maxTokens)
{
if (_model is { } model)
if (Model is { } model)
{
ChatCompletionModels.EnsureMaxTokensIsSupported(model, maxTokens);
}
Expand All @@ -93,11 +92,7 @@ public class ChatGPTConfig
/// ID of the model to use. One of: <see cref="ChatCompletionModels"/>
/// Maps to: <see cref="ChatCompletionRequest.Model"/>
/// </summary>
public string? Model
{
get => _model;
set => _model = value;
}
public string? Model { get; set; }

/// <summary>
/// What sampling temperature to use, between 0 and 2.
Expand Down Expand Up @@ -161,7 +156,7 @@ internal void ModifyRequest(ChatCompletionRequest request)
(not null, null) => baseConfig,
_ => new ChatGPTConfig()
{
_model = config._model ?? baseConfig._model,
Model = config.Model ?? baseConfig.Model,
_maxTokens = config._maxTokens ?? baseConfig._maxTokens,
_temperature = config._temperature ?? baseConfig._temperature,
PassUserIdToOpenAiRequests = config.PassUserIdToOpenAiRequests ??
Expand Down

0 comments on commit 78e1ae8

Please sign in to comment.