From 80f6d169a850cb149d0a7ffd1cf149b1137d122d Mon Sep 17 00:00:00 2001 From: Victor Ferrer Date: Thu, 17 Aug 2023 15:09:42 +0200 Subject: [PATCH] feat: Add Run Pipeline methods to HASSWSApi --- .../KnownEnumHelpersTests.cs | 28 ++++++++ .../Helpers/KnownEnumHelpers.cs | 35 +++++++++- .../Models/Events/KnownPipelineEventTypes.cs | 68 ++++++++++++++++++ src/HassClient.WS/HASSClientWebSocket.cs | 47 +++++++++++++ src/HassClient.WS/HASSWSApi.cs | 70 +++++++++++++++++++ .../Commands/Pipeline/IPipelineRunInput.cs | 9 +++ .../Commands/Pipeline/PipelineRunMessage.cs | 49 +++++++++++++ .../Pipeline/PipelineRunSampleRateInput.cs | 12 ++++ .../Commands/Pipeline/PipelineRunTextInput.cs | 12 ++++ .../Messages/Commands/Pipeline/StageTypes.cs | 23 ++++++ .../Response/PipelineEventResultInfo.cs | 38 ++++++++++ 11 files changed, 390 insertions(+), 1 deletion(-) create mode 100644 src/HassClient.Core/Models/Events/KnownPipelineEventTypes.cs create mode 100644 src/HassClient.WS/Messages/Commands/Pipeline/IPipelineRunInput.cs create mode 100644 src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunMessage.cs create mode 100644 src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunSampleRateInput.cs create mode 100644 src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunTextInput.cs create mode 100644 src/HassClient.WS/Messages/Commands/Pipeline/StageTypes.cs create mode 100644 src/HassClient.WS/Messages/Response/PipelineEventResultInfo.cs diff --git a/src/HassClient.Core.Tests/KnownEnumHelpersTests.cs b/src/HassClient.Core.Tests/KnownEnumHelpersTests.cs index 979c35c..272543d 100644 --- a/src/HassClient.Core.Tests/KnownEnumHelpersTests.cs +++ b/src/HassClient.Core.Tests/KnownEnumHelpersTests.cs @@ -53,6 +53,34 @@ public void EmptyStringAsEventTypeThrows() Assert.Throws(() => string.Empty.AsKnownEventType()); } + [Test] + [TestCase("run-start")] + [TestCase("run-end")] + [TestCase("stt-start")] + [TestCase("stt-end")] + [TestCase("intent-start")] + [TestCase("intent-end")] + [TestCase("tts-start")] + [TestCase("tts-end")] + [TestCase("error")] + public void AllKnownPipelineEventTypesCanBeParsed(string snakeCaseValue) + { + var result = snakeCaseValue.AsKnownPipelineEventType(); + Assert.AreNotEqual(KnownPipelineEventTypes.Undefined, result); + } + + [Test] + public void NullStringAsPipelineEventTypeThrows() + { + Assert.Throws(() => ((string)null).AsKnownPipelineEventType()); + } + + [Test] + public void EmptyStringAsPipelineEventTypeThrows() + { + Assert.Throws(() => string.Empty.AsKnownPipelineEventType()); + } + [Test] [TestCase("adguard")] [TestCase("air_quality")] diff --git a/src/HassClient.Core/Helpers/KnownEnumHelpers.cs b/src/HassClient.Core/Helpers/KnownEnumHelpers.cs index 3619a31..fc4ba46 100644 --- a/src/HassClient.Core/Helpers/KnownEnumHelpers.cs +++ b/src/HassClient.Core/Helpers/KnownEnumHelpers.cs @@ -48,7 +48,7 @@ public static string ToDomainString(this KnownDomains domain) private static KnownEnumCache knownEventTypesCache = new KnownEnumCache(); /// - /// Converts a given snake case to . + /// Converts a given snake case to . /// /// /// The event type as a snake case . (e.g. state_changed). @@ -78,6 +78,39 @@ public static string ToEventTypeString(this KnownEventTypes eventType) return knownEventTypesCache.AsString(eventType); } + private static KnownEnumCache knownPipelineEventTypesCache = new KnownEnumCache(); + + /// + /// Converts a given snake case to . + /// + /// + /// The event type as a snake case . (e.g. state_changed). + /// + /// + /// The event type as a if defined; otherwise, . + /// + public static KnownPipelineEventTypes AsKnownPipelineEventType(this string eventType) + { + if (string.IsNullOrEmpty(eventType)) + { + throw new ArgumentException($"'{nameof(eventType)}' cannot be null or empty", nameof(eventType)); + } + + return knownPipelineEventTypesCache.AsEnum(eventType); + } + + /// + /// Converts a given to a snake case . + /// + /// A . + /// + /// The service as a . + /// + public static string ToEventTypeString(this KnownPipelineEventTypes eventType) + { + return knownPipelineEventTypesCache.AsString(eventType); + } + private static KnownEnumCache knownServicesCache = new KnownEnumCache(); /// diff --git a/src/HassClient.Core/Models/Events/KnownPipelineEventTypes.cs b/src/HassClient.Core/Models/Events/KnownPipelineEventTypes.cs new file mode 100644 index 0000000..59a31f4 --- /dev/null +++ b/src/HassClient.Core/Models/Events/KnownPipelineEventTypes.cs @@ -0,0 +1,68 @@ +using System.Runtime.Serialization; + +namespace HassClient.Models +{ + /// + /// Collection of built-in pipeline event types. See . + /// + public enum KnownPipelineEventTypes + { + /// + /// Used to represent a type not defined within this enum. + /// + Undefined = 0, + + /// + /// Start of pipeline run. + /// + [EnumMember(Value = "run-start")] + RunStart, + + /// + /// End of pipeline run. + /// + [EnumMember(Value = "run-end")] + RunEnd, + + /// + /// Start of speech to text. + /// + [EnumMember(Value = "stt-start")] + STTStart, + + /// + /// End of speech to text. + /// + [EnumMember(Value = "stt-end")] + STTEnd, + + /// + /// Start of intent recognition. + /// + [EnumMember(Value = "intent-start")] + IntentStart, + + /// + /// End of intent recognition. + /// + [EnumMember(Value = "intent-end")] + IntentEnd, + + /// + /// Start of text to speech. + /// + [EnumMember(Value = "tts-start")] + TTSStart, + + /// + /// End of text to speech. + /// + [EnumMember(Value = "tts-end")] + TTSEnd, + + /// + /// Error in pipeline. + /// + Error, + } +} diff --git a/src/HassClient.WS/HASSClientWebSocket.cs b/src/HassClient.WS/HASSClientWebSocket.cs index b832fbe..a093cda 100644 --- a/src/HassClient.WS/HASSClientWebSocket.cs +++ b/src/HassClient.WS/HASSClientWebSocket.cs @@ -2,6 +2,7 @@ using HassClient.Models; using HassClient.Serialization; using HassClient.WS.Messages; +using HassClient.WS.Messages.Response; using HassClient.WS.Serialization; using Newtonsoft.Json; using Newtonsoft.Json.Linq; @@ -403,6 +404,7 @@ private async Task CreateSocketListenerTask() if (incomingMessage is EventResultMessage eventResultMessage) { + Debug.WriteLine($"{TAG} Event message received {eventResultMessage}"); if (!this.receivedEventsChannel.Writer.TryWrite(eventResultMessage)) { Trace.TraceWarning($"{TAG} {nameof(this.receivedEventsChannel)} is full. One event message will discarded."); @@ -717,6 +719,51 @@ internal async Task SendCommandWithSuccessAsync(BaseOutgoingMessage comman return resultMessage.Success; } + /// + /// Sends a pipeline run command and returns a list of the received event results. + /// + /// The command message to be sent. + /// The cancellation token for the asynchronous operation. + /// + /// A task representing the asynchronous operation. + /// The result of the task is a list of the received event results. + /// + internal async Task> SendPipelineRunCommandAsync(PipelineRunMessage commandMessage, CancellationToken cancellationToken) + { + var result = await this.SendCommandWithResultAsync(commandMessage, cancellationToken); + + if (!result.Success) + { + Enumerable.Empty(); + } + + var receivedEvents = new List(); + try + { + var linkedCTS = CancellationTokenSource.CreateLinkedTokenSource(this.closeConnectionCTS.Token, cancellationToken); + var responseTCS = new TaskCompletionSource(linkedCTS.Token); + var eventCallback = new Action(eventResultMessage => + { + var eventResultInfo = eventResultMessage.DeserializeEvent(); + receivedEvents.Add(eventResultInfo); + + if (eventResultInfo.KnownType == KnownPipelineEventTypes.RunEnd) + { + responseTCS.SetResult(true); + } + }); + + this.socketEventCallbacksBySubsciptionId.Add(result.Id, eventCallback); + await responseTCS.Task; + } + finally + { + this.socketEventCallbacksBySubsciptionId.Remove(result.Id); + } + + return receivedEvents; + } + /// /// Adds an to an event subscription. /// diff --git a/src/HassClient.WS/HASSWSApi.cs b/src/HassClient.WS/HASSWSApi.cs index 12576e5..911bbc0 100644 --- a/src/HassClient.WS/HASSWSApi.cs +++ b/src/HassClient.WS/HASSWSApi.cs @@ -977,6 +977,76 @@ public Task SearchRelatedAsync(ItemTypes itemType, string return this.hassClientWebSocket.SendCommandWithResultAsync(commandMessage, cancellationToken); } + /// + /// Performs a pipeline run starting with a stage. + /// + /// The last stage to run. + /// The text to be used as input. + /// ID of the pipeline. ijdkj + /// Unique id for conversation. . + /// Amount of time before pipeline times out (default: 30 seconds). + /// + /// A cancellation token used to propagate notification that this operation should be canceled. + /// + /// + /// A task representing the asynchronous operation. + /// The result of the task is a list of the received event results. + /// + public async Task> RunIntentPipeline( + StageTypes endStage, + string text, + string pipeline = null, + string conversationId = null, + TimeSpan? timeout = default, + CancellationToken cancellationToken = default) + { + var commandMessage = new PipelineRunMessage() + { + StartStage = StageTypes.Intent, + EndStage = endStage, + Input = new PipelineRunTextInput(text), + Pipeline = pipeline, + ConversationId = conversationId, + Timeout = (float?)timeout?.TotalSeconds, + }; + return await this.hassClientWebSocket.SendPipelineRunCommandAsync(commandMessage, cancellationToken); + } + + /// + /// Performs a pipeline run starting with a stage. + /// + /// The last stage to run. + /// The text to be used as input. + /// ID of the pipeline. ijdkj + /// Unique id for conversation. . + /// Amount of time before pipeline times out (default: 30 seconds). + /// + /// A cancellation token used to propagate notification that this operation should be canceled. + /// + /// + /// A task representing the asynchronous operation. + /// The result of the task is a list of the received event results. + /// + public async Task> RunTTSPipeline( + StageTypes endStage, + string text, + string pipeline = null, + string conversationId = null, + TimeSpan? timeout = default, + CancellationToken cancellationToken = default) + { + var commandMessage = new PipelineRunMessage() + { + StartStage = StageTypes.TTS, + EndStage = endStage, + Input = new PipelineRunTextInput(text), + Pipeline = pipeline, + ConversationId = conversationId, + Timeout = (float?)timeout?.TotalSeconds, + }; + return await this.hassClientWebSocket.SendPipelineRunCommandAsync(commandMessage, cancellationToken); + } + /// /// Sends a customized command to the Home Assistant instance. This is useful when a command is not defined by the . /// diff --git a/src/HassClient.WS/Messages/Commands/Pipeline/IPipelineRunInput.cs b/src/HassClient.WS/Messages/Commands/Pipeline/IPipelineRunInput.cs new file mode 100644 index 0000000..8a5f4e8 --- /dev/null +++ b/src/HassClient.WS/Messages/Commands/Pipeline/IPipelineRunInput.cs @@ -0,0 +1,9 @@ +namespace HassClient.WS.Messages +{ + /// + /// Represents a input. + /// + internal interface IPipelineRunInput + { + } +} diff --git a/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunMessage.cs b/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunMessage.cs new file mode 100644 index 0000000..83b8bc3 --- /dev/null +++ b/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunMessage.cs @@ -0,0 +1,49 @@ +using Newtonsoft.Json; + +namespace HassClient.WS.Messages +{ + internal class PipelineRunMessage : BaseOutgoingMessage + { + public PipelineRunMessage() + : base("assist_pipeline/run") + { + } + + /// + /// The first stage to run. + /// + public StageTypes StartStage { get; set; } + + /// + /// The last stage to run. + /// + public StageTypes EndStage { get; set; } + + /// + /// Depends on . + /// + /// For , it should be an . + /// + /// + /// For and , it should be an . + /// + /// + public IPipelineRunInput Input { get; set; } + + /// + /// ID of the pipeline. + /// + public string Pipeline { get; set; } + + /// + /// Unique id for conversation. . + /// + public string ConversationId { get; set; } + + /// + /// Number of seconds before pipeline times out (default: 30). + /// + [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] + public float? Timeout { get; set; } + } +} diff --git a/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunSampleRateInput.cs b/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunSampleRateInput.cs new file mode 100644 index 0000000..9885f85 --- /dev/null +++ b/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunSampleRateInput.cs @@ -0,0 +1,12 @@ +namespace HassClient.WS.Messages +{ + internal class PipelineRunSampleRateInput : IPipelineRunInput + { + public PipelineRunSampleRateInput(int sampleRate) + { + this.SampleRate = sampleRate; + } + + public int SampleRate { get; set; } + } +} diff --git a/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunTextInput.cs b/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunTextInput.cs new file mode 100644 index 0000000..499e40a --- /dev/null +++ b/src/HassClient.WS/Messages/Commands/Pipeline/PipelineRunTextInput.cs @@ -0,0 +1,12 @@ +namespace HassClient.WS.Messages +{ + internal class PipelineRunTextInput : IPipelineRunInput + { + public PipelineRunTextInput(string text) + { + this.Text = text; + } + + public string Text { get; set; } + } +} diff --git a/src/HassClient.WS/Messages/Commands/Pipeline/StageTypes.cs b/src/HassClient.WS/Messages/Commands/Pipeline/StageTypes.cs new file mode 100644 index 0000000..de7f6c5 --- /dev/null +++ b/src/HassClient.WS/Messages/Commands/Pipeline/StageTypes.cs @@ -0,0 +1,23 @@ +namespace HassClient +{ + /// + /// Well known Home Assistant stage types used during pipeline run. + /// + public enum StageTypes + { + /// + /// Speech to Text. + /// + STT, + + /// + /// Intent . + /// + Intent, + + /// + /// Text to Speech. + /// + TTS, + } +} diff --git a/src/HassClient.WS/Messages/Response/PipelineEventResultInfo.cs b/src/HassClient.WS/Messages/Response/PipelineEventResultInfo.cs new file mode 100644 index 0000000..98d7b26 --- /dev/null +++ b/src/HassClient.WS/Messages/Response/PipelineEventResultInfo.cs @@ -0,0 +1,38 @@ +using HassClient.Helpers; +using HassClient.Models; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; + +namespace HassClient.WS.Messages +{ + /// + /// Information of a fired Home Assistant pipeline event. + /// + public class PipelineEventResultInfo + { + /// + /// Gets or sets the pipeline event type. + /// + [JsonProperty(Required = Required.Always)] + public string Type { get; set; } + + /// + /// Gets the pipeline event type as . + /// + [JsonIgnore] + public KnownPipelineEventTypes KnownType => this.Type.AsKnownPipelineEventType(); + + /// + /// Gets or sets the data associated with the fired event. + /// + [JsonProperty(Required = Required.AllowNull)] + public JRaw Data { get; set; } + + /// + /// Gets or sets the time at which the event was fired. + /// + [JsonProperty(Required = Required.Always)] + public DateTimeOffset Timestamp { get; set; } + } +}