2
2
3
3
using System ;
4
4
using System . Collections . Generic ;
5
- using System . Diagnostics . CodeAnalysis ;
6
- using System . Runtime . CompilerServices ;
7
5
using System . Text ;
8
6
using System . Text . Json ;
9
7
using System . Threading ;
@@ -20,12 +18,13 @@ namespace Microsoft.SemanticKernel.Connectors.Onnx;
20
18
/// </summary>
21
19
public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionService , IDisposable
22
20
{
23
- private readonly string _modelId ;
24
21
private readonly string _modelPath ;
25
- private readonly JsonSerializerOptions ? _jsonSerializerOptions ;
26
- private Model ? _model ;
27
- private Tokenizer ? _tokenizer ;
28
- private Dictionary < string , object ? > AttributesInternal { get ; } = new ( ) ;
22
+ private OnnxRuntimeGenAIChatClient ? _chatClient ;
23
+ private IChatCompletionService ? _chatClientWrapper ;
24
+ private readonly Dictionary < string , object ? > _attributesInternal = [ ] ;
25
+
26
+ /// <inheritdoc/>
27
+ public IReadOnlyDictionary < string , object ? > Attributes => this . _attributesInternal ;
29
28
30
29
/// <summary>
31
30
/// Initializes a new instance of the OnnxRuntimeGenAIChatCompletionService class.
@@ -43,174 +42,38 @@ public OnnxRuntimeGenAIChatCompletionService(
43
42
Verify . NotNullOrWhiteSpace ( modelId ) ;
44
43
Verify . NotNullOrWhiteSpace ( modelPath ) ;
45
44
46
- this . _modelId = modelId ;
45
+ this . _attributesInternal . Add ( AIServiceExtensions . ModelIdKey , modelId ) ;
47
46
this . _modelPath = modelPath ;
48
- this . _jsonSerializerOptions = jsonSerializerOptions ;
49
- this . AttributesInternal . Add ( AIServiceExtensions . ModelIdKey , this . _modelId ) ;
50
- }
51
-
52
- /// <inheritdoc />
53
- public IReadOnlyDictionary < string , object ? > Attributes => this . AttributesInternal ;
54
-
55
- /// <inheritdoc />
56
- public async Task < IReadOnlyList < ChatMessageContent > > GetChatMessageContentsAsync ( ChatHistory chatHistory , PromptExecutionSettings ? executionSettings = null , Kernel ? kernel = null , CancellationToken cancellationToken = default )
57
- {
58
- var result = new StringBuilder ( ) ;
59
-
60
- await foreach ( var content in this . RunInferenceAsync ( chatHistory , executionSettings , cancellationToken ) . ConfigureAwait ( false ) )
61
- {
62
- result . Append ( content ) ;
63
- }
64
-
65
- return new List < ChatMessageContent >
66
- {
67
- new (
68
- role : AuthorRole . Assistant ,
69
- modelId : this . _modelId ,
70
- content : result . ToString ( ) )
71
- } ;
72
- }
73
-
74
- /// <inheritdoc />
75
- public async IAsyncEnumerable < StreamingChatMessageContent > GetStreamingChatMessageContentsAsync (
76
- ChatHistory chatHistory ,
77
- PromptExecutionSettings ? executionSettings = null ,
78
- Kernel ? kernel = null ,
79
- [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
80
- {
81
- await foreach ( var content in this . RunInferenceAsync ( chatHistory , executionSettings , cancellationToken ) . ConfigureAwait ( false ) )
82
- {
83
- yield return new StreamingChatMessageContent ( AuthorRole . Assistant , content , modelId : this . _modelId ) ;
84
- }
85
47
}
86
48
87
- private async IAsyncEnumerable < string > RunInferenceAsync ( ChatHistory chatHistory , PromptExecutionSettings ? executionSettings , [ EnumeratorCancellation ] CancellationToken cancellationToken )
49
+ private IChatCompletionService GetChatCompletionService ( )
88
50
{
89
- OnnxRuntimeGenAIPromptExecutionSettings onnxPromptExecutionSettings = this . GetOnnxPromptExecutionSettingsSettings ( executionSettings ) ;
90
-
91
- var prompt = this . GetPrompt ( chatHistory , onnxPromptExecutionSettings ) ;
92
- using var tokens = this . GetTokenizer ( ) . Encode ( prompt ) ;
93
-
94
- using var generatorParams = new GeneratorParams ( this . GetModel ( ) ) ;
95
- this . UpdateGeneratorParamsFromPromptExecutionSettings ( generatorParams , onnxPromptExecutionSettings ) ;
96
-
97
- using var generator = new Generator ( this . GetModel ( ) , generatorParams ) ;
98
- generator . AppendTokenSequences ( tokens ) ;
99
-
100
- bool removeNextTokenStartingWithSpace = true ;
101
- while ( ! generator . IsDone ( ) )
51
+ this . _chatClient ??= new OnnxRuntimeGenAIChatClient ( this . _modelPath , new OnnxRuntimeGenAIChatClientOptions ( )
102
52
{
103
- cancellationToken . ThrowIfCancellationRequested ( ) ;
104
-
105
- yield return await Task . Run ( ( ) =>
53
+ PromptFormatter = ( messages , options ) =>
106
54
{
107
- generator . GenerateNextToken ( ) ;
108
-
109
- var outputTokens = generator . GetSequence ( 0 ) ;
110
- var newToken = outputTokens [ outputTokens . Length - 1 ] ;
111
-
112
- using var tokenizerStream = this . GetTokenizer ( ) . CreateStream ( ) ;
113
- string output = tokenizerStream . Decode ( newToken ) ;
114
-
115
- if ( removeNextTokenStartingWithSpace && output [ 0 ] == ' ' )
55
+ StringBuilder promptBuilder = new ( ) ;
56
+ foreach ( var message in messages )
116
57
{
117
- removeNextTokenStartingWithSpace = false ;
118
- output = output . TrimStart ( ) ;
58
+ promptBuilder . Append ( $ "<|{ message . Role } |>\n { message . Text } ") ;
119
59
}
60
+ promptBuilder . Append ( "<|end|>\n <|assistant|>" ) ;
120
61
121
- return output ;
122
- } , cancellationToken ) . ConfigureAwait ( false ) ;
123
- }
124
- }
125
-
126
- private Model GetModel ( ) => this . _model ??= new Model ( this . _modelPath ) ;
127
-
128
- private Tokenizer GetTokenizer ( ) => this . _tokenizer ??= new Tokenizer ( this . GetModel ( ) ) ;
62
+ return promptBuilder . ToString ( ) ;
63
+ }
64
+ } ) ;
129
65
130
- private string GetPrompt ( ChatHistory chatHistory , OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings )
131
- {
132
- var promptBuilder = new StringBuilder ( ) ;
133
- foreach ( var message in chatHistory )
134
- {
135
- promptBuilder . Append ( $ "<|{ message . Role } |>\n { message . Content } ") ;
136
- }
137
- promptBuilder . Append ( "<|end|>\n <|assistant|>" ) ;
138
-
139
- return promptBuilder . ToString ( ) ;
66
+ return this . _chatClientWrapper ??= this . _chatClient . AsChatCompletionService ( ) ;
140
67
}
141
68
142
- private void UpdateGeneratorParamsFromPromptExecutionSettings ( GeneratorParams generatorParams , OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings )
143
- {
144
- if ( onnxRuntimeGenAIPromptExecutionSettings . TopP . HasValue )
145
- {
146
- generatorParams . SetSearchOption ( "top_p" , onnxRuntimeGenAIPromptExecutionSettings . TopP . Value ) ;
147
- }
148
- if ( onnxRuntimeGenAIPromptExecutionSettings . TopK . HasValue )
149
- {
150
- generatorParams . SetSearchOption ( "top_k" , onnxRuntimeGenAIPromptExecutionSettings . TopK . Value ) ;
151
- }
152
- if ( onnxRuntimeGenAIPromptExecutionSettings . Temperature . HasValue )
153
- {
154
- generatorParams . SetSearchOption ( "temperature" , onnxRuntimeGenAIPromptExecutionSettings . Temperature . Value ) ;
155
- }
156
- if ( onnxRuntimeGenAIPromptExecutionSettings . RepetitionPenalty . HasValue )
157
- {
158
- generatorParams . SetSearchOption ( "repetition_penalty" , onnxRuntimeGenAIPromptExecutionSettings . RepetitionPenalty . Value ) ;
159
- }
160
- if ( onnxRuntimeGenAIPromptExecutionSettings . PastPresentShareBuffer . HasValue )
161
- {
162
- generatorParams . SetSearchOption ( "past_present_share_buffer" , onnxRuntimeGenAIPromptExecutionSettings . PastPresentShareBuffer . Value ) ;
163
- }
164
- if ( onnxRuntimeGenAIPromptExecutionSettings . NumReturnSequences . HasValue )
165
- {
166
- generatorParams . SetSearchOption ( "num_return_sequences" , onnxRuntimeGenAIPromptExecutionSettings . NumReturnSequences . Value ) ;
167
- }
168
- if ( onnxRuntimeGenAIPromptExecutionSettings . NoRepeatNgramSize . HasValue )
169
- {
170
- generatorParams . SetSearchOption ( "no_repeat_ngram_size" , onnxRuntimeGenAIPromptExecutionSettings . NoRepeatNgramSize . Value ) ;
171
- }
172
- if ( onnxRuntimeGenAIPromptExecutionSettings . MinTokens . HasValue )
173
- {
174
- generatorParams . SetSearchOption ( "min_length" , onnxRuntimeGenAIPromptExecutionSettings . MinTokens . Value ) ;
175
- }
176
- if ( onnxRuntimeGenAIPromptExecutionSettings . MaxTokens . HasValue )
177
- {
178
- generatorParams . SetSearchOption ( "max_length" , onnxRuntimeGenAIPromptExecutionSettings . MaxTokens . Value ) ;
179
- }
180
- if ( onnxRuntimeGenAIPromptExecutionSettings . LengthPenalty . HasValue )
181
- {
182
- generatorParams . SetSearchOption ( "length_penalty" , onnxRuntimeGenAIPromptExecutionSettings . LengthPenalty . Value ) ;
183
- }
184
- if ( onnxRuntimeGenAIPromptExecutionSettings . EarlyStopping . HasValue )
185
- {
186
- generatorParams . SetSearchOption ( "early_stopping" , onnxRuntimeGenAIPromptExecutionSettings . EarlyStopping . Value ) ;
187
- }
188
- if ( onnxRuntimeGenAIPromptExecutionSettings . DoSample . HasValue )
189
- {
190
- generatorParams . SetSearchOption ( "do_sample" , onnxRuntimeGenAIPromptExecutionSettings . DoSample . Value ) ;
191
- }
192
- if ( onnxRuntimeGenAIPromptExecutionSettings . DiversityPenalty . HasValue )
193
- {
194
- generatorParams . SetSearchOption ( "diversity_penalty" , onnxRuntimeGenAIPromptExecutionSettings . DiversityPenalty . Value ) ;
195
- }
196
- }
197
-
198
- [ UnconditionalSuppressMessage ( "Trimming" , "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code" , Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via the class constructor." ) ]
199
- [ UnconditionalSuppressMessage ( "AOT" , "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling." , Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via class constructor." ) ]
200
- private OnnxRuntimeGenAIPromptExecutionSettings GetOnnxPromptExecutionSettingsSettings ( PromptExecutionSettings ? executionSettings )
201
- {
202
- if ( this . _jsonSerializerOptions is not null )
203
- {
204
- return OnnxRuntimeGenAIPromptExecutionSettings . FromExecutionSettings ( executionSettings , this . _jsonSerializerOptions ) ;
205
- }
69
+ /// <inheritdoc/>
70
+ public void Dispose ( ) => this . _chatClient ? . Dispose ( ) ;
206
71
207
- return OnnxRuntimeGenAIPromptExecutionSettings . FromExecutionSettings ( executionSettings ) ;
208
- }
72
+ /// <inheritdoc/>
73
+ public Task < IReadOnlyList < ChatMessageContent > > GetChatMessageContentsAsync ( ChatHistory chatHistory , PromptExecutionSettings ? executionSettings = null , Kernel ? kernel = null , CancellationToken cancellationToken = default ) =>
74
+ this . GetChatCompletionService ( ) . GetChatMessageContentsAsync ( chatHistory , executionSettings , kernel , cancellationToken ) ;
209
75
210
76
/// <inheritdoc/>
211
- public void Dispose ( )
212
- {
213
- this . _tokenizer ? . Dispose ( ) ;
214
- this . _model ? . Dispose ( ) ;
215
- }
77
+ public IAsyncEnumerable < StreamingChatMessageContent > GetStreamingChatMessageContentsAsync ( ChatHistory chatHistory , PromptExecutionSettings ? executionSettings = null , Kernel ? kernel = null , CancellationToken cancellationToken = default ) =>
78
+ this . GetChatCompletionService ( ) . GetStreamingChatMessageContentsAsync ( chatHistory , executionSettings , kernel , cancellationToken ) ;
216
79
}
0 commit comments