Skip to content

Commit

Permalink
feat: Add Run Pipeline methods to HASSWSApi
Browse files Browse the repository at this point in the history
  • Loading branch information
vicfergar committed Aug 17, 2023
1 parent b4c0899 commit 80f6d16
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 1 deletion.
28 changes: 28 additions & 0 deletions src/HassClient.Core.Tests/KnownEnumHelpersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ public void EmptyStringAsEventTypeThrows()
Assert.Throws<ArgumentException>(() => string.Empty.AsKnownEventType());
}

[Test]
[TestCase("run-start")]
[TestCase("run-end")]
[TestCase("stt-start")]
[TestCase("stt-end")]
[TestCase("intent-start")]
[TestCase("intent-end")]
[TestCase("tts-start")]
[TestCase("tts-end")]
[TestCase("error")]
public void AllKnownPipelineEventTypesCanBeParsed(string snakeCaseValue)
{
var result = snakeCaseValue.AsKnownPipelineEventType();
Assert.AreNotEqual(KnownPipelineEventTypes.Undefined, result);
}

[Test]
public void NullStringAsPipelineEventTypeThrows()
{
Assert.Throws<ArgumentException>(() => ((string)null).AsKnownPipelineEventType());
}

[Test]
public void EmptyStringAsPipelineEventTypeThrows()
{
Assert.Throws<ArgumentException>(() => string.Empty.AsKnownPipelineEventType());
}

[Test]
[TestCase("adguard")]
[TestCase("air_quality")]
Expand Down
35 changes: 34 additions & 1 deletion src/HassClient.Core/Helpers/KnownEnumHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static string ToDomainString(this KnownDomains domain)
private static KnownEnumCache<KnownEventTypes> knownEventTypesCache = new KnownEnumCache<KnownEventTypes>();

/// <summary>
/// Converts a given snake case <paramref name="eventType"/> to <see cref="KnownDomains"/>.
/// Converts a given snake case <paramref name="eventType"/> to <see cref="KnownEventTypes"/>.
/// </summary>
/// <param name="eventType">
/// The event type as a snake case <see cref="string"/>. (e.g. <c>state_changed</c>).
Expand Down Expand Up @@ -78,6 +78,39 @@ public static string ToEventTypeString(this KnownEventTypes eventType)
return knownEventTypesCache.AsString(eventType);
}

private static KnownEnumCache<KnownPipelineEventTypes> knownPipelineEventTypesCache = new KnownEnumCache<KnownPipelineEventTypes>();

/// <summary>
/// Converts a given snake case <paramref name="eventType"/> to <see cref="KnownPipelineEventTypes"/>.
/// </summary>
/// <param name="eventType">
/// The event type as a snake case <see cref="string"/>. (e.g. <c>state_changed</c>).
/// </param>
/// <returns>
/// The event type as a <see cref="KnownPipelineEventTypes"/> if defined; otherwise, <see cref="KnownPipelineEventTypes.Undefined"/>.
/// </returns>
public static KnownPipelineEventTypes AsKnownPipelineEventType(this string eventType)
{
if (string.IsNullOrEmpty(eventType))
{
throw new ArgumentException($"'{nameof(eventType)}' cannot be null or empty", nameof(eventType));
}

return knownPipelineEventTypesCache.AsEnum(eventType);
}

/// <summary>
/// Converts a given <see cref="KnownPipelineEventTypes"/> to a snake case <see cref="string"/>.
/// </summary>
/// <param name="eventType">A <see cref="KnownPipelineEventTypes"/>.</param>
/// <returns>
/// The service as a <see cref="string"/>.
/// </returns>
public static string ToEventTypeString(this KnownPipelineEventTypes eventType)
{
return knownPipelineEventTypesCache.AsString(eventType);
}

private static KnownEnumCache<KnownServices> knownServicesCache = new KnownEnumCache<KnownServices>();

/// <summary>
Expand Down
68 changes: 68 additions & 0 deletions src/HassClient.Core/Models/Events/KnownPipelineEventTypes.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using System.Runtime.Serialization;

namespace HassClient.Models
{
/// <summary>
/// Collection of built-in pipeline event types. See <see href="https://developers.home-assistant.io/docs/voice/pipelines/#events"/>.
/// </summary>
public enum KnownPipelineEventTypes
{
/// <summary>
/// Used to represent a type not defined within this enum.
/// </summary>
Undefined = 0,

/// <summary>
/// Start of pipeline run.
/// </summary>
[EnumMember(Value = "run-start")]
RunStart,

/// <summary>
/// End of pipeline run.
/// </summary>
[EnumMember(Value = "run-end")]
RunEnd,

/// <summary>
/// Start of speech to text.
/// </summary>
[EnumMember(Value = "stt-start")]
STTStart,

/// <summary>
/// End of speech to text.
/// </summary>
[EnumMember(Value = "stt-end")]
STTEnd,

/// <summary>
/// Start of intent recognition.
/// </summary>
[EnumMember(Value = "intent-start")]
IntentStart,

/// <summary>
/// End of intent recognition.
/// </summary>
[EnumMember(Value = "intent-end")]
IntentEnd,

/// <summary>
/// Start of text to speech.
/// </summary>
[EnumMember(Value = "tts-start")]
TTSStart,

/// <summary>
/// End of text to speech.
/// </summary>
[EnumMember(Value = "tts-end")]
TTSEnd,

/// <summary>
/// Error in pipeline.
/// </summary>
Error,
}
}
47 changes: 47 additions & 0 deletions src/HassClient.WS/HASSClientWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using HassClient.Models;
using HassClient.Serialization;
using HassClient.WS.Messages;
using HassClient.WS.Messages.Response;

Check failure on line 5 in src/HassClient.WS/HASSClientWebSocket.cs

View workflow job for this annotation

GitHub Actions / build (stable)

The type or namespace name 'Response' does not exist in the namespace 'HassClient.WS.Messages' (are you missing an assembly reference?)

Check failure on line 5 in src/HassClient.WS/HASSClientWebSocket.cs

View workflow job for this annotation

GitHub Actions / build (stable)

The type or namespace name 'Response' does not exist in the namespace 'HassClient.WS.Messages' (are you missing an assembly reference?)
using HassClient.WS.Serialization;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
Expand Down Expand Up @@ -403,6 +404,7 @@ private async Task CreateSocketListenerTask()

if (incomingMessage is EventResultMessage eventResultMessage)
{
Debug.WriteLine($"{TAG} Event message received {eventResultMessage}");
if (!this.receivedEventsChannel.Writer.TryWrite(eventResultMessage))
{
Trace.TraceWarning($"{TAG} {nameof(this.receivedEventsChannel)} is full. One event message will discarded.");
Expand Down Expand Up @@ -717,6 +719,51 @@ internal async Task<bool> SendCommandWithSuccessAsync(BaseOutgoingMessage comman
return resultMessage.Success;
}

/// <summary>
/// Sends a pipeline run command and returns a list of the received event results.
/// </summary>
/// <param name="commandMessage">The command message to be sent.</param>
/// <param name="cancellationToken">The cancellation token for the asynchronous operation.</param>
/// <returns>
/// A task representing the asynchronous operation.
/// The result of the task is a list of the received event results.
/// </returns>
internal async Task<IEnumerable<PipelineEventResultInfo>> SendPipelineRunCommandAsync(PipelineRunMessage commandMessage, CancellationToken cancellationToken)
{
var result = await this.SendCommandWithResultAsync(commandMessage, cancellationToken);

if (!result.Success)
{
Enumerable.Empty<PipelineEventResultInfo>();
}

var receivedEvents = new List<PipelineEventResultInfo>();
try
{
var linkedCTS = CancellationTokenSource.CreateLinkedTokenSource(this.closeConnectionCTS.Token, cancellationToken);
var responseTCS = new TaskCompletionSource<bool>(linkedCTS.Token);
var eventCallback = new Action<EventResultMessage>(eventResultMessage =>
{
var eventResultInfo = eventResultMessage.DeserializeEvent<PipelineEventResultInfo>();
receivedEvents.Add(eventResultInfo);
if (eventResultInfo.KnownType == KnownPipelineEventTypes.RunEnd)
{
responseTCS.SetResult(true);
}
});

this.socketEventCallbacksBySubsciptionId.Add(result.Id, eventCallback);
await responseTCS.Task;
}
finally
{
this.socketEventCallbacksBySubsciptionId.Remove(result.Id);
}

return receivedEvents;
}

/// <summary>
/// Adds an <see cref="EventHandler{TEventArgs}"/> to an event subscription.
/// </summary>
Expand Down
70 changes: 70 additions & 0 deletions src/HassClient.WS/HASSWSApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,76 @@ public Task<SearchRelatedResponse> SearchRelatedAsync(ItemTypes itemType, string
return this.hassClientWebSocket.SendCommandWithResultAsync<SearchRelatedResponse>(commandMessage, cancellationToken);
}

/// <summary>
/// Performs a pipeline run starting with a <see cref="StageTypes.Intent"/> stage.
/// </summary>
/// <param name="endStage">The last stage to run.</param>
/// <param name="text">The text to be used as input.</param>
/// <param name="pipeline">ID of the pipeline.</param> ijdkj
/// <param name="conversationId">Unique id for conversation. <see href="https://developers.home-assistant.io/docs/intent_conversation_api#conversation-id"/>.</param>
/// <param name="timeout">Amount of time before pipeline times out (default: 30 seconds).</param>
/// <param name="cancellationToken">
/// A cancellation token used to propagate notification that this operation should be canceled.
/// </param>
/// <returns>
/// A task representing the asynchronous operation.
/// The result of the task is a list of the received event results.
/// </returns>
public async Task<IEnumerable<PipelineEventResultInfo>> RunIntentPipeline(
StageTypes endStage,
string text,
string pipeline = null,
string conversationId = null,
TimeSpan? timeout = default,
CancellationToken cancellationToken = default)
{
var commandMessage = new PipelineRunMessage()
{
StartStage = StageTypes.Intent,
EndStage = endStage,
Input = new PipelineRunTextInput(text),
Pipeline = pipeline,
ConversationId = conversationId,
Timeout = (float?)timeout?.TotalSeconds,
};
return await this.hassClientWebSocket.SendPipelineRunCommandAsync(commandMessage, cancellationToken);
}

/// <summary>
/// Performs a pipeline run starting with a <see cref="StageTypes.TTS"/> stage.
/// </summary>
/// <param name="endStage">The last stage to run.</param>
/// <param name="text">The text to be used as input.</param>
/// <param name="pipeline">ID of the pipeline.</param> ijdkj
/// <param name="conversationId">Unique id for conversation. <see href="https://developers.home-assistant.io/docs/intent_conversation_api#conversation-id"/>.</param>
/// <param name="timeout">Amount of time before pipeline times out (default: 30 seconds).</param>
/// <param name="cancellationToken">
/// A cancellation token used to propagate notification that this operation should be canceled.
/// </param>
/// <returns>
/// A task representing the asynchronous operation.
/// The result of the task is a list of the received event results.
/// </returns>
public async Task<IEnumerable<PipelineEventResultInfo>> RunTTSPipeline(
StageTypes endStage,
string text,
string pipeline = null,
string conversationId = null,
TimeSpan? timeout = default,
CancellationToken cancellationToken = default)
{
var commandMessage = new PipelineRunMessage()
{
StartStage = StageTypes.TTS,
EndStage = endStage,
Input = new PipelineRunTextInput(text),
Pipeline = pipeline,
ConversationId = conversationId,
Timeout = (float?)timeout?.TotalSeconds,
};
return await this.hassClientWebSocket.SendPipelineRunCommandAsync(commandMessage, cancellationToken);
}

/// <summary>
/// Sends a customized command to the Home Assistant instance. This is useful when a command is not defined by the <see cref="HassWSApi"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace HassClient.WS.Messages
{
/// <summary>
/// Represents a <see cref="PipelineRunMessage"/> input.
/// </summary>
internal interface IPipelineRunInput
{
}
}
49 changes: 49 additions & 0 deletions src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunMessage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using Newtonsoft.Json;

namespace HassClient.WS.Messages
{
internal class PipelineRunMessage : BaseOutgoingMessage
{
public PipelineRunMessage()
: base("assist_pipeline/run")
{
}

/// <summary>
/// The first stage to run.
/// </summary>
public StageTypes StartStage { get; set; }

/// <summary>
/// The last stage to run.
/// </summary>
public StageTypes EndStage { get; set; }

/// <summary>
/// Depends on <see cref="StartStage"/>.
/// <para>
/// For <see cref="StageTypes.STT"/>, it should be an <see cref="PipelineRunSampleRateInput"/>.
/// </para>
/// <para>
/// For <see cref="StageTypes.Intent"/> and <see cref="StageTypes.TTS"/>, it should be an <see cref="PipelineRunTextInput"/>.
/// </para>
/// </summary>
public IPipelineRunInput Input { get; set; }

/// <summary>
/// ID of the pipeline.
/// </summary>
public string Pipeline { get; set; }

/// <summary>
/// Unique id for conversation. <see href="https://developers.home-assistant.io/docs/intent_conversation_api#conversation-id"/>.
/// </summary>
public string ConversationId { get; set; }

/// <summary>
/// Number of seconds before pipeline times out (default: 30).
/// </summary>
[JsonProperty(NullValueHandling = NullValueHandling.Ignore)]
public float? Timeout { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace HassClient.WS.Messages
{
internal class PipelineRunSampleRateInput : IPipelineRunInput
{
public PipelineRunSampleRateInput(int sampleRate)
{
this.SampleRate = sampleRate;
}

public int SampleRate { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace HassClient.WS.Messages
{
internal class PipelineRunTextInput : IPipelineRunInput
{
public PipelineRunTextInput(string text)
{
this.Text = text;
}

public string Text { get; set; }
}
}
Loading

0 comments on commit 80f6d16

Please sign in to comment.