Skip to content

Commit d2120d4

Browse files
stephentoubRogerBarretomarkwallace-microsoft
authored
.Net: Implement OnnxRuntimeGenAIChatCompletionService on OnnxRuntimeGenAIChatClient (#12197)
Co-authored-by: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com>
1 parent aabfc9c commit d2120d4

File tree

2 files changed

+29
-166
lines changed

2 files changed

+29
-166
lines changed

dotnet/Directory.Packages.props

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
<PackageVersion Include="Microsoft.Identity.Client" Version="4.67.2" />
6969
<PackageVersion Include="Microsoft.Identity.Client.Extensions.Msal" Version="4.67.2" />
7070
<PackageVersion Include="Microsoft.IdentityModel.JsonWebTokens" Version="7.5.1" />
71-
<PackageVersion Include="Microsoft.ML.OnnxRuntime" Version="1.21.0" />
71+
<PackageVersion Include="Microsoft.ML.OnnxRuntime" Version="1.22.0" />
7272
<PackageVersion Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" Version="1.0.1" />
7373
<PackageVersion Include="Microsoft.SemanticKernel.Abstractions" Version="1.47.0" />
7474
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.OpenAI" Version="1.47.0" />
@@ -215,9 +215,9 @@
215215
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
216216
</PackageReference>
217217
<!-- OnnxRuntimeGenAI -->
218-
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.7.0-rc2" />
219-
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.7.0-rc2" />
220-
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.7.0-rc2" />
218+
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.8.1" />
219+
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.8.1" />
220+
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.8.1" />
221221
<!-- SpectreConsole-->
222222
<PackageVersion Include="Spectre.Console" Version="0.49.1" />
223223
<PackageVersion Include="Spectre.Console.Cli" Version="0.49.1" />

dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs

Lines changed: 25 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
using System;
44
using System.Collections.Generic;
5-
using System.Diagnostics.CodeAnalysis;
6-
using System.Runtime.CompilerServices;
75
using System.Text;
86
using System.Text.Json;
97
using System.Threading;
@@ -20,12 +18,13 @@ namespace Microsoft.SemanticKernel.Connectors.Onnx;
2018
/// </summary>
2119
public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionService, IDisposable
2220
{
23-
private readonly string _modelId;
2421
private readonly string _modelPath;
25-
private readonly JsonSerializerOptions? _jsonSerializerOptions;
26-
private Model? _model;
27-
private Tokenizer? _tokenizer;
28-
private Dictionary<string, object?> AttributesInternal { get; } = new();
22+
private OnnxRuntimeGenAIChatClient? _chatClient;
23+
private IChatCompletionService? _chatClientWrapper;
24+
private readonly Dictionary<string, object?> _attributesInternal = [];
25+
26+
/// <inheritdoc/>
27+
public IReadOnlyDictionary<string, object?> Attributes => this._attributesInternal;
2928

3029
/// <summary>
3130
/// Initializes a new instance of the OnnxRuntimeGenAIChatCompletionService class.
@@ -43,174 +42,38 @@ public OnnxRuntimeGenAIChatCompletionService(
4342
Verify.NotNullOrWhiteSpace(modelId);
4443
Verify.NotNullOrWhiteSpace(modelPath);
4544

46-
this._modelId = modelId;
45+
this._attributesInternal.Add(AIServiceExtensions.ModelIdKey, modelId);
4746
this._modelPath = modelPath;
48-
this._jsonSerializerOptions = jsonSerializerOptions;
49-
this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, this._modelId);
50-
}
51-
52-
/// <inheritdoc />
53-
public IReadOnlyDictionary<string, object?> Attributes => this.AttributesInternal;
54-
55-
/// <inheritdoc />
56-
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
57-
{
58-
var result = new StringBuilder();
59-
60-
await foreach (var content in this.RunInferenceAsync(chatHistory, executionSettings, cancellationToken).ConfigureAwait(false))
61-
{
62-
result.Append(content);
63-
}
64-
65-
return new List<ChatMessageContent>
66-
{
67-
new(
68-
role: AuthorRole.Assistant,
69-
modelId: this._modelId,
70-
content: result.ToString())
71-
};
72-
}
73-
74-
/// <inheritdoc />
75-
public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(
76-
ChatHistory chatHistory,
77-
PromptExecutionSettings? executionSettings = null,
78-
Kernel? kernel = null,
79-
[EnumeratorCancellation] CancellationToken cancellationToken = default)
80-
{
81-
await foreach (var content in this.RunInferenceAsync(chatHistory, executionSettings, cancellationToken).ConfigureAwait(false))
82-
{
83-
yield return new StreamingChatMessageContent(AuthorRole.Assistant, content, modelId: this._modelId);
84-
}
8547
}
8648

87-
private async IAsyncEnumerable<string> RunInferenceAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings, [EnumeratorCancellation] CancellationToken cancellationToken)
49+
private IChatCompletionService GetChatCompletionService()
8850
{
89-
OnnxRuntimeGenAIPromptExecutionSettings onnxPromptExecutionSettings = this.GetOnnxPromptExecutionSettingsSettings(executionSettings);
90-
91-
var prompt = this.GetPrompt(chatHistory, onnxPromptExecutionSettings);
92-
using var tokens = this.GetTokenizer().Encode(prompt);
93-
94-
using var generatorParams = new GeneratorParams(this.GetModel());
95-
this.UpdateGeneratorParamsFromPromptExecutionSettings(generatorParams, onnxPromptExecutionSettings);
96-
97-
using var generator = new Generator(this.GetModel(), generatorParams);
98-
generator.AppendTokenSequences(tokens);
99-
100-
bool removeNextTokenStartingWithSpace = true;
101-
while (!generator.IsDone())
51+
this._chatClient ??= new OnnxRuntimeGenAIChatClient(this._modelPath, new OnnxRuntimeGenAIChatClientOptions()
10252
{
103-
cancellationToken.ThrowIfCancellationRequested();
104-
105-
yield return await Task.Run(() =>
53+
PromptFormatter = (messages, options) =>
10654
{
107-
generator.GenerateNextToken();
108-
109-
var outputTokens = generator.GetSequence(0);
110-
var newToken = outputTokens[outputTokens.Length - 1];
111-
112-
using var tokenizerStream = this.GetTokenizer().CreateStream();
113-
string output = tokenizerStream.Decode(newToken);
114-
115-
if (removeNextTokenStartingWithSpace && output[0] == ' ')
55+
StringBuilder promptBuilder = new();
56+
foreach (var message in messages)
11657
{
117-
removeNextTokenStartingWithSpace = false;
118-
output = output.TrimStart();
58+
promptBuilder.Append($"<|{message.Role}|>\n{message.Text}");
11959
}
60+
promptBuilder.Append("<|end|>\n<|assistant|>");
12061

121-
return output;
122-
}, cancellationToken).ConfigureAwait(false);
123-
}
124-
}
125-
126-
private Model GetModel() => this._model ??= new Model(this._modelPath);
127-
128-
private Tokenizer GetTokenizer() => this._tokenizer ??= new Tokenizer(this.GetModel());
62+
return promptBuilder.ToString();
63+
}
64+
});
12965

130-
private string GetPrompt(ChatHistory chatHistory, OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings)
131-
{
132-
var promptBuilder = new StringBuilder();
133-
foreach (var message in chatHistory)
134-
{
135-
promptBuilder.Append($"<|{message.Role}|>\n{message.Content}");
136-
}
137-
promptBuilder.Append("<|end|>\n<|assistant|>");
138-
139-
return promptBuilder.ToString();
66+
return this._chatClientWrapper ??= this._chatClient.AsChatCompletionService();
14067
}
14168

142-
private void UpdateGeneratorParamsFromPromptExecutionSettings(GeneratorParams generatorParams, OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings)
143-
{
144-
if (onnxRuntimeGenAIPromptExecutionSettings.TopP.HasValue)
145-
{
146-
generatorParams.SetSearchOption("top_p", onnxRuntimeGenAIPromptExecutionSettings.TopP.Value);
147-
}
148-
if (onnxRuntimeGenAIPromptExecutionSettings.TopK.HasValue)
149-
{
150-
generatorParams.SetSearchOption("top_k", onnxRuntimeGenAIPromptExecutionSettings.TopK.Value);
151-
}
152-
if (onnxRuntimeGenAIPromptExecutionSettings.Temperature.HasValue)
153-
{
154-
generatorParams.SetSearchOption("temperature", onnxRuntimeGenAIPromptExecutionSettings.Temperature.Value);
155-
}
156-
if (onnxRuntimeGenAIPromptExecutionSettings.RepetitionPenalty.HasValue)
157-
{
158-
generatorParams.SetSearchOption("repetition_penalty", onnxRuntimeGenAIPromptExecutionSettings.RepetitionPenalty.Value);
159-
}
160-
if (onnxRuntimeGenAIPromptExecutionSettings.PastPresentShareBuffer.HasValue)
161-
{
162-
generatorParams.SetSearchOption("past_present_share_buffer", onnxRuntimeGenAIPromptExecutionSettings.PastPresentShareBuffer.Value);
163-
}
164-
if (onnxRuntimeGenAIPromptExecutionSettings.NumReturnSequences.HasValue)
165-
{
166-
generatorParams.SetSearchOption("num_return_sequences", onnxRuntimeGenAIPromptExecutionSettings.NumReturnSequences.Value);
167-
}
168-
if (onnxRuntimeGenAIPromptExecutionSettings.NoRepeatNgramSize.HasValue)
169-
{
170-
generatorParams.SetSearchOption("no_repeat_ngram_size", onnxRuntimeGenAIPromptExecutionSettings.NoRepeatNgramSize.Value);
171-
}
172-
if (onnxRuntimeGenAIPromptExecutionSettings.MinTokens.HasValue)
173-
{
174-
generatorParams.SetSearchOption("min_length", onnxRuntimeGenAIPromptExecutionSettings.MinTokens.Value);
175-
}
176-
if (onnxRuntimeGenAIPromptExecutionSettings.MaxTokens.HasValue)
177-
{
178-
generatorParams.SetSearchOption("max_length", onnxRuntimeGenAIPromptExecutionSettings.MaxTokens.Value);
179-
}
180-
if (onnxRuntimeGenAIPromptExecutionSettings.LengthPenalty.HasValue)
181-
{
182-
generatorParams.SetSearchOption("length_penalty", onnxRuntimeGenAIPromptExecutionSettings.LengthPenalty.Value);
183-
}
184-
if (onnxRuntimeGenAIPromptExecutionSettings.EarlyStopping.HasValue)
185-
{
186-
generatorParams.SetSearchOption("early_stopping", onnxRuntimeGenAIPromptExecutionSettings.EarlyStopping.Value);
187-
}
188-
if (onnxRuntimeGenAIPromptExecutionSettings.DoSample.HasValue)
189-
{
190-
generatorParams.SetSearchOption("do_sample", onnxRuntimeGenAIPromptExecutionSettings.DoSample.Value);
191-
}
192-
if (onnxRuntimeGenAIPromptExecutionSettings.DiversityPenalty.HasValue)
193-
{
194-
generatorParams.SetSearchOption("diversity_penalty", onnxRuntimeGenAIPromptExecutionSettings.DiversityPenalty.Value);
195-
}
196-
}
197-
198-
[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via the class constructor.")]
199-
[UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via class constructor.")]
200-
private OnnxRuntimeGenAIPromptExecutionSettings GetOnnxPromptExecutionSettingsSettings(PromptExecutionSettings? executionSettings)
201-
{
202-
if (this._jsonSerializerOptions is not null)
203-
{
204-
return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings, this._jsonSerializerOptions);
205-
}
69+
/// <inheritdoc/>
70+
public void Dispose() => this._chatClient?.Dispose();
20671

207-
return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
208-
}
72+
/// <inheritdoc/>
73+
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) =>
74+
this.GetChatCompletionService().GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
20975

21076
/// <inheritdoc/>
211-
public void Dispose()
212-
{
213-
this._tokenizer?.Dispose();
214-
this._model?.Dispose();
215-
}
77+
public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) =>
78+
this.GetChatCompletionService().GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
21679
}

0 commit comments

Comments
 (0)