Skip to content

Commit

Permalink
Clean up exposure and handling of CancellationToken (microsoft#612)
Browse files Browse the repository at this point in the history
### Motivation and Context

Clean up exposure and handling of CancellationToken in public APIs.

### Description

- In public API, they should always be the last parameter (unless they
can't be, e.g. due to a params array), named "cancellationToken", and be
non-nullable.
- Ensure they're forwarded appropriately from one API to the next. There
may be some more missing cases, but I fixed all the ones I found.
- Make the XML docs consistent.
- Remove unnecessary cancellation-related analyzer suppressions
  • Loading branch information
stephentoub authored and codebrain committed May 16, 2023
1 parent 37c2122 commit f28579e
Show file tree
Hide file tree
Showing 47 changed files with 405 additions and 409 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class ClientBase
/// </summary>
/// <param name="text">The prompt to complete.</param>
/// <param name="requestSettings">Request settings for the completion API</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Text generated by the remote model</returns>
protected async Task<string> InternalCompleteTextAsync(
string text,
Expand Down Expand Up @@ -87,7 +87,7 @@ public abstract class ClientBase
/// Generates an embedding from the given <paramref name="data"/>.
/// </summary>
/// <param name="data">List of strings to generate embeddings for</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of embeddings</returns>
protected async Task<IList<Embedding<float>>> InternalGenerateTextEmbeddingsAsync(
IList<string> data,
Expand Down Expand Up @@ -193,7 +193,7 @@ protected ChatHistory InternalCreateNewChat(string instructions = "")
/// </summary>
/// <param name="text">The prompt to complete.</param>
/// <param name="requestSettings">Request settings for the completion API</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Text generated by the remote model</returns>
protected async Task<string> InternalCompleteTextUsingChatAsync(
string text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ internal OpenAIClientBase(ILogger? log = null, IDelegatingHandlerFactory? handle
/// </summary>
/// <param name="url">URL for the text embedding request API</param>
/// <param name="requestBody">Request payload</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of text embeddings</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingRequestAsync(
Expand Down Expand Up @@ -93,7 +93,7 @@ internal OpenAIClientBase(ILogger? log = null, IDelegatingHandlerFactory? handle
/// </summary>
/// <param name="url">URL for the image generation request API</param>
/// <param name="requestBody">Request payload</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of image URLs</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<string>> ExecuteImageUrlGenerationRequestAsync(
Expand All @@ -119,7 +119,7 @@ internal OpenAIClientBase(ILogger? log = null, IDelegatingHandlerFactory? handle
/// </summary>
/// <param name="url">URL for the image generation request API</param>
/// <param name="requestBody">Request payload</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of images serialized in base64</returns>
/// <exception cref="AIException">AIException thrown during the request.</exception>
protected async Task<IList<string>> ExecuteImageBase64GenerationRequestAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public sealed class AzureTextEmbeddingGeneration : AzureOpenAIClientBase, IEmbed
/// Generates an embedding from the given <paramref name="data"/>.
/// </summary>
/// <param name="data">List of strings to generate embeddings for</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of embeddings</returns>
public Task<IList<Embedding<float>>> GenerateEmbeddingsAsync(
IList<string> data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public sealed class OpenAITextEmbeddingGeneration : OpenAIClientBase, IEmbedding
/// Generates an embedding from the given <paramref name="data"/>.
/// </summary>
/// <param name="data">List of strings to generate embeddings for</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of embeddings</returns>
public Task<IList<Embedding<float>>> GenerateEmbeddingsAsync(
IList<string> data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ private CosmosMemoryStore()
/// <param name="client">Client with endpoint and authentication to the Azure CosmosDB Account.</param>
/// <param name="databaseName">The name of the database to back the memory store.</param>
/// <param name="log">Optional logger.</param>
/// <param name="cancel">Optional cancellation token.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <exception cref="CosmosException"></exception>
public static async Task<CosmosMemoryStore> CreateAsync(CosmosClient client, string databaseName, ILogger? log = null, CancellationToken cancel = default)
public static async Task<CosmosMemoryStore> CreateAsync(CosmosClient client, string databaseName, ILogger? log = null, CancellationToken cancellationToken = default)
{
var newStore = new CosmosMemoryStore();

newStore._databaseName = databaseName;
newStore._log = log ?? NullLogger<CosmosMemoryStore>.Instance;
var response = await client.CreateDatabaseIfNotExistsAsync(newStore._databaseName, cancellationToken: cancel).ConfigureAwait(false);
var response = await client.CreateDatabaseIfNotExistsAsync(newStore._databaseName, cancellationToken: cancellationToken).ConfigureAwait(false);

if (response.StatusCode == HttpStatusCode.Created)
{
Expand All @@ -69,7 +69,7 @@ public static async Task<CosmosMemoryStore> CreateAsync(CosmosClient client, str
}

/// <inheritdoc />
public IAsyncEnumerable<string> GetCollectionsAsync(CancellationToken cancel = default)
public IAsyncEnumerable<string> GetCollectionsAsync(CancellationToken cancellationToken = default)
{
// Azure Cosmos DB does not support listing all Containers, this does not break the interface but it is not ideal.
this._log.LogWarning("Listing all containers is not supported by Azure Cosmos DB, returning empty list.");
Expand All @@ -78,9 +78,9 @@ public IAsyncEnumerable<string> GetCollectionsAsync(CancellationToken cancel = d
}

/// <inheritdoc />
public async Task CreateCollectionAsync(string collectionName, CancellationToken cancel = default)
public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
var response = await this._database.CreateContainerIfNotExistsAsync(collectionName, "/" + collectionName, cancellationToken: cancel).ConfigureAwait(false);
var response = await this._database.CreateContainerIfNotExistsAsync(collectionName, "/" + collectionName, cancellationToken: cancellationToken).ConfigureAwait(false);

if (response.StatusCode == HttpStatusCode.Created)
{
Expand All @@ -97,20 +97,20 @@ public async Task CreateCollectionAsync(string collectionName, CancellationToken
}

/// <inheritdoc />
public Task<bool> DoesCollectionExistAsync(string collectionName, CancellationToken cancel = default)
public Task<bool> DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default)
{
// Azure Cosmos DB does not support checking if container exists without attempting to create it.
// Note that CreateCollectionIfNotExistsAsync() is idempotent. This does not break the interface but it is not ideal.
return Task.FromResult(false);
}

/// <inheritdoc />
public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancel = default)
public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
var container = this._database.Client.GetContainer(this._databaseName, collectionName);
try
{
await container.DeleteContainerAsync(cancellationToken: cancel).ConfigureAwait(false);
await container.DeleteContainerAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (CosmosException ex)
{
Expand All @@ -119,15 +119,15 @@ public async Task DeleteCollectionAsync(string collectionName, CancellationToken
}

/// <inheritdoc />
public async Task<MemoryRecord?> GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancel = default)
public async Task<MemoryRecord?> GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default)
{
var id = this.ToCosmosFriendlyId(key);
var partitionKey = PartitionKey.None;

var container = this._database.Client.GetContainer(this._databaseName, collectionName);
MemoryRecord? memoryRecord = null;

var response = await container.ReadItemAsync<CosmosMemoryRecord>(id, partitionKey, cancellationToken: cancel).ConfigureAwait(false);
var response = await container.ReadItemAsync<CosmosMemoryRecord>(id, partitionKey, cancellationToken: cancellationToken).ConfigureAwait(false);

if (response == null)
{
Expand Down Expand Up @@ -158,11 +158,11 @@ public async Task DeleteCollectionAsync(string collectionName, CancellationToken

/// <inheritdoc/>
public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(string collectionName, IEnumerable<string> keys, bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancel = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var key in keys)
{
var record = await this.GetAsync(collectionName, key, withEmbeddings, cancel).ConfigureAwait(false);
var record = await this.GetAsync(collectionName, key, withEmbeddings, cancellationToken).ConfigureAwait(false);

if (record != null)
{
Expand All @@ -172,7 +172,7 @@ public async Task DeleteCollectionAsync(string collectionName, CancellationToken
}

/// <inheritdoc />
public async Task<string> UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancel = default)
public async Task<string> UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default)
{
record.Key = this.ToCosmosFriendlyId(record.Metadata.Id);

Expand All @@ -187,7 +187,7 @@ public async Task<string> UpsertAsync(string collectionName, MemoryRecord record

var container = this._database.Client.GetContainer(this._databaseName, collectionName);

var response = await container.UpsertItemAsync(entity, cancellationToken: cancel).ConfigureAwait(false);
var response = await container.UpsertItemAsync(entity, cancellationToken: cancellationToken).ConfigureAwait(false);

if (response.StatusCode == HttpStatusCode.OK || response.StatusCode == HttpStatusCode.Created)
{
Expand All @@ -202,22 +202,22 @@ public async Task<string> UpsertAsync(string collectionName, MemoryRecord record
}

/// <inheritdoc/>
public async IAsyncEnumerable<string> UpsertBatchAsync(string collectionName, IEnumerable<MemoryRecord> records, [EnumeratorCancellation] CancellationToken cancel = default)
public async IAsyncEnumerable<string> UpsertBatchAsync(string collectionName, IEnumerable<MemoryRecord> records, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var r in records)
{
yield return await this.UpsertAsync(collectionName, r, cancel).ConfigureAwait(false);
yield return await this.UpsertAsync(collectionName, r, cancellationToken).ConfigureAwait(false);
}
}

/// <inheritdoc />
public async Task RemoveAsync(string collectionName, string key, CancellationToken cancel = default)
public async Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default)
{
var container = this._database.Client.GetContainer(this._databaseName, collectionName);
var response = await container.DeleteItemAsync<CosmosMemoryRecord>(
key,
PartitionKey.None,
cancellationToken: cancel).ConfigureAwait(false);
cancellationToken: cancellationToken).ConfigureAwait(false);

if (response.StatusCode == HttpStatusCode.OK)
{
Expand All @@ -230,9 +230,9 @@ public async Task RemoveAsync(string collectionName, string key, CancellationTok
}

/// <inheritdoc/>
public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> keys, CancellationToken cancel = default)
public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> keys, CancellationToken cancellationToken = default)
{
await Task.WhenAll(keys.Select(k => this.RemoveAsync(collectionName, k, cancel))).ConfigureAwait(false);
await Task.WhenAll(keys.Select(k => this.RemoveAsync(collectionName, k, cancellationToken))).ConfigureAwait(false);
}

/// <inheritdoc />
Expand All @@ -242,7 +242,7 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> ke
int limit,
double minRelevanceScore = 0,
bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancel = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
{
if (limit <= 0)
Expand All @@ -253,7 +253,7 @@ public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> ke
var collectionMemories = new List<MemoryRecord>();
TopNCollection<MemoryRecord> embeddings = new(limit);

await foreach (var record in this.GetAllAsync(collectionName, cancel))
await foreach (var record in this.GetAllAsync(collectionName, cancellationToken))
{
if (record != null)
{
Expand All @@ -279,18 +279,18 @@ await foreach (var record in this.GetAllAsync(collectionName, cancel))

/// <inheritdoc/>
public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, Embedding<float> embedding, double minRelevanceScore = 0, bool withEmbedding = false,
CancellationToken cancel = default)
CancellationToken cancellationToken = default)
{
return await this.GetNearestMatchesAsync(
collectionName: collectionName,
embedding: embedding,
limit: 1,
minRelevanceScore: minRelevanceScore,
withEmbeddings: withEmbedding,
cancel: cancel).FirstOrDefaultAsync(cancellationToken: cancel).ConfigureAwait(false);
cancellationToken: cancellationToken).FirstOrDefaultAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
}

private async IAsyncEnumerable<MemoryRecord> GetAllAsync(string collectionName, [EnumeratorCancellation] CancellationToken cancel = default)
private async IAsyncEnumerable<MemoryRecord> GetAllAsync(string collectionName, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var container = this._database.Client.GetContainer(this._databaseName, collectionName);
var query = new QueryDefinition("SELECT * FROM c");
Expand All @@ -299,7 +299,7 @@ private async IAsyncEnumerable<MemoryRecord> GetAllAsync(string collectionName,

while (iterator.HasMoreResults) //read all result in batch
{
var items = await iterator.ReadNextAsync(cancel).ConfigureAwait(false);
var items = await iterator.ReadNextAsync(cancellationToken).ConfigureAwait(false);

foreach (var item in items)
{
Expand Down

0 comments on commit f28579e

Please sign in to comment.