From 197393afc48d1029ac09efde8be5cecc321af3a0 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Mon, 24 Jun 2024 13:18:13 -0700 Subject: [PATCH 01/14] [DOW-105] refactor interruptions into the output device (#562) * initial refactor works * remove notion of UtteranceAudioChunk and put all of the state in the callback * move per_chunk_allowance_seconds into output device * onboard onto vonage * rename to abstract output device and onboard other output devices * initial work to onboard twilio output device * twilio conversation works * some cleanup with better comments * unset poetry.lock * move abstract play method into ratelimitoutputdevice + dispatch to thread in fileoutputdevice * rename back to AsyncWorker * comments * work through a bit of mypy * asyncio.gather is g2g: * create interrupt lock * remove todo * remove last todo * remove log for interrupts * fmt --- .../streaming/synthesizer/synthesize.py | 10 +- tests/fakedata/conversation.py | 7 +- vocode/helpers.py | 2 +- .../streaming/client_backend/conversation.py | 2 +- vocode/streaming/models/synthesizer.py | 4 +- .../output_device/abstract_output_device.py | 16 ++ vocode/streaming/output_device/audio_chunk.py | 28 +++ .../output_device/base_output_device.py | 16 -- .../output_device/blocking_speaker_output.py | 58 ++++-- .../output_device/file_output_device.py | 39 +--- .../rate_limit_interruptions_output_device.py | 70 +++++++ .../streaming/output_device/speaker_output.py | 5 +- .../output_device/twilio_output_device.py | 120 ++++++++--- .../output_device/vonage_output_device.py | 36 ++-- .../output_device/websocket_output_device.py | 27 +-- vocode/streaming/streaming_conversation.py | 191 ++++++++++-------- .../synthesizer/azure_synthesizer.py | 4 +- .../streaming/synthesizer/base_synthesizer.py | 24 ++- .../eleven_labs_websocket_synthesizer.py | 4 +- .../synthesizer/polly_synthesizer.py | 6 +- .../abstract_phone_conversation.py | 2 - .../conversation/mark_message_queue.py | 46 ----- .../conversation/twilio_phone_conversation.py | 150 +------------- .../conversation/vonage_phone_conversation.py | 2 - .../streaming/transcriber/base_transcriber.py | 9 +- vocode/streaming/utils/__init__.py | 9 + vocode/streaming/utils/worker.py | 2 +- ...ut_device.py => abstract_output_device.py} | 7 +- .../output_device/speaker_output.py | 4 +- vocode/turn_based/turn_based_conversation.py | 4 +- 30 files changed, 451 insertions(+), 453 deletions(-) create mode 100644 vocode/streaming/output_device/abstract_output_device.py create mode 100644 vocode/streaming/output_device/audio_chunk.py delete mode 100644 vocode/streaming/output_device/base_output_device.py create mode 100644 vocode/streaming/output_device/rate_limit_interruptions_output_device.py delete mode 100644 vocode/streaming/telephony/conversation/mark_message_queue.py rename vocode/turn_based/output_device/{base_output_device.py => abstract_output_device.py} (55%) diff --git a/playground/streaming/synthesizer/synthesize.py b/playground/streaming/synthesizer/synthesize.py index 086169772..c9431620a 100644 --- a/playground/streaming/synthesizer/synthesize.py +++ b/playground/streaming/synthesizer/synthesize.py @@ -2,11 +2,13 @@ from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.synthesizer import AzureSynthesizerConfig -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.output_device.audio_chunk import AudioChunk from vocode.streaming.output_device.blocking_speaker_output import BlockingSpeakerOutput from vocode.streaming.synthesizer.azure_synthesizer import AzureSynthesizer from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer from vocode.streaming.utils import get_chunk_size_per_second +from vocode.streaming.utils.worker import InterruptibleEvent if __name__ == "__main__": import asyncio @@ -19,7 +21,7 @@ async def speak( synthesizer: BaseSynthesizer, - output_device: BaseOutputDevice, + output_device: AbstractOutputDevice, message: BaseMessage, ): message_sent = message.text @@ -38,7 +40,9 @@ async def speak( try: start_time = time.time() speech_length_seconds = seconds_per_chunk * (len(chunk_result.chunk) / chunk_size) - output_device.consume_nonblocking(chunk_result.chunk) + output_device.consume_nonblocking( + InterruptibleEvent(payload=AudioChunk(data=chunk_result.chunk)) + ) end_time = time.time() await asyncio.sleep( max( diff --git a/tests/fakedata/conversation.py b/tests/fakedata/conversation.py index 25344a006..2b6917b21 100644 --- a/tests/fakedata/conversation.py +++ b/tests/fakedata/conversation.py @@ -8,7 +8,7 @@ from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.synthesizer import PlayHtSynthesizerConfig, SynthesizerConfig from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, TranscriberConfig -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.streaming_conversation import StreamingConversation from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer from vocode.streaming.telephony.constants import DEFAULT_CHUNK_SIZE, DEFAULT_SAMPLING_RATE @@ -36,10 +36,13 @@ ) -class DummyOutputDevice(BaseOutputDevice): +class DummyOutputDevice(AbstractOutputDevice): def consume_nonblocking(self, chunk: bytes): pass + def interrupt(self): + pass + def create_fake_transcriber(mocker: MockerFixture, transcriber_config: TranscriberConfig): transcriber = mocker.MagicMock() diff --git a/vocode/helpers.py b/vocode/helpers.py index 29727eb92..d05b06fe5 100644 --- a/vocode/helpers.py +++ b/vocode/helpers.py @@ -33,7 +33,7 @@ def create_streaming_microphone_input_and_speaker_output( ): return _create_microphone_input_and_speaker_output( microphone_class=StreamingMicrophoneInput, - speaker_class=(BlockingStreamingSpeakerOutput), + speaker_class=BlockingStreamingSpeakerOutput, use_default_devices=use_default_devices, input_device_name=input_device_name, output_device_name=output_device_name, diff --git a/vocode/streaming/client_backend/conversation.py b/vocode/streaming/client_backend/conversation.py index 5970c475b..31b34c82b 100644 --- a/vocode/streaming/client_backend/conversation.py +++ b/vocode/streaming/client_backend/conversation.py @@ -116,7 +116,7 @@ def __init__( async def handle_event(self, event: Event): if event.type == EventType.TRANSCRIPT: transcript_event = typing.cast(TranscriptEvent, event) - self.output_device.consume_transcript(transcript_event) + await self.output_device.send_transcript(transcript_event) # logger.debug(event.dict()) def restart(self, output_device: WebsocketOutputDevice): diff --git a/vocode/streaming/models/synthesizer.py b/vocode/streaming/models/synthesizer.py index eb1032d6c..13a7fe655 100644 --- a/vocode/streaming/models/synthesizer.py +++ b/vocode/streaming/models/synthesizer.py @@ -4,7 +4,7 @@ from pydantic.v1 import validator from vocode.streaming.models.client_backend import OutputAudioConfig -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.telephony.constants import DEFAULT_AUDIO_ENCODING, DEFAULT_SAMPLING_RATE from .audio import AudioEncoding, SamplingRate @@ -47,7 +47,7 @@ class Config: arbitrary_types_allowed = True @classmethod - def from_output_device(cls, output_device: BaseOutputDevice, **kwargs): + def from_output_device(cls, output_device: AbstractOutputDevice, **kwargs): return cls( sampling_rate=output_device.sampling_rate, audio_encoding=output_device.audio_encoding, diff --git a/vocode/streaming/output_device/abstract_output_device.py b/vocode/streaming/output_device/abstract_output_device.py new file mode 100644 index 000000000..746a4ffd8 --- /dev/null +++ b/vocode/streaming/output_device/abstract_output_device.py @@ -0,0 +1,16 @@ +from abc import abstractmethod +import asyncio +from vocode.streaming.output_device.audio_chunk import AudioChunk +from vocode.streaming.utils.worker import AsyncWorker, InterruptibleEvent + + +class AbstractOutputDevice(AsyncWorker[InterruptibleEvent[AudioChunk]]): + + def __init__(self, sampling_rate: int, audio_encoding): + super().__init__(input_queue=asyncio.Queue()) + self.sampling_rate = sampling_rate + self.audio_encoding = audio_encoding + + @abstractmethod + def interrupt(self): + pass diff --git a/vocode/streaming/output_device/audio_chunk.py b/vocode/streaming/output_device/audio_chunk.py new file mode 100644 index 000000000..d58e081ad --- /dev/null +++ b/vocode/streaming/output_device/audio_chunk.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass, field +from enum import Enum +from uuid import UUID +import uuid + + +class ChunkState(int, Enum): + UNPLAYED = 0 + PLAYED = 1 + INTERRUPTED = 2 + + +@dataclass +class AudioChunk: + data: bytes + state: ChunkState = ChunkState.UNPLAYED + chunk_id: UUID = field(default_factory=uuid.uuid4) + + @staticmethod + def on_play(): + pass + + @staticmethod + def on_interrupt(): + pass + + def __hash__(self) -> int: + return hash(self.chunk_id) diff --git a/vocode/streaming/output_device/base_output_device.py b/vocode/streaming/output_device/base_output_device.py deleted file mode 100644 index 2ce90d5c2..000000000 --- a/vocode/streaming/output_device/base_output_device.py +++ /dev/null @@ -1,16 +0,0 @@ -from vocode.streaming.models.audio import AudioEncoding - - -class BaseOutputDevice: - def __init__(self, sampling_rate: int, audio_encoding: AudioEncoding): - self.sampling_rate = sampling_rate - self.audio_encoding = audio_encoding - - def start(self): - pass - - def consume_nonblocking(self, chunk: bytes): - raise NotImplemented - - def terminate(self): - pass diff --git a/vocode/streaming/output_device/blocking_speaker_output.py b/vocode/streaming/output_device/blocking_speaker_output.py index 99fdcd559..637a41d4b 100644 --- a/vocode/streaming/output_device/blocking_speaker_output.py +++ b/vocode/streaming/output_device/blocking_speaker_output.py @@ -6,26 +6,21 @@ import sounddevice as sd from vocode.streaming.models.audio import AudioEncoding -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) from vocode.streaming.utils.worker import ThreadAsyncWorker +DEFAULT_SAMPLING_RATE = 44100 -class BlockingSpeakerOutput(BaseOutputDevice, ThreadAsyncWorker): - DEFAULT_SAMPLING_RATE = 44100 - def __init__( - self, - device_info: dict, - sampling_rate: Optional[int] = None, - audio_encoding: AudioEncoding = AudioEncoding.LINEAR16, - ): +class _PlaybackWorker(ThreadAsyncWorker[bytes]): + + def __init__(self, *, device_info: dict, sampling_rate: int): + self.sampling_rate = sampling_rate self.device_info = device_info - sampling_rate = sampling_rate or int( - self.device_info.get("default_samplerate", self.DEFAULT_SAMPLING_RATE) - ) self.input_queue: asyncio.Queue[bytes] = asyncio.Queue() - BaseOutputDevice.__init__(self, sampling_rate, audio_encoding) - ThreadAsyncWorker.__init__(self, self.input_queue) + super().__init__(self.input_queue) self.stream = sd.OutputStream( channels=1, samplerate=self.sampling_rate, @@ -36,9 +31,6 @@ def __init__( self.input_queue.put_nowait(self.sampling_rate * b"\x00") self.stream.start() - def start(self): - ThreadAsyncWorker.start(self) - def _run_loop(self): while not self._ended: try: @@ -47,14 +39,38 @@ def _run_loop(self): except queue.Empty: continue - def consume_nonblocking(self, chunk): - ThreadAsyncWorker.consume_nonblocking(self, chunk) - def terminate(self): self._ended = True - ThreadAsyncWorker.terminate(self) + super().terminate() self.stream.close() + +class BlockingSpeakerOutput(RateLimitInterruptionsOutputDevice): + DEFAULT_SAMPLING_RATE = 44100 + + def __init__( + self, + device_info: dict, + sampling_rate: Optional[int] = None, + audio_encoding: AudioEncoding = AudioEncoding.LINEAR16, + ): + sampling_rate = sampling_rate or int( + device_info.get("default_samplerate", DEFAULT_SAMPLING_RATE) + ) + super().__init__(sampling_rate=sampling_rate, audio_encoding=audio_encoding) + self.playback_worker = _PlaybackWorker(device_info=device_info, sampling_rate=sampling_rate) + + async def play(self, chunk): + self.playback_worker.consume_nonblocking(chunk) + + def start(self) -> asyncio.Task: + self.playback_worker.start() + return super().start() + + def terminate(self): + self.playback_worker.terminate() + super().terminate() + @classmethod def from_default_device( cls, diff --git a/vocode/streaming/output_device/file_output_device.py b/vocode/streaming/output_device/file_output_device.py index 65a548fac..5316a8fa4 100644 --- a/vocode/streaming/output_device/file_output_device.py +++ b/vocode/streaming/output_device/file_output_device.py @@ -1,34 +1,16 @@ import asyncio import wave -from asyncio import Queue import numpy as np -from vocode.streaming.models.audio import AudioEncoding -from vocode.streaming.utils.worker import ThreadAsyncWorker - -from .base_output_device import BaseOutputDevice - +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) -class FileWriterWorker(ThreadAsyncWorker): - def __init__(self, input_queue: Queue, wave) -> None: - super().__init__(input_queue) - self.wav = wave - - def _run_loop(self): - while True: - try: - block = self.input_janus_queue.sync_q.get() - self.wav.writeframes(block) - except asyncio.CancelledError: - return - - def terminate(self): - super().terminate() - self.wav.close() +from vocode.streaming.models.audio import AudioEncoding -class FileOutputDevice(BaseOutputDevice): +class FileOutputDevice(RateLimitInterruptionsOutputDevice): DEFAULT_SAMPLING_RATE = 44100 def __init__( @@ -39,7 +21,6 @@ def __init__( ): super().__init__(sampling_rate, audio_encoding) self.blocksize = self.sampling_rate - self.queue: Queue[np.ndarray] = Queue() wav = wave.open(file_path, "wb") wav.setnchannels(1) # Mono channel @@ -47,16 +28,14 @@ def __init__( wav.setframerate(self.sampling_rate) self.wav = wav - self.thread_worker = FileWriterWorker(self.queue, wav) - self.thread_worker.start() - - def consume_nonblocking(self, chunk): + async def play(self, chunk: bytes): chunk_arr = np.frombuffer(chunk, dtype=np.int16) for i in range(0, chunk_arr.shape[0], self.blocksize): block = np.zeros(self.blocksize, dtype=np.int16) size = min(self.blocksize, chunk_arr.shape[0] - i) block[:size] = chunk_arr[i : i + size] - self.queue.put_nowait(block) + await asyncio.to_thread(lambda: self.wav.writeframes(block.tobytes())) def terminate(self): - self.thread_worker.terminate() + self.wav.close() + super().terminate() diff --git a/vocode/streaming/output_device/rate_limit_interruptions_output_device.py b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py new file mode 100644 index 000000000..2cfb94d64 --- /dev/null +++ b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py @@ -0,0 +1,70 @@ +from abc import abstractmethod +import asyncio +import time + +from vocode.streaming.constants import PER_CHUNK_ALLOWANCE_SECONDS +from vocode.streaming.models.audio import AudioEncoding +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.output_device.audio_chunk import ChunkState +from vocode.streaming.utils import get_chunk_size_per_second + + +class RateLimitInterruptionsOutputDevice(AbstractOutputDevice): + """Output device that works by rate limiting the chunks sent to the output. For interrupts to work properly, + the next chunk of audio can only be sent after the last chunk is played, so we send + a chunk of x seconds only after x seconds have passed since the last chunk was sent.""" + + def __init__( + self, + sampling_rate: int, + audio_encoding: AudioEncoding, + per_chunk_allowance_seconds: float = PER_CHUNK_ALLOWANCE_SECONDS, + ): + super().__init__(sampling_rate, audio_encoding) + self.per_chunk_allowance_seconds = per_chunk_allowance_seconds + + async def _run_loop(self): + while True: + start_time = time.time() + try: + item = await self.input_queue.get() + except asyncio.CancelledError: + return + + self.interruptible_event = item + audio_chunk = item.payload + + if item.is_interrupted(): + audio_chunk.on_interrupt() + audio_chunk.state = ChunkState.INTERRUPTED + continue + + speech_length_seconds = (len(audio_chunk.data)) / get_chunk_size_per_second( + self.audio_encoding, + self.sampling_rate, + ) + await self.play(audio_chunk.data) + audio_chunk.on_play() + audio_chunk.state = ChunkState.PLAYED + end_time = time.time() + await asyncio.sleep( + max( + speech_length_seconds + - (end_time - start_time) + - self.per_chunk_allowance_seconds, + 0, + ), + ) + self.interruptible_event.is_interruptible = False + + @abstractmethod + async def play(self, chunk: bytes): + """Sends an audio chunk to immediate playback""" + pass + + def interrupt(self): + """ + For conversations that use rate-limiting playback as above, + no custom logic is needed on interrupt, because to end synthesis, all we need to do is stop sending chunks. + """ + pass diff --git a/vocode/streaming/output_device/speaker_output.py b/vocode/streaming/output_device/speaker_output.py index a9a30f82e..b7861dbad 100644 --- a/vocode/streaming/output_device/speaker_output.py +++ b/vocode/streaming/output_device/speaker_output.py @@ -4,14 +4,13 @@ import numpy as np import sounddevice as sd +from .abstract_output_device import AbstractOutputDevice from vocode.streaming.models.audio import AudioEncoding -from .base_output_device import BaseOutputDevice - raise DeprecationWarning("Use BlockingSpeakerOutput instead") -class SpeakerOutput(BaseOutputDevice): +class SpeakerOutput(AbstractOutputDevice): DEFAULT_SAMPLING_RATE = 44100 def __init__( diff --git a/vocode/streaming/output_device/twilio_output_device.py b/vocode/streaming/output_device/twilio_output_device.py index 565ec65f1..c0509d57f 100644 --- a/vocode/streaming/output_device/twilio_output_device.py +++ b/vocode/streaming/output_device/twilio_output_device.py @@ -3,66 +3,124 @@ import asyncio import base64 import json -from typing import Optional +from typing import Optional, Union +import uuid from fastapi import WebSocket from fastapi.websockets import WebSocketState +from pydantic import BaseModel -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.telephony.constants import DEFAULT_AUDIO_ENCODING, DEFAULT_SAMPLING_RATE from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log +from vocode.streaming.utils.worker import InterruptibleEvent -class TwilioOutputDevice(BaseOutputDevice): +class ChunkFinishedMarkMessage(BaseModel): + chunk_id: str + + +MarkMessage = Union[ChunkFinishedMarkMessage] # space for more mark messages + + +class TwilioOutputDevice(AbstractOutputDevice): def __init__(self, ws: Optional[WebSocket] = None, stream_sid: Optional[str] = None): super().__init__(sampling_rate=DEFAULT_SAMPLING_RATE, audio_encoding=DEFAULT_AUDIO_ENCODING) self.ws = ws self.stream_sid = stream_sid self.active = True - self.queue: asyncio.Queue[str] = asyncio.Queue() - self.process_task = asyncio_create_task_with_done_error_log(self.process()) - async def process(self): - while self.active: - message = await self.queue.get() + self.twilio_events_queue: asyncio.Queue[str] = asyncio.Queue() + self.mark_message_queue: asyncio.Queue[MarkMessage] = asyncio.Queue() + self.unprocessed_audio_chunks_queue: asyncio.Queue[InterruptibleEvent[AudioChunk]] = ( + asyncio.Queue() + ) + + def consume_nonblocking(self, item: InterruptibleEvent[AudioChunk]): + if not item.is_interrupted(): + self._send_audio_chunk_and_mark(item.payload.data) + self.unprocessed_audio_chunks_queue.put_nowait(item) + else: + audio_chunk = item.payload + audio_chunk.on_interrupt() + audio_chunk.state = ChunkState.INTERRUPTED + + async def play(self, chunk: bytes): + """ + For Twilio, we send all of the audio chunks to be played at once, + and then consume the mark messages to know when to send the on_play / on_interrupt callbacks + """ + pass + + def interrupt(self): + self._send_clear_message() + + def enqueue_mark_message(self, mark_message: MarkMessage): + self.mark_message_queue.put_nowait(mark_message) + + async def _send_twilio_messages(self): + while True: + try: + twilio_event = await self.twilio_events_queue.get() + except asyncio.CancelledError: + return if self.ws.application_state == WebSocketState.DISCONNECTED: break - await self.ws.send_text(message) + await self.ws.send_text(twilio_event) + + async def _process_mark_messages(self): + while True: + try: + # mark messages are tagged with the chunk ID that is attached to the audio chunk + # but they are guaranteed to come in the same order as the audio chunks, and we + # don't need to build resiliency there + await self.mark_message_queue.get() + item = await self.unprocessed_audio_chunks_queue.get() + except asyncio.CancelledError: + return + + self.interruptible_event = item + audio_chunk = item.payload + + if item.is_interrupted(): + audio_chunk.on_interrupt() + audio_chunk.state = ChunkState.INTERRUPTED + continue - def consume_nonblocking(self, chunk: bytes): - twilio_message = { + audio_chunk.on_play() + audio_chunk.state = ChunkState.PLAYED + + self.interruptible_event.is_interruptible = False + + async def _run_loop(self): + send_twilio_messages_task = asyncio_create_task_with_done_error_log( + self._send_twilio_messages() + ) + process_mark_messages_task = asyncio_create_task_with_done_error_log( + self._process_mark_messages() + ) + await asyncio.gather(send_twilio_messages_task, process_mark_messages_task) + + def _send_audio_chunk_and_mark(self, chunk: bytes): + media_message = { "event": "media", "streamSid": self.stream_sid, "media": {"payload": base64.b64encode(chunk).decode("utf-8")}, } - self.queue.put_nowait(json.dumps(twilio_message)) - - def send_chunk_finished_mark(self, utterance_id, chunk_idx): - mark_message = { - "event": "mark", - "streamSid": self.stream_sid, - "mark": { - "name": f"chunk-{utterance_id}-{chunk_idx}", - }, - } - self.queue.put_nowait(json.dumps(mark_message)) - - def send_utterance_finished_mark(self, utterance_id): + self.twilio_events_queue.put_nowait(json.dumps(media_message)) mark_message = { "event": "mark", "streamSid": self.stream_sid, "mark": { - "name": f"utterance-{utterance_id}", + "name": str(uuid.uuid4()), }, } - self.queue.put_nowait(json.dumps(mark_message)) + self.twilio_events_queue.put_nowait(json.dumps(mark_message)) - def send_clear_message(self): + def _send_clear_message(self): clear_message = { "event": "clear", "streamSid": self.stream_sid, } - self.queue.put_nowait(json.dumps(clear_message)) - - def terminate(self): - self.process_task.cancel() + self.twilio_events_queue.put_nowait(json.dumps(clear_message)) diff --git a/vocode/streaming/output_device/vonage_output_device.py b/vocode/streaming/output_device/vonage_output_device.py index 76aef2b12..8385bfaaf 100644 --- a/vocode/streaming/output_device/vonage_output_device.py +++ b/vocode/streaming/output_device/vonage_output_device.py @@ -1,21 +1,21 @@ -import asyncio from typing import Optional from fastapi import WebSocket from fastapi.websockets import WebSocketState -from vocode.streaming.output_device.base_output_device import BaseOutputDevice from vocode.streaming.output_device.blocking_speaker_output import BlockingSpeakerOutput +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) from vocode.streaming.telephony.constants import ( PCM_SILENCE_BYTE, VONAGE_AUDIO_ENCODING, VONAGE_CHUNK_SIZE, VONAGE_SAMPLING_RATE, ) -from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log -class VonageOutputDevice(BaseOutputDevice): +class VonageOutputDevice(RateLimitInterruptionsOutputDevice): def __init__( self, ws: Optional[WebSocket] = None, @@ -23,30 +23,18 @@ def __init__( ): super().__init__(sampling_rate=VONAGE_SAMPLING_RATE, audio_encoding=VONAGE_AUDIO_ENCODING) self.ws = ws - self.active = True - self.queue: asyncio.Queue[bytes] = asyncio.Queue() - self.process_task = asyncio_create_task_with_done_error_log(self.process()) self.output_to_speaker = output_to_speaker if output_to_speaker: self.output_speaker = BlockingSpeakerOutput.from_default_device( sampling_rate=VONAGE_SAMPLING_RATE, blocksize=VONAGE_CHUNK_SIZE // 2 ) - async def process(self): - while self.active: - chunk = await self.queue.get() - if self.ws.application_state == WebSocketState.DISCONNECTED: - break - if self.output_to_speaker: - self.output_speaker.consume_nonblocking(chunk) - for i in range(0, len(chunk), VONAGE_CHUNK_SIZE): - subchunk = chunk[i : i + VONAGE_CHUNK_SIZE] - if len(subchunk) % 2 == 1: - subchunk += PCM_SILENCE_BYTE # pad with silence, Vonage goes crazy otherwise + async def play(self, chunk: bytes): + if self.output_to_speaker: + self.output_speaker.consume_nonblocking(chunk) + for i in range(0, len(chunk), VONAGE_CHUNK_SIZE): + subchunk = chunk[i : i + VONAGE_CHUNK_SIZE] + if len(subchunk) % 2 == 1: + subchunk += PCM_SILENCE_BYTE # pad with silence, Vonage goes crazy otherwise + if self.ws and self.ws.application_state == WebSocketState.DISCONNECTED: await self.ws.send_bytes(subchunk) - - def consume_nonblocking(self, chunk: bytes): - self.queue.put_nowait(chunk) - - def terminate(self): - self.process_task.cancel() diff --git a/vocode/streaming/output_device/websocket_output_device.py b/vocode/streaming/output_device/websocket_output_device.py index ca0133c14..9b5ce1de4 100644 --- a/vocode/streaming/output_device/websocket_output_device.py +++ b/vocode/streaming/output_device/websocket_output_device.py @@ -7,11 +7,12 @@ from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.models.transcript import TranscriptEvent from vocode.streaming.models.websocket import AudioMessage, TranscriptMessage -from vocode.streaming.output_device.base_output_device import BaseOutputDevice -from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) -class WebsocketOutputDevice(BaseOutputDevice): +class WebsocketOutputDevice(RateLimitInterruptionsOutputDevice): def __init__(self, ws: WebSocket, sampling_rate: int, audio_encoding: AudioEncoding): super().__init__(sampling_rate, audio_encoding) self.ws = ws @@ -20,25 +21,15 @@ def __init__(self, ws: WebSocket, sampling_rate: int, audio_encoding: AudioEncod def start(self): self.active = True - self.process_task = asyncio_create_task_with_done_error_log(self.process()) + return super().start() def mark_closed(self): self.active = False - async def process(self): - while self.active: - message = await self.queue.get() - await self.ws.send_text(message) + async def play(self, chunk: bytes): + await self.ws.send_text(AudioMessage.from_bytes(chunk).json()) - def consume_nonblocking(self, chunk: bytes): - if self.active: - audio_message = AudioMessage.from_bytes(chunk) - self.queue.put_nowait(audio_message.json()) - - def consume_transcript(self, event: TranscriptEvent): + async def send_transcript(self, event: TranscriptEvent): if self.active: transcript_message = TranscriptMessage.from_event(event) - self.queue.put_nowait(transcript_message.json()) - - def terminate(self): - self.process_task.cancel() + await self.ws.send_text(transcript_message.json()) diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index aadf6bfa6..b6d2def00 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -9,7 +9,6 @@ import typing from typing import ( Any, - AsyncGenerator, Awaitable, Callable, Generic, @@ -39,7 +38,6 @@ from vocode.streaming.constants import ( ALLOWED_IDLE_TIME, CHECK_HUMAN_PRESENT_MESSAGE_CHOICES, - PER_CHUNK_ALLOWANCE_SECONDS, TEXT_TO_SPEECH_CHUNK_SIZE_SECONDS, ) from vocode.streaming.models.actions import EndOfTurn @@ -48,7 +46,8 @@ from vocode.streaming.models.message import BaseMessage, BotBackchannel, LLMToken, SilenceMessage from vocode.streaming.models.transcriber import TranscriberConfig, Transcription from vocode.streaming.models.transcript import Message, Transcript, TranscriptCompleteEvent -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.synthesizer.base_synthesizer import ( BaseSynthesizer, FillerAudio, @@ -57,7 +56,11 @@ from vocode.streaming.synthesizer.input_streaming_synthesizer import InputStreamingSynthesizer from vocode.streaming.transcriber.base_transcriber import BaseTranscriber from vocode.streaming.transcriber.deepgram_transcriber import DeepgramTranscriber -from vocode.streaming.utils import create_conversation_id, get_chunk_size_per_second +from vocode.streaming.utils import ( + create_conversation_id, + enumerate_async_iter, + get_chunk_size_per_second, +) from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log from vocode.streaming.utils.events_manager import EventsManager from vocode.streaming.utils.speed_manager import SpeedManager @@ -102,7 +105,7 @@ LOW_INTERRUPT_SENSITIVITY_BACKCHANNEL_UTTERANCE_LENGTH_THRESHOLD = 3 -OutputDeviceType = TypeVar("OutputDeviceType", bound=BaseOutputDevice) +OutputDeviceType = TypeVar("OutputDeviceType", bound=AbstractOutputDevice) class StreamingConversation(Generic[OutputDeviceType]): @@ -136,7 +139,7 @@ def create_interruptible_agent_response_event( self.conversation.interruptible_events.put_nowait(interruptible_event) return interruptible_event - class TranscriptionsWorker(AsyncQueueWorker): + class TranscriptionsWorker(AsyncQueueWorker[Transcription]): """Processes all transcriptions: sends an interrupt if needed and sends final transcriptions to the output queue""" @@ -267,7 +270,7 @@ async def process(self, transcription: Transcription): ) if not self.conversation.is_human_speaking: self.conversation.current_transcription_is_interrupt = ( - self.conversation.broadcast_interrupt() + await self.conversation.broadcast_interrupt() ) self.has_associated_unignored_utterance = not transcription.is_final if self.conversation.current_transcription_is_interrupt: @@ -590,7 +593,6 @@ def __init__( synthesizer: BaseSynthesizer, speed_coefficient: float = 1.0, conversation_id: Optional[str] = None, - per_chunk_allowance_seconds: float = PER_CHUNK_ALLOWANCE_SECONDS, events_manager: Optional[EventsManager] = None, ): self.id = conversation_id or create_conversation_id() @@ -655,7 +657,6 @@ def __init__( self.events_manager = events_manager or EventsManager() self.events_task: Optional[asyncio.Task] = None - self.per_chunk_allowance_seconds = per_chunk_allowance_seconds self.transcript = Transcript() self.transcript.attach_events_manager(self.events_manager) @@ -678,6 +679,8 @@ def __init__( self.agent.get_agent_config().allowed_idle_time_seconds or ALLOWED_IDLE_TIME ) + self.interrupt_lock = asyncio.Lock() + def create_state_manager(self) -> ConversationStateManager: return ConversationStateManager(conversation=self) @@ -806,28 +809,30 @@ def warmup_synthesizer(self): def mark_last_action_timestamp(self): self.last_action_timestamp = time.time() - def broadcast_interrupt(self): + async def broadcast_interrupt(self): """Stops all inflight events and cancels all workers that are sending output Returns true if any events were interrupted - which is used as a flag for the agent (is_interrupt) """ - num_interrupts = 0 - while True: - try: - interruptible_event = self.interruptible_events.get_nowait() - if not interruptible_event.is_interrupted(): - if interruptible_event.interrupt(): - logger.debug( - f"Interrupting event {type(interruptible_event.payload)} {interruptible_event.payload}", - ) - num_interrupts += 1 - except queue.Empty: - break - self.agent.cancel_current_task() - self.agent_responses_worker.cancel_current_task() - if self.actions_worker: - self.actions_worker.cancel_current_task() - return num_interrupts > 0 + async with self.interrupt_lock: + num_interrupts = 0 + while True: + try: + interruptible_event = self.interruptible_events.get_nowait() + if not interruptible_event.is_interrupted(): + if interruptible_event.interrupt(): + logger.debug( + f"Interrupting event {type(interruptible_event.payload)} {interruptible_event.payload}", + ) + num_interrupts += 1 + except queue.Empty: + break + self.output_device.interrupt() + self.agent.cancel_current_task() + self.agent_responses_worker.cancel_current_task() + if self.actions_worker: + self.actions_worker.cancel_current_task() + return num_interrupts > 0 def is_interrupt(self, transcription: Transcription): return transcription.confidence >= ( @@ -880,83 +885,97 @@ async def send_speech_to_output( - If the stop_event is set, the output is stopped - Sets started_event when the first chunk is sent - Importantly, we rate limit the chunks sent to the output. For interrupts to work properly, - the next chunk of audio can only be sent after the last chunk is played, so we send - a chunk of x seconds only after x seconds have passed since the last chunk was sent. - Returns the message that was sent up to, and a flag if the message was cut off """ + seconds_spoken = 0.0 - async def get_chunks( - output_queue: asyncio.Queue[Optional[SynthesisResult.ChunkResult]], - chunk_generator: AsyncGenerator[SynthesisResult.ChunkResult, None], + def create_on_play_callback( + chunk_idx: int, + processed_event: asyncio.Event, ): - try: - async for chunk_result in chunk_generator: - await output_queue.put(chunk_result) - except asyncio.CancelledError: - pass - finally: - await output_queue.put(None) # sentinel + def _on_play(): + if chunk_idx == 0: + if started_event: + started_event.set() + if first_chunk_span: + self._track_first_chunk(first_chunk_span, synthesis_result) + + nonlocal seconds_spoken + + self.mark_last_action_timestamp() + + seconds_spoken += seconds_per_chunk + if transcript_message: + transcript_message.text = synthesis_result.get_message_up_to(seconds_spoken) + + processed_event.set() + + return _on_play + + def create_on_interrupt_callback( + processed_event: asyncio.Event, + ): + def _on_interrupt(): + processed_event.set() + + return _on_interrupt if self.transcriber.get_transcriber_config().mute_during_speech: logger.debug("Muting transcriber") self.transcriber.mute() - message_sent = message - cut_off = False - chunk_size = self._get_synthesizer_chunk_size(seconds_per_chunk) - chunk_idx = 0 - seconds_spoken = 0.0 logger.debug(f"Start sending speech {message} to output") first_chunk_span = self._maybe_create_first_chunk_span(synthesis_result, message) - chunk_queue: asyncio.Queue[Optional[SynthesisResult.ChunkResult]] = asyncio.Queue() - get_chunks_task = asyncio_create_task_with_done_error_log( - get_chunks(chunk_queue, synthesis_result.chunk_generator), - ) - first = True - while True: - chunk_result = await chunk_queue.get() - if chunk_result is None: - break - if first and first_chunk_span: - self._track_first_chunk(first_chunk_span, synthesis_result) - first = False - start_time = time.time() - speech_length_seconds = seconds_per_chunk * (len(chunk_result.chunk) / chunk_size) - seconds_spoken = chunk_idx * seconds_per_chunk + audio_chunks: List[AudioChunk] = [] + processed_events: List[asyncio.Event] = [] + async for chunk_idx, chunk_result in enumerate_async_iter(synthesis_result.chunk_generator): if stop_event.is_set(): - logger.debug( - "Interrupted, stopping text to speech after {} chunks".format(chunk_idx), - ) - message_sent = synthesis_result.get_message_up_to(seconds_spoken) - cut_off = True + logger.debug("Interrupted before all chunks were sent") break - if chunk_idx == 0: - if started_event: - started_event.set() - self.output_device.consume_nonblocking(chunk_result.chunk) - end_time = time.time() - await asyncio.sleep( - max( - speech_length_seconds - - (end_time - start_time) - - self.per_chunk_allowance_seconds, - 0, - ), + processed_event = asyncio.Event() + audio_chunk = AudioChunk( + data=chunk_result.chunk, ) - self.mark_last_action_timestamp() - chunk_idx += 1 - seconds_spoken += seconds_per_chunk - if transcript_message: - transcript_message.text = synthesis_result.get_message_up_to(seconds_spoken) - get_chunks_task.cancel() + # register callbacks + setattr(audio_chunk, "on_play", create_on_play_callback(chunk_idx, processed_event)) + setattr( + audio_chunk, + "on_interrupt", + create_on_interrupt_callback(processed_event), + ) + async with self.interrupt_lock: + self.output_device.consume_nonblocking( + InterruptibleEvent( + payload=audio_chunk, + is_interruptible=True, + interruption_event=stop_event, + ), + ) + audio_chunks.append(audio_chunk) + processed_events.append(processed_event) + + logger.debug("Finished sending chunks to the output device") + + await asyncio.gather(*(processed_event.wait() for processed_event in processed_events)) + + maybe_first_interrupted_audio_chunk = next( + ( + audio_chunk + for audio_chunk in audio_chunks + if audio_chunk.state == ChunkState.INTERRUPTED + ), + None, + ) + cut_off = maybe_first_interrupted_audio_chunk is not None + if not cut_off: # if the audio was not cut off, we can set the transcript message to the full message + transcript_message.text = synthesis_result.get_message_up_to(None) + if self.transcriber.get_transcriber_config().mute_during_speech: logger.debug("Unmuting transcriber") self.transcriber.unmute() if transcript_message: - transcript_message.text = message_sent transcript_message.is_final = not cut_off + message_sent = transcript_message.text if transcript_message and cut_off else message if synthesis_result.synthesis_total_span: synthesis_result.synthesis_total_span.finish() return message_sent, cut_off @@ -966,7 +985,7 @@ def mark_terminated(self, bot_disconnect: bool = False): async def terminate(self): self.mark_terminated() - self.broadcast_interrupt() + await self.broadcast_interrupt() self.events_manager.publish_event( TranscriptCompleteEvent( conversation_id=self.id, diff --git a/vocode/streaming/synthesizer/azure_synthesizer.py b/vocode/streaming/synthesizer/azure_synthesizer.py index f2aaf2122..eee831e29 100644 --- a/vocode/streaming/synthesizer/azure_synthesizer.py +++ b/vocode/streaming/synthesizer/azure_synthesizer.py @@ -218,9 +218,11 @@ def get_message_up_to( self, message: str, ssml: str, - seconds: float, + seconds: Optional[float], word_boundary_event_pool: WordBoundaryEventPool, ) -> str: + if seconds is None: + return message events = word_boundary_event_pool.get_events_sorted() for event in events: if event["audio_offset"] > seconds: diff --git a/vocode/streaming/synthesizer/base_synthesizer.py b/vocode/streaming/synthesizer/base_synthesizer.py index e3e86661f..0c58d2c8d 100644 --- a/vocode/streaming/synthesizer/base_synthesizer.py +++ b/vocode/streaming/synthesizer/base_synthesizer.py @@ -50,6 +50,13 @@ def encode_as_wav(chunk: bytes, synthesizer_config: SynthesizerConfig) -> bytes: class SynthesisResult: + """Holds audio bytes for an utterance and method to know how much of utterance was spoken + + @param chunk_generator - an async generator that that yields ChunkResult objects, which contain chunks of audio and a flag indicating if it is the last chunk + @param get_message_up_to - takes in the number of seconds spoken and returns the message up to that point + - *if seconds is None, then it should return the full messages* + """ + class ChunkResult: def __init__(self, chunk: bytes, is_last_chunk: bool): self.chunk = chunk @@ -58,7 +65,7 @@ def __init__(self, chunk: bytes, is_last_chunk: bool): def __init__( self, chunk_generator: AsyncGenerator[ChunkResult, None], - get_message_up_to: Callable[[float], str], + get_message_up_to: Callable[[Optional[float]], str], cached: bool = False, is_first: bool = False, synthesis_total_span: Optional[SentrySpan] = None, @@ -111,7 +118,7 @@ async def chunk_generator(chunk_transform=lambda x: x): ) else: output_generator = chunk_generator() - return SynthesisResult(output_generator, lambda seconds: self.message.text) + return SynthesisResult(output_generator, lambda _: self.message.text) class CachedAudio: @@ -155,7 +162,7 @@ def get_message_up_to(seconds): else: - def get_message_up_to(seconds): + def get_message_up_to(seconds: Optional[float]): return BaseSynthesizer.get_message_cutoff_from_total_response_length( self.synthesizer_config, self.message, seconds, len(self.audio_data) ) @@ -273,20 +280,27 @@ def ready_synthesizer(self, chunk_size: int): def get_message_cutoff_from_total_response_length( synthesizer_config: SynthesizerConfig, message: BaseMessage, - seconds: float, + seconds: Optional[float], size_of_output: int, ) -> str: estimated_output_seconds = size_of_output / synthesizer_config.sampling_rate if not message.text: return message.text + if seconds is None: + return message.text + estimated_output_seconds_per_char = estimated_output_seconds / len(message.text) return message.text[: int(seconds / estimated_output_seconds_per_char)] @staticmethod def get_message_cutoff_from_voice_speed( - message: BaseMessage, seconds: float, words_per_minute: int + message: BaseMessage, seconds: Optional[float], words_per_minute: int ) -> str: + + if seconds is None: + return message.text + words_per_second = words_per_minute / 60 estimated_words_spoken = math.floor(words_per_second * seconds) tokens = word_tokenize(message.text) diff --git a/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py b/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py index c4644ad6c..d95271c41 100644 --- a/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py +++ b/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py @@ -329,11 +329,11 @@ def ready_synthesizer(self, chunk_size: int): self.establish_websocket_listeners(chunk_size) ) - def get_current_message_so_far(self, seconds: float) -> str: + def get_current_message_so_far(self, seconds: Optional[float]) -> str: seconds_idx = 0.0 buffer = "" for utterance, duration in self.current_turn_utterances_by_chunk: - if seconds_idx > seconds: + if seconds is not None and seconds_idx > seconds: return buffer buffer += utterance seconds_idx += duration diff --git a/vocode/streaming/synthesizer/polly_synthesizer.py b/vocode/streaming/synthesizer/polly_synthesizer.py index 4f36cd905..72385e05f 100644 --- a/vocode/streaming/synthesizer/polly_synthesizer.py +++ b/vocode/streaming/synthesizer/polly_synthesizer.py @@ -1,7 +1,7 @@ import asyncio import json from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, Optional import boto3 @@ -63,9 +63,11 @@ def get_speech_marks(self, message: str) -> Any: def get_message_up_to( self, message: str, - seconds: float, + seconds: Optional[float], word_events, ) -> str: + if seconds is None: + return message for event in word_events: # time field is in ms if event["time"] > seconds * 1000: diff --git a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py index 308919a40..1e3cf4c0f 100644 --- a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py @@ -49,7 +49,6 @@ def __init__( conversation_id: Optional[str] = None, events_manager: Optional[EventsManager] = None, speed_coefficient: float = 1.0, - per_chunk_allowance_seconds: float = 0.01, ): conversation_id = conversation_id or create_conversation_id() ctx_conversation_id.set(conversation_id) @@ -64,7 +63,6 @@ def __init__( agent_factory.create_agent(agent_config), synthesizer_factory.create_synthesizer(synthesizer_config), conversation_id=conversation_id, - per_chunk_allowance_seconds=per_chunk_allowance_seconds, events_manager=events_manager, speed_coefficient=speed_coefficient, ) diff --git a/vocode/streaming/telephony/conversation/mark_message_queue.py b/vocode/streaming/telephony/conversation/mark_message_queue.py deleted file mode 100644 index c4b17b931..000000000 --- a/vocode/streaming/telephony/conversation/mark_message_queue.py +++ /dev/null @@ -1,46 +0,0 @@ -import asyncio -from typing import Dict, Union - -from pydantic.v1 import BaseModel - - -class ChunkFinishedMarkMessage(BaseModel): - chunk_idx: int - - -class UtteranceFinishedMarkMessage(BaseModel): - pass - - -MarkMessage = Union[ChunkFinishedMarkMessage, UtteranceFinishedMarkMessage] - - -class MarkMessageQueue: - """A keyed asyncio.Queue for MarkMessage objects""" - - def __init__(self): - self.utterance_queues: Dict[str, asyncio.Queue[MarkMessage]] = {} - - def create_utterance_queue(self, utterance_id: str): - if utterance_id in self.utterance_queues: - raise ValueError(f"utterance_id {utterance_id} already exists") - self.utterance_queues[utterance_id] = asyncio.Queue() - - def put_nowait( - self, - utterance_id: str, - mark_message: MarkMessage, - ): - if utterance_id in self.utterance_queues: - self.utterance_queues[utterance_id].put_nowait(mark_message) - - async def get( - self, - utterance_id: str, - ) -> MarkMessage: - if utterance_id not in self.utterance_queues: - raise ValueError(f"utterance_id {utterance_id} not found") - return await self.utterance_queues[utterance_id].get() - - def delete_utterance_queue(self, utterance_id: str): - del self.utterance_queues[utterance_id] diff --git a/vocode/streaming/telephony/conversation/twilio_phone_conversation.py b/vocode/streaming/telephony/conversation/twilio_phone_conversation.py index 6145d53b0..f3c833726 100644 --- a/vocode/streaming/telephony/conversation/twilio_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/twilio_phone_conversation.py @@ -1,10 +1,8 @@ -import asyncio import base64 import json import os -import threading from enum import Enum -from typing import AsyncGenerator, Optional +from typing import Optional from fastapi import WebSocket from loguru import logger @@ -15,25 +13,17 @@ from vocode.streaming.models.synthesizer import SynthesizerConfig from vocode.streaming.models.telephony import PhoneCallDirection, TwilioConfig from vocode.streaming.models.transcriber import TranscriberConfig -from vocode.streaming.models.transcript import Message -from vocode.streaming.output_device.twilio_output_device import TwilioOutputDevice +from vocode.streaming.output_device.twilio_output_device import ( + ChunkFinishedMarkMessage, + TwilioOutputDevice, +) from vocode.streaming.synthesizer.abstract_factory import AbstractSynthesizerFactory -from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult -from vocode.streaming.synthesizer.input_streaming_synthesizer import InputStreamingSynthesizer from vocode.streaming.telephony.client.twilio_client import TwilioClient from vocode.streaming.telephony.config_manager.base_config_manager import BaseConfigManager from vocode.streaming.telephony.conversation.abstract_phone_conversation import ( AbstractPhoneConversation, ) -from vocode.streaming.telephony.conversation.mark_message_queue import ( - ChunkFinishedMarkMessage, - MarkMessage, - MarkMessageQueue, - UtteranceFinishedMarkMessage, -) from vocode.streaming.transcriber.abstract_factory import AbstractTranscriberFactory -from vocode.streaming.utils import create_utterance_id -from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log from vocode.streaming.utils.events_manager import EventsManager from vocode.streaming.utils.state_manager import TwilioPhoneConversationStateManager @@ -83,7 +73,6 @@ def __init__( synthesizer_factory=synthesizer_factory, speed_coefficient=speed_coefficient, ) - self.mark_message_queue: MarkMessageQueue = MarkMessageQueue() self.config_manager = config_manager self.twilio_config = twilio_config or TwilioConfig( account_sid=os.environ["TWILIO_ACCOUNT_SID"], @@ -140,135 +129,10 @@ async def _handle_ws_message(self, message) -> Optional[TwilioPhoneConversationW chunk = base64.b64decode(media["payload"]) self.receive_audio(chunk) if data["event"] == "mark": - mark_name = data["mark"]["name"] - if mark_name.startswith("chunk-"): - utterance_id, chunk_idx = mark_name.split("-")[1:] - self.mark_message_queue.put_nowait( - utterance_id=utterance_id, - mark_message=ChunkFinishedMarkMessage(chunk_idx=int(chunk_idx)), - ) - elif mark_name.startswith("utterance"): - utterance_id = mark_name.split("-")[1] - self.mark_message_queue.put_nowait( - utterance_id=utterance_id, - mark_message=UtteranceFinishedMarkMessage(), - ) + chunk_id = data["mark"]["name"] + self.output_device.enqueue_mark_message(ChunkFinishedMarkMessage(chunk_id=chunk_id)) elif data["event"] == "stop": logger.debug(f"Media WS: Received event 'stop': {message}") logger.debug("Stopping...") return TwilioPhoneConversationWebsocketAction.CLOSE_WEBSOCKET return None - - async def _send_chunks( - self, - utterance_id: str, - chunk_generator: AsyncGenerator[SynthesisResult.ChunkResult, None], - clear_message_lock: asyncio.Lock, - stop_event: threading.Event, - ): - chunk_idx = 0 - try: - async for chunk_result in chunk_generator: - async with clear_message_lock: - if stop_event.is_set(): - break - self.output_device.consume_nonblocking(chunk_result.chunk) - self.output_device.send_chunk_finished_mark(utterance_id, chunk_idx) - chunk_idx += 1 - except asyncio.CancelledError: - pass - finally: - logger.debug("Finished sending all chunks to Twilio") - self.output_device.send_utterance_finished_mark(utterance_id) - - async def send_speech_to_output( - self, - message: str, - synthesis_result: SynthesisResult, - stop_event: threading.Event, - seconds_per_chunk: float, - transcript_message: Optional[Message] = None, - started_event: Optional[threading.Event] = None, - ): - """In contrast with send_speech_to_output in the base class, this function uses mark messages - to support interruption - we send all chunks to the output device, and then wait for mark messages[0] - that indicate that each chunk has been played. This means that we don't need to depends on asyncio.sleep - to support interruptions. - - Once we receive an interruption signal: - - we send a clear message to Twilio to stop playing all queued audio - - based on the number of mark messages we've received back, we know how many chunks were played and can indicate on the transcript - - [0] https://www.twilio.com/docs/voice/twiml/stream#websocket-messages-to-twilio - """ - - if self.transcriber.get_transcriber_config().mute_during_speech: - logger.debug("Muting transcriber") - self.transcriber.mute() - message_sent = message - cut_off = False - chunk_idx = 0 - seconds_spoken = 0.0 - logger.debug(f"Start sending speech {message} to output") - - utterance_id = create_utterance_id() - self.mark_message_queue.create_utterance_queue(utterance_id) - - first_chunk_span = self._maybe_create_first_chunk_span(synthesis_result, message) - - clear_message_lock = asyncio.Lock() - - asyncio_create_task_with_done_error_log( - self._send_chunks( - utterance_id, - synthesis_result.chunk_generator, - clear_message_lock, - stop_event, - ), - ) - mark_event: MarkMessage - first = True - while True: - mark_event = await self.mark_message_queue.get(utterance_id) - if isinstance(mark_event, UtteranceFinishedMarkMessage): - break - if first and first_chunk_span: - self._track_first_chunk(first_chunk_span, synthesis_result) - first = False - seconds_spoken = mark_event.chunk_idx * seconds_per_chunk - # Lock here so that we check the stop event and send the clear message atomically - # w.r.t. the _send_chunks task which also checks the stop event - # Otherwise, we could send the clear message while _send_chunks is in the middle of sending a chunk - # and the synthesis wouldn't be cleared - async with clear_message_lock: - if stop_event.is_set(): - self.output_device.send_clear_message() - logger.debug( - "Interrupted, stopping text to speech after {} chunks".format(chunk_idx) - ) - message_sent = synthesis_result.get_message_up_to(seconds_spoken) - cut_off = True - break - if chunk_idx == 0: - if started_event: - started_event.set() - self.mark_last_action_timestamp() - chunk_idx += 1 - seconds_spoken += seconds_per_chunk - if transcript_message: - transcript_message.text = synthesis_result.get_message_up_to(seconds_spoken) - self.mark_message_queue.delete_utterance_queue(utterance_id) - if self.transcriber.get_transcriber_config().mute_during_speech: - logger.debug("Unmuting transcriber") - self.transcriber.unmute() - if transcript_message: - # For input streaming synthesizers, we have to buffer the message as it is streamed in - # What is said is federated fully by synthesis_result.get_message_up_to - if isinstance(self.synthesizer, InputStreamingSynthesizer): - message_sent = transcript_message.text - else: - transcript_message.text = message_sent - transcript_message.is_final = not cut_off - if synthesis_result.synthesis_total_span: - synthesis_result.synthesis_total_span.finish() - return message_sent, cut_off diff --git a/vocode/streaming/telephony/conversation/vonage_phone_conversation.py b/vocode/streaming/telephony/conversation/vonage_phone_conversation.py index 8af5399d5..38629f89e 100644 --- a/vocode/streaming/telephony/conversation/vonage_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/vonage_phone_conversation.py @@ -47,7 +47,6 @@ def __init__( events_manager: Optional[EventsManager] = None, output_to_speaker: bool = False, speed_coefficient: float = 1.0, - per_chunk_allowance_seconds: float = 0.01, noise_suppression: bool = False, ): self.speed_coefficient = speed_coefficient @@ -67,7 +66,6 @@ def __init__( transcriber_factory=transcriber_factory, agent_factory=agent_factory, synthesizer_factory=synthesizer_factory, - per_chunk_allowance_seconds=per_chunk_allowance_seconds, ) self.vonage_config = vonage_config self.telephony_client = VonageClient( diff --git a/vocode/streaming/transcriber/base_transcriber.py b/vocode/streaming/transcriber/base_transcriber.py index 4745b79a6..8de24ace7 100644 --- a/vocode/streaming/transcriber/base_transcriber.py +++ b/vocode/streaming/transcriber/base_transcriber.py @@ -58,19 +58,18 @@ def terminate(self): pass -class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker): +class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) AsyncWorker.__init__(self, self.input_queue, self.output_queue) - async def _run_loop(self): - raise NotImplementedError - def terminate(self): AsyncWorker.terminate(self) -class BaseThreadAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], ThreadAsyncWorker): +class BaseThreadAsyncTranscriber( + AbstractTranscriber[TranscriberConfigType], ThreadAsyncWorker[bytes] +): def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) ThreadAsyncWorker.__init__(self, self.input_queue, self.output_queue) diff --git a/vocode/streaming/utils/__init__.py b/vocode/streaming/utils/__init__.py index 6f2585e09..31dd24998 100644 --- a/vocode/streaming/utils/__init__.py +++ b/vocode/streaming/utils/__init__.py @@ -135,3 +135,12 @@ async def generate_from_async_iter_with_lookahead( if buffer and stream_length <= lookahead: yield buffer return + + +async def enumerate_async_iter( + async_iter: AsyncIterator[AsyncIteratorGenericType], +) -> AsyncGenerator[Tuple[int, AsyncIteratorGenericType], None]: + i = 0 + async for item in async_iter: + yield i, item + i += 1 diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index 80021b580..15070ea60 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -93,7 +93,7 @@ def terminate(self): return super().terminate() -class AsyncQueueWorker(AsyncWorker): +class AsyncQueueWorker(AsyncWorker[WorkerInputType]): async def _run_loop(self): while True: try: diff --git a/vocode/turn_based/output_device/base_output_device.py b/vocode/turn_based/output_device/abstract_output_device.py similarity index 55% rename from vocode/turn_based/output_device/base_output_device.py rename to vocode/turn_based/output_device/abstract_output_device.py index d54c0c7fd..d111dd67a 100644 --- a/vocode/turn_based/output_device/base_output_device.py +++ b/vocode/turn_based/output_device/abstract_output_device.py @@ -1,9 +1,12 @@ +from abc import ABC, abstractmethod from pydub import AudioSegment -class BaseOutputDevice: +class AbstractOutputDevice(ABC): + + @abstractmethod def send_audio(self, audio: AudioSegment) -> None: - raise NotImplementedError + pass def terminate(self): pass diff --git a/vocode/turn_based/output_device/speaker_output.py b/vocode/turn_based/output_device/speaker_output.py index a0b35dc4d..a3f748f9e 100644 --- a/vocode/turn_based/output_device/speaker_output.py +++ b/vocode/turn_based/output_device/speaker_output.py @@ -4,10 +4,10 @@ import sounddevice as sd from pydub import AudioSegment -from vocode.turn_based.output_device.base_output_device import BaseOutputDevice +from vocode.turn_based.output_device.abstract_output_device import AbstractOutputDevice -class SpeakerOutput(BaseOutputDevice): +class SpeakerOutput(AbstractOutputDevice): DEFAULT_SAMPLING_RATE = 44100 def __init__( diff --git a/vocode/turn_based/turn_based_conversation.py b/vocode/turn_based/turn_based_conversation.py index faddb613b..0eb9e506a 100644 --- a/vocode/turn_based/turn_based_conversation.py +++ b/vocode/turn_based/turn_based_conversation.py @@ -2,7 +2,7 @@ from vocode.turn_based.agent.base_agent import BaseAgent from vocode.turn_based.input_device.base_input_device import BaseInputDevice -from vocode.turn_based.output_device.base_output_device import BaseOutputDevice +from vocode.turn_based.output_device.abstract_output_device import AbstractOutputDevice from vocode.turn_based.synthesizer.base_synthesizer import BaseSynthesizer from vocode.turn_based.transcriber.base_transcriber import BaseTranscriber @@ -14,7 +14,7 @@ def __init__( transcriber: BaseTranscriber, agent: BaseAgent, synthesizer: BaseSynthesizer, - output_device: BaseOutputDevice, + output_device: AbstractOutputDevice, ): self.input_device = input_device self.transcriber = transcriber From 10254e3b245c532ef966ea2af54b077c2febd68b Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Wed, 26 Jun 2024 18:19:13 -0700 Subject: [PATCH 02/14] fix mypy --- vocode/streaming/streaming_conversation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index b6d2def00..41c668225 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -967,7 +967,9 @@ def _on_interrupt(): None, ) cut_off = maybe_first_interrupted_audio_chunk is not None - if not cut_off: # if the audio was not cut off, we can set the transcript message to the full message + if ( + transcript_message and not cut_off + ): # if the audio was not cut off, we can set the transcript message to the full message transcript_message.text = synthesis_result.get_message_up_to(None) if self.transcriber.get_transcriber_config().mute_during_speech: From 6f7ac5c6c3f4148eaf36517404964fff042f68f6 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 10:28:03 -0700 Subject: [PATCH 03/14] fix mypy --- vocode/streaming/synthesizer/base_synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vocode/streaming/synthesizer/base_synthesizer.py b/vocode/streaming/synthesizer/base_synthesizer.py index 0c58d2c8d..ea368cc15 100644 --- a/vocode/streaming/synthesizer/base_synthesizer.py +++ b/vocode/streaming/synthesizer/base_synthesizer.py @@ -157,7 +157,7 @@ async def chunk_generator(): if isinstance(self.message, BotBackchannel): - def get_message_up_to(seconds): + def get_message_up_to(seconds: Optional[float]): return self.message.text else: From d69b260aebd785ec64960e6817d4ebdc79f2b1e1 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 11:53:36 -0700 Subject: [PATCH 04/14] isort --- vocode/streaming/output_device/abstract_output_device.py | 3 ++- vocode/streaming/output_device/audio_chunk.py | 2 +- vocode/streaming/output_device/file_output_device.py | 3 +-- .../output_device/rate_limit_interruptions_output_device.py | 2 +- vocode/streaming/output_device/speaker_output.py | 3 ++- vocode/streaming/output_device/twilio_output_device.py | 4 ++-- vocode/streaming/streaming_conversation.py | 2 +- vocode/turn_based/output_device/abstract_output_device.py | 1 + 8 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vocode/streaming/output_device/abstract_output_device.py b/vocode/streaming/output_device/abstract_output_device.py index 746a4ffd8..d675f39ad 100644 --- a/vocode/streaming/output_device/abstract_output_device.py +++ b/vocode/streaming/output_device/abstract_output_device.py @@ -1,5 +1,6 @@ -from abc import abstractmethod import asyncio +from abc import abstractmethod + from vocode.streaming.output_device.audio_chunk import AudioChunk from vocode.streaming.utils.worker import AsyncWorker, InterruptibleEvent diff --git a/vocode/streaming/output_device/audio_chunk.py b/vocode/streaming/output_device/audio_chunk.py index d58e081ad..df669d19b 100644 --- a/vocode/streaming/output_device/audio_chunk.py +++ b/vocode/streaming/output_device/audio_chunk.py @@ -1,7 +1,7 @@ +import uuid from dataclasses import dataclass, field from enum import Enum from uuid import UUID -import uuid class ChunkState(int, Enum): diff --git a/vocode/streaming/output_device/file_output_device.py b/vocode/streaming/output_device/file_output_device.py index 5316a8fa4..99e72c520 100644 --- a/vocode/streaming/output_device/file_output_device.py +++ b/vocode/streaming/output_device/file_output_device.py @@ -3,12 +3,11 @@ import numpy as np +from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( RateLimitInterruptionsOutputDevice, ) -from vocode.streaming.models.audio import AudioEncoding - class FileOutputDevice(RateLimitInterruptionsOutputDevice): DEFAULT_SAMPLING_RATE = 44100 diff --git a/vocode/streaming/output_device/rate_limit_interruptions_output_device.py b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py index 2cfb94d64..98493b9d8 100644 --- a/vocode/streaming/output_device/rate_limit_interruptions_output_device.py +++ b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py @@ -1,6 +1,6 @@ -from abc import abstractmethod import asyncio import time +from abc import abstractmethod from vocode.streaming.constants import PER_CHUNK_ALLOWANCE_SECONDS from vocode.streaming.models.audio import AudioEncoding diff --git a/vocode/streaming/output_device/speaker_output.py b/vocode/streaming/output_device/speaker_output.py index b7861dbad..c416c814e 100644 --- a/vocode/streaming/output_device/speaker_output.py +++ b/vocode/streaming/output_device/speaker_output.py @@ -4,9 +4,10 @@ import numpy as np import sounddevice as sd -from .abstract_output_device import AbstractOutputDevice from vocode.streaming.models.audio import AudioEncoding +from .abstract_output_device import AbstractOutputDevice + raise DeprecationWarning("Use BlockingSpeakerOutput instead") diff --git a/vocode/streaming/output_device/twilio_output_device.py b/vocode/streaming/output_device/twilio_output_device.py index c0509d57f..9599a7217 100644 --- a/vocode/streaming/output_device/twilio_output_device.py +++ b/vocode/streaming/output_device/twilio_output_device.py @@ -3,15 +3,15 @@ import asyncio import base64 import json -from typing import Optional, Union import uuid +from typing import Optional, Union from fastapi import WebSocket from fastapi.websockets import WebSocketState from pydantic import BaseModel -from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState from vocode.streaming.telephony.constants import DEFAULT_AUDIO_ENCODING, DEFAULT_SAMPLING_RATE from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log from vocode.streaming.utils.worker import InterruptibleEvent diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index 41c668225..90e50d997 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -46,8 +46,8 @@ from vocode.streaming.models.message import BaseMessage, BotBackchannel, LLMToken, SilenceMessage from vocode.streaming.models.transcriber import TranscriberConfig, Transcription from vocode.streaming.models.transcript import Message, Transcript, TranscriptCompleteEvent -from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState from vocode.streaming.synthesizer.base_synthesizer import ( BaseSynthesizer, FillerAudio, diff --git a/vocode/turn_based/output_device/abstract_output_device.py b/vocode/turn_based/output_device/abstract_output_device.py index d111dd67a..c2eb148f2 100644 --- a/vocode/turn_based/output_device/abstract_output_device.py +++ b/vocode/turn_based/output_device/abstract_output_device.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod + from pydub import AudioSegment From 67149e81321757746af31b095ebedb89e0825a48 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 12:04:10 -0700 Subject: [PATCH 05/14] creates first test and adds scaffolding --- tests/fakedata/conversation.py | 32 +++++++++- ..._rate_limit_interruptions_output_device.py | 0 .../test_twilio_output_device.py | 0 .../streaming/test_streaming_conversation.py | 59 ++++++++++++++++++- .../output_device/abstract_output_device.py | 8 +++ .../output_device/twilio_output_device.py | 7 --- 6 files changed, 96 insertions(+), 10 deletions(-) create mode 100644 tests/streaming/output_device/test_rate_limit_interruptions_output_device.py create mode 100644 tests/streaming/output_device/test_twilio_output_device.py diff --git a/tests/fakedata/conversation.py b/tests/fakedata/conversation.py index 2b6917b21..92f4348f4 100644 --- a/tests/fakedata/conversation.py +++ b/tests/fakedata/conversation.py @@ -1,3 +1,5 @@ +import asyncio +import time from typing import Optional from pytest_mock import MockerFixture @@ -9,6 +11,7 @@ from vocode.streaming.models.synthesizer import PlayHtSynthesizerConfig, SynthesizerConfig from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, TranscriberConfig from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.output_device.audio_chunk import ChunkState from vocode.streaming.streaming_conversation import StreamingConversation from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer from vocode.streaming.telephony.constants import DEFAULT_CHUNK_SIZE, DEFAULT_SAMPLING_RATE @@ -37,8 +40,33 @@ class DummyOutputDevice(AbstractOutputDevice): - def consume_nonblocking(self, chunk: bytes): - pass + def process(self, item): + self.interruptible_event = item + audio_chunk = item.payload + + if item.is_interrupted(): + audio_chunk.on_interrupt() + audio_chunk.state = ChunkState.INTERRUPTED + else: + audio_chunk.on_play() + audio_chunk.state = ChunkState.PLAYED + self.interruptible_event.is_interruptible = False + + async def _run_loop(self): + while True: + try: + item = await self.input_queue.get() + except asyncio.CancelledError: + return + self.process(item) + + def flush(self): + while True: + try: + item = self.input_queue.get_nowait() + except asyncio.QueueEmpty: + break + self.process(item) def interrupt(self): pass diff --git a/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py b/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/streaming/output_device/test_twilio_output_device.py b/tests/streaming/output_device/test_twilio_output_device.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/streaming/test_streaming_conversation.py b/tests/streaming/test_streaming_conversation.py index 429943837..614ece854 100644 --- a/tests/streaming/test_streaming_conversation.py +++ b/tests/streaming/test_streaming_conversation.py @@ -1,5 +1,6 @@ import asyncio -from typing import List +import threading +from typing import List, Optional from unittest.mock import MagicMock import pytest @@ -16,6 +17,7 @@ from vocode.streaming.models.events import Sender from vocode.streaming.models.transcriber import Transcription from vocode.streaming.models.transcript import ActionStart, Message, Transcript +from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult from vocode.streaming.utils.worker import AsyncWorker @@ -451,3 +453,58 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun assert streaming_conversation.broadcast_interrupt.called streaming_conversation.transcriptions_worker.terminate() + + +def _create_dummy_synthesis_result(message: str = "Hi there", num_audio_chunks: int = 3): + async def chunk_generator(): + for i in range(num_audio_chunks): + yield SynthesisResult.ChunkResult(chunk=b"", is_last_chunk=i == num_audio_chunks - 1) + + def get_message_up_to(seconds: Optional[float]): + if seconds is None: + return message + return message[: len(message) // 2] + + return SynthesisResult(chunk_generator=chunk_generator(), get_message_up_to=get_message_up_to) + + +@pytest.mark.asyncio +async def test_send_speech_to_output_uninterrupted( + mocker: MockerFixture, +): + streaming_conversation = await _mock_streaming_conversation_constructor(mocker) + synthesis_result = _create_dummy_synthesis_result() + stop_event = threading.Event() + transcript_message = Message( + text="", + sender=Sender.BOT, + ) + + streaming_conversation.output_device.start() + message_sent, cut_off = await streaming_conversation.send_speech_to_output( + message="Hi there", + synthesis_result=synthesis_result, + stop_event=stop_event, + seconds_per_chunk=0.1, + transcript_message=transcript_message, + ) + streaming_conversation.output_device.flush() + + assert message_sent == "Hi there" + assert not cut_off + assert transcript_message.text == "Hi there" + assert transcript_message.is_final + + +@pytest.mark.asyncio +async def test_send_speech_to_output_interrupted_before_all_chunks_sent( + mocker: MockerFixture, +): + pass + + +@pytest.mark.asyncio +async def test_send_speech_to_output_interrupted_during_playback( + mocker: MockerFixture, +): + pass diff --git a/vocode/streaming/output_device/abstract_output_device.py b/vocode/streaming/output_device/abstract_output_device.py index d675f39ad..14002075d 100644 --- a/vocode/streaming/output_device/abstract_output_device.py +++ b/vocode/streaming/output_device/abstract_output_device.py @@ -6,6 +6,13 @@ class AbstractOutputDevice(AsyncWorker[InterruptibleEvent[AudioChunk]]): + """Output devices are workers that are responsible for playing back audio. + + As part of processing: + - it must call AudioChunk.on_play() when the chunk is played back and set AudioChunk.state = ChunkState.PLAYED + - it must call AudioChunk.on_interrupt() when the chunk is interrupted and set AudioChunk.state = ChunkState.INTERRUPTED + - if the interruptible event marker is set, then it must also mark the chunk as interrupted + """ def __init__(self, sampling_rate: int, audio_encoding): super().__init__(input_queue=asyncio.Queue()) @@ -14,4 +21,5 @@ def __init__(self, sampling_rate: int, audio_encoding): @abstractmethod def interrupt(self): + """Must interrupt the currently playing audio""" pass diff --git a/vocode/streaming/output_device/twilio_output_device.py b/vocode/streaming/output_device/twilio_output_device.py index 9599a7217..87e9ddf3e 100644 --- a/vocode/streaming/output_device/twilio_output_device.py +++ b/vocode/streaming/output_device/twilio_output_device.py @@ -46,13 +46,6 @@ def consume_nonblocking(self, item: InterruptibleEvent[AudioChunk]): audio_chunk.on_interrupt() audio_chunk.state = ChunkState.INTERRUPTED - async def play(self, chunk: bytes): - """ - For Twilio, we send all of the audio chunks to be played at once, - and then consume the mark messages to know when to send the on_play / on_interrupt callbacks - """ - pass - def interrupt(self): self._send_clear_message() From 23554f95dbe7266354f009bbaee071f123ccabee Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 14:33:20 -0700 Subject: [PATCH 06/14] adds two other send_speech_to_output tests --- tests/fakedata/conversation.py | 21 ++++- .../streaming/test_streaming_conversation.py | 76 +++++++++++++++++-- .../output_device/abstract_output_device.py | 3 +- vocode/streaming/streaming_conversation.py | 6 +- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/tests/fakedata/conversation.py b/tests/fakedata/conversation.py index 92f4348f4..69e407d33 100644 --- a/tests/fakedata/conversation.py +++ b/tests/fakedata/conversation.py @@ -40,7 +40,20 @@ class DummyOutputDevice(AbstractOutputDevice): - def process(self, item): + + def __init__( + self, + sampling_rate: int, + audio_encoding: AudioEncoding, + wait_for_interrupt: bool = False, + chunks_before_interrupt: int = 1, + ): + super().__init__(sampling_rate, audio_encoding) + self.wait_for_interrupt = wait_for_interrupt + self.chunks_before_interrupt = chunks_before_interrupt + self.interrupt_event = asyncio.Event() + + async def process(self, item): self.interruptible_event = item audio_chunk = item.payload @@ -53,12 +66,16 @@ def process(self, item): self.interruptible_event.is_interruptible = False async def _run_loop(self): + chunk_counter = 0 while True: try: item = await self.input_queue.get() except asyncio.CancelledError: return - self.process(item) + if self.wait_for_interrupt and chunk_counter == self.chunks_before_interrupt: + await self.interrupt_event.wait() + await self.process(item) + chunk_counter += 1 def flush(self): while True: diff --git a/tests/streaming/test_streaming_conversation.py b/tests/streaming/test_streaming_conversation.py index 614ece854..1d4ef026d 100644 --- a/tests/streaming/test_streaming_conversation.py +++ b/tests/streaming/test_streaming_conversation.py @@ -1,6 +1,6 @@ import asyncio import threading -from typing import List, Optional +from typing import AsyncGenerator, List, Optional from unittest.mock import MagicMock import pytest @@ -455,7 +455,11 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun streaming_conversation.transcriptions_worker.terminate() -def _create_dummy_synthesis_result(message: str = "Hi there", num_audio_chunks: int = 3): +def _create_dummy_synthesis_result( + message: str = "Hi there", + num_audio_chunks: int = 3, + chunk_generator_override: Optional[AsyncGenerator[SynthesisResult.ChunkResult, None]] = None, +): async def chunk_generator(): for i in range(num_audio_chunks): yield SynthesisResult.ChunkResult(chunk=b"", is_last_chunk=i == num_audio_chunks - 1) @@ -465,7 +469,10 @@ def get_message_up_to(seconds: Optional[float]): return message return message[: len(message) // 2] - return SynthesisResult(chunk_generator=chunk_generator(), get_message_up_to=get_message_up_to) + return SynthesisResult( + chunk_generator=chunk_generator_override or chunk_generator(), + get_message_up_to=get_message_up_to, + ) @pytest.mark.asyncio @@ -500,11 +507,70 @@ async def test_send_speech_to_output_uninterrupted( async def test_send_speech_to_output_interrupted_before_all_chunks_sent( mocker: MockerFixture, ): - pass + streaming_conversation = await _mock_streaming_conversation_constructor(mocker) + synthesis_result = _create_dummy_synthesis_result() + stop_event = threading.Event() + transcript_message = Message( + text="", + sender=Sender.BOT, + ) + stop_event.set() + + streaming_conversation.output_device.start() + message_sent, cut_off = await streaming_conversation.send_speech_to_output( + message="Hi there", + synthesis_result=synthesis_result, + stop_event=stop_event, + seconds_per_chunk=0.1, + transcript_message=transcript_message, + ) + streaming_conversation.output_device.flush() + + assert message_sent != "Hi there" + assert cut_off + assert transcript_message.text != "Hi there" + assert not transcript_message.is_final @pytest.mark.asyncio async def test_send_speech_to_output_interrupted_during_playback( mocker: MockerFixture, ): - pass + finished_sending_chunks = asyncio.Event() + + async def chunk_generator(): + yield SynthesisResult.ChunkResult(chunk=b"", is_last_chunk=False) + yield SynthesisResult.ChunkResult(chunk=b"", is_last_chunk=False) + yield SynthesisResult.ChunkResult(chunk=b"", is_last_chunk=True) + finished_sending_chunks.set() + + streaming_conversation = await _mock_streaming_conversation_constructor(mocker) + synthesis_result = _create_dummy_synthesis_result(chunk_generator_override=chunk_generator()) + stop_event = threading.Event() + transcript_message = Message( + text="", + sender=Sender.BOT, + ) + + streaming_conversation.output_device.wait_for_interrupt = True + + streaming_conversation.output_device.start() + send_speech_to_output_task = asyncio.create_task( + streaming_conversation.send_speech_to_output( + message="Hi there", + synthesis_result=synthesis_result, + stop_event=stop_event, + seconds_per_chunk=0.1, + transcript_message=transcript_message, + ) + ) + await finished_sending_chunks.wait() + stop_event.set() + streaming_conversation.output_device.interrupt_event.set() + message_sent, cut_off = await send_speech_to_output_task + streaming_conversation.output_device.terminate() + + assert message_sent != "Hi there" + assert cut_off + assert transcript_message.text != "Hi there" + assert not transcript_message.is_final diff --git a/vocode/streaming/output_device/abstract_output_device.py b/vocode/streaming/output_device/abstract_output_device.py index 14002075d..985077f63 100644 --- a/vocode/streaming/output_device/abstract_output_device.py +++ b/vocode/streaming/output_device/abstract_output_device.py @@ -1,6 +1,7 @@ import asyncio from abc import abstractmethod +from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.output_device.audio_chunk import AudioChunk from vocode.streaming.utils.worker import AsyncWorker, InterruptibleEvent @@ -14,7 +15,7 @@ class AbstractOutputDevice(AsyncWorker[InterruptibleEvent[AudioChunk]]): - if the interruptible event marker is set, then it must also mark the chunk as interrupted """ - def __init__(self, sampling_rate: int, audio_encoding): + def __init__(self, sampling_rate: int, audio_encoding: AudioEncoding): super().__init__(input_queue=asyncio.Queue()) self.sampling_rate = sampling_rate self.audio_encoding = audio_encoding diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index 90e50d997..f8c255b66 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -928,9 +928,11 @@ def _on_interrupt(): first_chunk_span = self._maybe_create_first_chunk_span(synthesis_result, message) audio_chunks: List[AudioChunk] = [] processed_events: List[asyncio.Event] = [] + interrupted_before_all_chunks_sent = False async for chunk_idx, chunk_result in enumerate_async_iter(synthesis_result.chunk_generator): if stop_event.is_set(): logger.debug("Interrupted before all chunks were sent") + interrupted_before_all_chunks_sent = True break processed_event = asyncio.Event() audio_chunk = AudioChunk( @@ -966,7 +968,9 @@ def _on_interrupt(): ), None, ) - cut_off = maybe_first_interrupted_audio_chunk is not None + cut_off = ( + interrupted_before_all_chunks_sent or maybe_first_interrupted_audio_chunk is not None + ) if ( transcript_message and not cut_off ): # if the audio was not cut off, we can set the transcript message to the full message From 1678d39de3229d4d1e69c3f95c23c954843f48be Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 14:56:24 -0700 Subject: [PATCH 07/14] make send_speech_to_output more efficient --- vocode/streaming/streaming_conversation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index f8c255b66..0e84e4285 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -958,7 +958,8 @@ def _on_interrupt(): logger.debug("Finished sending chunks to the output device") - await asyncio.gather(*(processed_event.wait() for processed_event in processed_events)) + if processed_events: + await processed_events[-1].wait() maybe_first_interrupted_audio_chunk = next( ( From 29ccecaca5c5f47138946425eaba31036af0c414 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 14:56:37 -0700 Subject: [PATCH 08/14] adds tests for rate limit interruptions output device --- ..._rate_limit_interruptions_output_device.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py b/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py index e69de29bb..44c6679ab 100644 --- a/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py +++ b/tests/streaming/output_device/test_rate_limit_interruptions_output_device.py @@ -0,0 +1,68 @@ +import asyncio + +import pytest + +from vocode.streaming.models.audio import AudioEncoding +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) +from vocode.streaming.utils.worker import InterruptibleEvent + + +class DummyRateLimitInterruptionsOutputDevice(RateLimitInterruptionsOutputDevice): + async def play(self, chunk: bytes): + pass + + +@pytest.mark.asyncio +async def test_calls_callbacks(): + output_device = DummyRateLimitInterruptionsOutputDevice( + sampling_rate=16000, audio_encoding=AudioEncoding.LINEAR16 + ) + + played_event = asyncio.Event() + interrupted_event = asyncio.Event() + uninterruptible_played_event = asyncio.Event() + + def on_play(): + played_event.set() + + def on_interrupt(): + interrupted_event.set() + + def uninterruptible_on_play(): + uninterruptible_played_event.set() + + played_audio_chunk = AudioChunk(data=b"") + played_audio_chunk.on_play = on_play + + interrupted_audio_chunk = AudioChunk(data=b"") + interrupted_audio_chunk.on_interrupt = on_interrupt + + uninterruptible_audio_chunk = AudioChunk(data=b"") + uninterruptible_audio_chunk.on_play = uninterruptible_on_play + + interruptible_event = InterruptibleEvent(payload=interrupted_audio_chunk, is_interruptible=True) + interruptible_event.interruption_event.set() + + uninterruptible_event = InterruptibleEvent( + payload=uninterruptible_audio_chunk, is_interruptible=False + ) + uninterruptible_event.interruption_event.set() + + output_device.consume_nonblocking(InterruptibleEvent(payload=played_audio_chunk)) + output_device.consume_nonblocking(interruptible_event) + output_device.consume_nonblocking(uninterruptible_event) + output_device.start() + + await played_event.wait() + assert played_audio_chunk.state == ChunkState.PLAYED + + await interrupted_event.wait() + assert interrupted_audio_chunk.state == ChunkState.INTERRUPTED + + await uninterruptible_played_event.wait() + assert uninterruptible_audio_chunk.state == ChunkState.PLAYED + + output_device.terminate() From 418db81e087ae472b2a042fca18a932581fb17ca Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 16:12:37 -0700 Subject: [PATCH 09/14] makes some variables private and also makes the chunk id coming back from the mark match the incoming audio chunk --- .../output_device/twilio_output_device.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/vocode/streaming/output_device/twilio_output_device.py b/vocode/streaming/output_device/twilio_output_device.py index 87e9ddf3e..062d998dd 100644 --- a/vocode/streaming/output_device/twilio_output_device.py +++ b/vocode/streaming/output_device/twilio_output_device.py @@ -31,16 +31,18 @@ def __init__(self, ws: Optional[WebSocket] = None, stream_sid: Optional[str] = N self.stream_sid = stream_sid self.active = True - self.twilio_events_queue: asyncio.Queue[str] = asyncio.Queue() - self.mark_message_queue: asyncio.Queue[MarkMessage] = asyncio.Queue() - self.unprocessed_audio_chunks_queue: asyncio.Queue[InterruptibleEvent[AudioChunk]] = ( + self._twilio_events_queue: asyncio.Queue[str] = asyncio.Queue() + self._mark_message_queue: asyncio.Queue[MarkMessage] = asyncio.Queue() + self._unprocessed_audio_chunks_queue: asyncio.Queue[InterruptibleEvent[AudioChunk]] = ( asyncio.Queue() ) def consume_nonblocking(self, item: InterruptibleEvent[AudioChunk]): if not item.is_interrupted(): - self._send_audio_chunk_and_mark(item.payload.data) - self.unprocessed_audio_chunks_queue.put_nowait(item) + self._send_audio_chunk_and_mark( + chunk=item.payload.data, chunk_id=str(item.payload.chunk_id) + ) + self._unprocessed_audio_chunks_queue.put_nowait(item) else: audio_chunk = item.payload audio_chunk.on_interrupt() @@ -50,12 +52,12 @@ def interrupt(self): self._send_clear_message() def enqueue_mark_message(self, mark_message: MarkMessage): - self.mark_message_queue.put_nowait(mark_message) + self._mark_message_queue.put_nowait(mark_message) async def _send_twilio_messages(self): while True: try: - twilio_event = await self.twilio_events_queue.get() + twilio_event = await self._twilio_events_queue.get() except asyncio.CancelledError: return if self.ws.application_state == WebSocketState.DISCONNECTED: @@ -68,8 +70,8 @@ async def _process_mark_messages(self): # mark messages are tagged with the chunk ID that is attached to the audio chunk # but they are guaranteed to come in the same order as the audio chunks, and we # don't need to build resiliency there - await self.mark_message_queue.get() - item = await self.unprocessed_audio_chunks_queue.get() + await self._mark_message_queue.get() + item = await self._unprocessed_audio_chunks_queue.get() except asyncio.CancelledError: return @@ -95,25 +97,25 @@ async def _run_loop(self): ) await asyncio.gather(send_twilio_messages_task, process_mark_messages_task) - def _send_audio_chunk_and_mark(self, chunk: bytes): + def _send_audio_chunk_and_mark(self, chunk: bytes, chunk_id: str): media_message = { "event": "media", "streamSid": self.stream_sid, "media": {"payload": base64.b64encode(chunk).decode("utf-8")}, } - self.twilio_events_queue.put_nowait(json.dumps(media_message)) + self._twilio_events_queue.put_nowait(json.dumps(media_message)) mark_message = { "event": "mark", "streamSid": self.stream_sid, "mark": { - "name": str(uuid.uuid4()), + "name": chunk_id, }, } - self.twilio_events_queue.put_nowait(json.dumps(mark_message)) + self._twilio_events_queue.put_nowait(json.dumps(mark_message)) def _send_clear_message(self): clear_message = { "event": "clear", "streamSid": self.stream_sid, } - self.twilio_events_queue.put_nowait(json.dumps(clear_message)) + self._twilio_events_queue.put_nowait(json.dumps(clear_message)) From 385f63b149665d403e4b660383e58bf16200c438 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 16:12:49 -0700 Subject: [PATCH 10/14] adds twilio output device tests --- .../test_twilio_output_device.py | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/tests/streaming/output_device/test_twilio_output_device.py b/tests/streaming/output_device/test_twilio_output_device.py index e69de29bb..a2c4ca3e5 100644 --- a/tests/streaming/output_device/test_twilio_output_device.py +++ b/tests/streaming/output_device/test_twilio_output_device.py @@ -0,0 +1,139 @@ +import asyncio +import base64 +import json + +import pytest +from pytest_mock import MockerFixture + +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState +from vocode.streaming.output_device.twilio_output_device import ( + ChunkFinishedMarkMessage, + TwilioOutputDevice, +) +from vocode.streaming.utils.worker import InterruptibleEvent + + +@pytest.fixture +def mock_ws(mocker: MockerFixture): + return mocker.AsyncMock() + + +@pytest.fixture +def mock_stream_sid(): + return "stream_sid" + + +@pytest.fixture +def twilio_output_device(mock_ws, mock_stream_sid): + return TwilioOutputDevice(ws=mock_ws, stream_sid=mock_stream_sid) + + +@pytest.mark.asyncio +async def test_calls_play_callbacks(twilio_output_device: TwilioOutputDevice): + played_event = asyncio.Event() + + def on_play(): + played_event.set() + + audio_chunk = AudioChunk(data=b"") + audio_chunk.on_play = on_play + + twilio_output_device.consume_nonblocking(InterruptibleEvent(payload=audio_chunk)) + twilio_output_device.start() + twilio_output_device.enqueue_mark_message( + ChunkFinishedMarkMessage(chunk_id=str(audio_chunk.chunk_id)) + ) + + await played_event.wait() + assert audio_chunk.state == ChunkState.PLAYED + + media_message = json.loads(twilio_output_device.ws.send_text.call_args_list[0][0][0]) + assert media_message["streamSid"] == twilio_output_device.stream_sid + assert media_message["media"] == {"payload": base64.b64encode(audio_chunk.data).decode("utf-8")} + + mark_message = json.loads(twilio_output_device.ws.send_text.call_args_list[1][0][0]) + assert mark_message["streamSid"] == twilio_output_device.stream_sid + assert mark_message["mark"]["name"] == str(audio_chunk.chunk_id) + + twilio_output_device.terminate() + + +@pytest.mark.asyncio +async def test_calls_interrupt_callbacks(twilio_output_device: TwilioOutputDevice): + interrupted_event = asyncio.Event() + + def on_interrupt(): + interrupted_event.set() + + audio_chunk = AudioChunk(data=b"") + audio_chunk.on_interrupt = on_interrupt + + interruptible_event = InterruptibleEvent(payload=audio_chunk, is_interruptible=True) + + twilio_output_device.consume_nonblocking(interruptible_event) + # we start the twilio events task and the mark messages task manually to test this particular case + + # step 1: media is sent into the websocket + send_twilio_messages_task = asyncio.create_task(twilio_output_device._send_twilio_messages()) + + while not twilio_output_device._twilio_events_queue.empty(): + await asyncio.sleep(0.1) + + # step 2: we get an interrupt + interruptible_event.interrupt() + twilio_output_device.interrupt() + + # note: this means that the time between the events being interrupted and the clear message being sent, chunks + # will be marked interrupted - this is OK since the clear message is sent almost instantaneously once queued + # this is required because it stops queueing new chunks to be sent to the WS immediately + + while not twilio_output_device._twilio_events_queue.empty(): + await asyncio.sleep(0.1) + + # step 3: we get a mark message for the interrupted audio chunk after the clear message + twilio_output_device.enqueue_mark_message( + ChunkFinishedMarkMessage(chunk_id=str(audio_chunk.chunk_id)) + ) + process_mark_messages_task = asyncio.create_task(twilio_output_device._process_mark_messages()) + + await interrupted_event.wait() + assert audio_chunk.state == ChunkState.INTERRUPTED + + media_message = json.loads(twilio_output_device.ws.send_text.call_args_list[0][0][0]) + assert media_message["streamSid"] == twilio_output_device.stream_sid + assert media_message["media"] == {"payload": base64.b64encode(audio_chunk.data).decode("utf-8")} + + mark_message = json.loads(twilio_output_device.ws.send_text.call_args_list[1][0][0]) + assert mark_message["streamSid"] == twilio_output_device.stream_sid + assert mark_message["mark"]["name"] == str(audio_chunk.chunk_id) + + clear_message = json.loads(twilio_output_device.ws.send_text.call_args_list[2][0][0]) + assert clear_message["streamSid"] == twilio_output_device.stream_sid + assert clear_message["event"] == "clear" + + send_twilio_messages_task.cancel() + process_mark_messages_task.cancel() + + +@pytest.mark.asyncio +async def test_interrupted_audio_chunks_are_not_sent_but_are_marked_interrupted( + twilio_output_device: TwilioOutputDevice, +): + interrupted_event = asyncio.Event() + + def on_interrupt(): + interrupted_event.set() + + audio_chunk = AudioChunk(data=b"") + audio_chunk.on_interrupt = on_interrupt + + interruptible_event = InterruptibleEvent(payload=audio_chunk, is_interruptible=True) + interruptible_event.interrupt() + + twilio_output_device.consume_nonblocking(interruptible_event) + twilio_output_device.start() + + await interrupted_event.wait() + assert audio_chunk.state == ChunkState.INTERRUPTED + + twilio_output_device.ws.send_text.assert_not_called() From 003623edb06e6691b83148d0e89c656ada0aeb28 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 16:16:50 -0700 Subject: [PATCH 11/14] make typing better for output devices --- vocode/streaming/utils/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index 15070ea60..cdbff41ac 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -15,7 +15,7 @@ class AsyncWorker(Generic[WorkerInputType]): def __init__( self, - input_queue: asyncio.Queue, + input_queue: asyncio.Queue[WorkerInputType], output_queue: asyncio.Queue = asyncio.Queue(), ) -> None: self.worker_task: Optional[asyncio.Task] = None From 18e9a041ac72529dddc04b9807b6b0951ed5c808 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jun 2024 16:18:53 -0700 Subject: [PATCH 12/14] fix mypy --- vocode/streaming/transcriber/base_transcriber.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vocode/streaming/transcriber/base_transcriber.py b/vocode/streaming/transcriber/base_transcriber.py index 8de24ace7..3e89d23d1 100644 --- a/vocode/streaming/transcriber/base_transcriber.py +++ b/vocode/streaming/transcriber/base_transcriber.py @@ -58,7 +58,7 @@ def terminate(self): pass -class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): +class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) AsyncWorker.__init__(self, self.input_queue, self.output_queue) @@ -67,7 +67,7 @@ def terminate(self): AsyncWorker.terminate(self) -class BaseThreadAsyncTranscriber( +class BaseThreadAsyncTranscriber( # type: ignore AbstractTranscriber[TranscriberConfigType], ThreadAsyncWorker[bytes] ): def __init__(self, transcriber_config: TranscriberConfigType): From 768aa0767db4686375639fa652119d4b8515a3fb Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 20 Jun 2024 17:02:56 -0700 Subject: [PATCH 13/14] [DOW-113] deprecate output queue and manually attach workers to each other (#569) * deprecate output queues * fix quickstarts * fix mypy * fix tests * temporarily allow test to run on vocode-core-0.1.0 --- playground/streaming/agent/chat.py | 28 ++-- .../streaming/transcriber/transcribe.py | 18 ++- tests/streaming/agent/test_base_agent.py | 48 +++---- .../streaming/test_streaming_conversation.py | 47 ++++--- vocode/streaming/action/worker.py | 43 ++++--- vocode/streaming/agent/base_agent.py | 121 ++++++++++-------- .../agent/websocket_user_implemented_agent.py | 14 +- .../output_device/blocking_speaker_output.py | 4 +- vocode/streaming/streaming_conversation.py | 116 ++++++++--------- .../streaming/synthesizer/base_synthesizer.py | 9 +- .../streaming/synthesizer/miniaudio_worker.py | 26 +++- .../abstract_phone_conversation.py | 6 - .../transcriber/assembly_ai_transcriber.py | 4 +- .../transcriber/azure_transcriber.py | 4 +- .../streaming/transcriber/base_transcriber.py | 38 +++++- .../transcriber/deepgram_transcriber.py | 6 +- .../transcriber/gladia_transcriber.py | 4 +- .../transcriber/google_transcriber.py | 2 +- .../transcriber/rev_ai_transcriber.py | 6 +- .../transcriber/whisper_cpp_transcriber.py | 2 +- vocode/streaming/utils/worker.py | 83 +++++------- 21 files changed, 337 insertions(+), 292 deletions(-) diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index 659443c9a..e93b9eeda 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -21,12 +21,13 @@ from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.transcript import Transcript from vocode.streaming.utils.state_manager import AbstractConversationStateManager -from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent +from vocode.streaming.utils.worker import InterruptibleAgentResponseEvent, QueueConsumer load_dotenv() from vocode.streaming.agent import ChatGPTAgent from vocode.streaming.agent.base_agent import ( + AgentResponse, AgentResponseMessage, AgentResponseType, BaseAgent, @@ -96,6 +97,11 @@ async def run_agent( ): ended = False conversation_id = create_conversation_id() + agent_response_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] = ( + asyncio.Queue() + ) + agent_consumer = QueueConsumer(input_queue=agent_response_queue) + agent.agent_responses_consumer = agent_consumer async def receiver(): nonlocal ended @@ -106,7 +112,7 @@ async def receiver(): while not ended: try: - event = await agent.get_output_queue().get() + event = await agent_response_queue.get() response = event.payload if response.type == AgentResponseType.FILLER_AUDIO: print("Would have sent filler audio") @@ -152,6 +158,13 @@ async def receiver(): break async def sender(): + if agent.agent_config.initial_message is not None: + agent.agent_responses_consumer.consume_nonblocking( + InterruptibleAgentResponseEvent( + payload=AgentResponseMessage(message=agent.agent_config.initial_message), + agent_response_tracker=asyncio.Event(), + ) + ) while not ended: try: message = await asyncio.get_event_loop().run_in_executor( @@ -175,10 +188,10 @@ async def sender(): actions_worker = None if isinstance(agent, ChatGPTAgent): actions_worker = ActionsWorker( - input_queue=agent.actions_queue, - output_queue=agent.get_input_queue(), action_factory=agent.action_factory, ) + actions_worker.consumer = agent + agent.actions_consumer = actions_worker actions_worker.attach_conversation_state_manager(agent.conversation_state_manager) actions_worker.start() @@ -215,13 +228,6 @@ async def agent_main(): ) agent.attach_conversation_state_manager(DummyConversationManager()) agent.attach_transcript(transcript) - if agent.agent_config.initial_message is not None: - agent.output_queue.put_nowait( - InterruptibleAgentResponseEvent( - payload=AgentResponseMessage(message=agent.agent_config.initial_message), - agent_response_tracker=asyncio.Event(), - ) - ) agent.start() try: diff --git a/playground/streaming/transcriber/transcribe.py b/playground/streaming/transcriber/transcribe.py index 111c27246..0a9b72cc5 100644 --- a/playground/streaming/transcriber/transcribe.py +++ b/playground/streaming/transcriber/transcribe.py @@ -5,6 +5,15 @@ DeepgramEndpointingConfig, DeepgramTranscriber, ) +from vocode.streaming.utils.worker import AsyncWorker + + +class TranscriptionPrinter(AsyncWorker[Transcription]): + async def _run_loop(self): + while True: + transcription: Transcription = await self.input_queue.get() + print(transcription) + if __name__ == "__main__": import asyncio @@ -13,11 +22,6 @@ load_dotenv() - async def print_output(transcriber: BaseTranscriber): - while True: - transcription: Transcription = await transcriber.output_queue.get() - print(transcription) - async def listen(): microphone_input = MicrophoneInput.from_default_device() @@ -28,7 +32,9 @@ async def listen(): ) ) transcriber.start() - asyncio.create_task(print_output(transcriber)) + transcription_printer = TranscriptionPrinter() + transcriber.consumer = transcription_printer + transcription_printer.start() print("Start speaking...press Ctrl+C to end. ") while True: chunk = await microphone_input.get_audio() diff --git a/tests/streaming/agent/test_base_agent.py b/tests/streaming/agent/test_base_agent.py index c6a0adb95..a794db7ef 100644 --- a/tests/streaming/agent/test_base_agent.py +++ b/tests/streaming/agent/test_base_agent.py @@ -19,7 +19,11 @@ from vocode.streaming.models.transcriber import Transcription from vocode.streaming.models.transcript import Transcript from vocode.streaming.utils.state_manager import ConversationStateManager -from vocode.streaming.utils.worker import InterruptibleEvent +from vocode.streaming.utils.worker import ( + InterruptibleAgentResponseEvent, + InterruptibleEvent, + QueueConsumer, +) @pytest.fixture(autouse=True) @@ -51,11 +55,16 @@ def _create_agent( return agent -async def _consume_until_end_of_turn(agent: BaseAgent, timeout: float = 0.1) -> List[AgentResponse]: +async def _consume_until_end_of_turn( + agent_consumer: QueueConsumer[InterruptibleAgentResponseEvent[AgentResponse]], + timeout: float = 0.1, +) -> List[AgentResponse]: agent_responses = [] try: while True: - agent_response = await asyncio.wait_for(agent.output_queue.get(), timeout=timeout) + agent_response = await asyncio.wait_for( + agent_consumer.input_queue.get(), timeout=timeout + ) agent_responses.append(agent_response.payload) if isinstance(agent_response.payload, AgentResponseMessage) and isinstance( agent_response.payload.message, EndOfTurn @@ -127,37 +136,10 @@ async def test_generate_responses(mocker: MockerFixture): agent, Transcription(message="Hello?", confidence=1.0, is_final=True), ) + agent_consumer = QueueConsumer() + agent.agent_responses_consumer = agent_consumer agent.start() - agent_responses = await _consume_until_end_of_turn(agent) - agent.terminate() - - messages = [response.message for response in agent_responses] - - assert messages == [BaseMessage(text="Hi, how are you doing today?"), EndOfTurn()] - - -@pytest.mark.asyncio -async def test_generate_response(mocker: MockerFixture): - agent_config = ChatGPTAgentConfig( - prompt_preamble="Have a pleasant conversation about life", - generate_responses=True, - ) - agent = _create_agent(mocker, agent_config) - _mock_generate_response( - mocker, - agent, - [ - GeneratedResponse( - message=BaseMessage(text="Hi, how are you doing today?"), is_interruptible=True - ) - ], - ) - _send_transcription( - agent, - Transcription(message="Hello?", confidence=1.0, is_final=True), - ) - agent.start() - agent_responses = await _consume_until_end_of_turn(agent) + agent_responses = await _consume_until_end_of_turn(agent_consumer) agent.terminate() messages = [response.message for response in agent_responses] diff --git a/tests/streaming/test_streaming_conversation.py b/tests/streaming/test_streaming_conversation.py index 1d4ef026d..570d25093 100644 --- a/tests/streaming/test_streaming_conversation.py +++ b/tests/streaming/test_streaming_conversation.py @@ -18,7 +18,7 @@ from vocode.streaming.models.transcriber import Transcription from vocode.streaming.models.transcript import ActionStart, Message, Transcript from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult -from vocode.streaming.utils.worker import AsyncWorker +from vocode.streaming.utils.worker import QueueConsumer class ShouldIgnoreUtteranceTestCase(BaseModel): @@ -27,9 +27,9 @@ class ShouldIgnoreUtteranceTestCase(BaseModel): expected: bool -async def _consume_worker_output(worker: AsyncWorker, timeout: float = 0.1): +async def _get_from_consumer_queue_if_exists(queue_consumer: QueueConsumer, timeout: float = 0.1): try: - return await asyncio.wait_for(worker.output_queue.get(), timeout=timeout) + return await asyncio.wait_for(queue_consumer.input_queue.get(), timeout=timeout) except asyncio.TimeoutError: return None @@ -174,8 +174,6 @@ def test_should_ignore_utterance( conversation = mocker.MagicMock() transcriptions_worker = StreamingConversation.TranscriptionsWorker( - input_queue=mocker.MagicMock(), - output_queue=mocker.MagicMock(), conversation=conversation, interruptible_event_factory=mocker.MagicMock(), ) @@ -253,7 +251,9 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message( is_final=True, ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert not streaming_conversation.broadcast_interrupt.called streaming_conversation.transcript.add_bot_message( @@ -269,8 +269,8 @@ async def test_transcriptions_worker_ignores_utterances_before_initial_message( ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert transcription_agent_input.payload.transcription.message == "hi, who is this?" assert streaming_conversation.broadcast_interrupt.called @@ -310,7 +310,10 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( is_final=False, ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer + + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert not streaming_conversation.broadcast_interrupt.called # ignored for length of response streaming_conversation.transcript.event_logs[-1].text = ( @@ -325,7 +328,7 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert not streaming_conversation.broadcast_interrupt.called # ignored for length of response streaming_conversation.transcriptions_worker.consume_nonblocking( @@ -336,8 +339,8 @@ async def test_transcriptions_worker_ignores_associated_ignored_utterance( ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert ( transcription_agent_input.payload.transcription.message == "I have not yet gotten a chance." @@ -377,7 +380,10 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts( ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer + + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert streaming_conversation.broadcast_interrupt.called streaming_conversation.transcriptions_worker.consume_nonblocking( @@ -388,8 +394,8 @@ async def test_transcriptions_worker_interrupts_on_interim_transcripts( ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert ( transcription_agent_input.payload.transcription.message @@ -421,7 +427,10 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun is_final=False, ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + transcriptions_worker_consumer = QueueConsumer() + streaming_conversation.transcriptions_worker.consumer = transcriptions_worker_consumer + + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert streaming_conversation.broadcast_interrupt.called streaming_conversation.transcriptions_worker.consume_nonblocking( @@ -431,8 +440,8 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun is_final=True, ), ) - transcription_agent_input = await _consume_worker_output( - streaming_conversation.transcriptions_worker + transcription_agent_input = await _get_from_consumer_queue_if_exists( + transcriptions_worker_consumer ) assert transcription_agent_input.payload.transcription.message == "Sorry, what?" assert streaming_conversation.broadcast_interrupt.called @@ -449,7 +458,7 @@ async def test_transcriptions_worker_interrupts_immediately_before_bot_has_begun ), ) - assert await _consume_worker_output(streaming_conversation.transcriptions_worker) is None + assert await _get_from_consumer_queue_if_exists(transcriptions_worker_consumer) is None assert streaming_conversation.broadcast_interrupt.called streaming_conversation.transcriptions_worker.terminate() diff --git a/vocode/streaming/action/worker.py b/vocode/streaming/action/worker.py index 26c51b240..b256489cb 100644 --- a/vocode/streaming/action/worker.py +++ b/vocode/streaming/action/worker.py @@ -12,6 +12,7 @@ ) from vocode.streaming.utils.state_manager import AbstractConversationStateManager from vocode.streaming.utils.worker import ( + AbstractWorker, InterruptibleEvent, InterruptibleEventFactory, InterruptibleWorker, @@ -19,16 +20,14 @@ class ActionsWorker(InterruptibleWorker): + consumer: AbstractWorker[InterruptibleEvent[ActionResultAgentInput]] + def __init__( self, action_factory: AbstractActionFactory, - input_queue: asyncio.Queue[InterruptibleEvent[ActionInput]], - output_queue: asyncio.Queue[InterruptibleEvent[AgentInput]], interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(), ): super().__init__( - input_queue=input_queue, - output_queue=output_queue, interruptible_event_factory=interruptible_event_factory, ) self.action_factory = action_factory @@ -43,22 +42,24 @@ async def process(self, item: InterruptibleEvent[ActionInput]): action = self.action_factory.create_action(action_input.action_config) action.attach_conversation_state_manager(self.conversation_state_manager) action_output = await action.run(action_input) - self.produce_interruptible_event_nonblocking( - ActionResultAgentInput( - conversation_id=action_input.conversation_id, - action_input=action_input, - action_output=action_output, - vonage_uuid=( - action_input.vonage_uuid - if isinstance(action_input, VonagePhoneConversationActionInput) - else None - ), - twilio_sid=( - action_input.twilio_sid - if isinstance(action_input, TwilioPhoneConversationActionInput) - else None + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_event( + ActionResultAgentInput( + conversation_id=action_input.conversation_id, + action_input=action_input, + action_output=action_output, + vonage_uuid=( + action_input.vonage_uuid + if isinstance(action_input, VonagePhoneConversationActionInput) + else None + ), + twilio_sid=( + action_input.twilio_sid + if isinstance(action_input, TwilioPhoneConversationActionInput) + else None + ), + is_quiet=action.quiet, ), - is_quiet=action.quiet, - ), - is_interruptible=False, + is_interruptible=False, + ) ) diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index f6a56a7e7..7dd5a3fe7 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -38,6 +38,7 @@ from vocode.streaming.utils import unrepeating_randomizer from vocode.streaming.utils.speed_manager import SpeedManager from vocode.streaming.utils.worker import ( + AbstractWorker, InterruptibleAgentResponseEvent, InterruptibleEvent, InterruptibleEventFactory, @@ -154,6 +155,9 @@ def get_cut_off_response(self) -> str: class BaseAgent(AbstractAgent[AgentConfigType], InterruptibleWorker): + agent_responses_consumer: AbstractWorker[InterruptibleAgentResponseEvent[AgentResponse]] + actions_consumer: Optional[AbstractWorker[InterruptibleEvent[ActionInput]]] + def __init__( self, agent_config: AgentConfigType, @@ -161,18 +165,12 @@ def __init__( interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(), ): self.input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]] = asyncio.Queue() - self.output_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] = ( - asyncio.Queue() - ) AbstractAgent.__init__(self, agent_config=agent_config) InterruptibleWorker.__init__( self, - input_queue=self.input_queue, - output_queue=self.output_queue, interruptible_event_factory=interruptible_event_factory, ) self.action_factory = action_factory - self.actions_queue: asyncio.Queue[InterruptibleEvent[ActionInput]] = asyncio.Queue() self.transcript: Optional[Transcript] = None self.functions = self.get_functions() if self.agent_config.actions else None @@ -211,11 +209,6 @@ def get_input_queue( ) -> asyncio.Queue[InterruptibleEvent[AgentInput]]: return self.input_queue - def get_output_queue( - self, - ) -> asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]]: - return self.output_queue - def is_first_response(self): assert self.transcript is not None @@ -299,14 +292,16 @@ async def handle_generate_response( continue agent_response_tracker = agent_input.agent_response_tracker or asyncio.Event() - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=generated_response.message, - is_first=is_first_response_of_turn, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=generated_response.message, + is_first=is_first_response_of_turn, + ), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off + and generated_response.is_interruptible, + agent_response_tracker=agent_response_tracker, ), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off - and generated_response.is_interruptible, - agent_response_tracker=agent_response_tracker, ) if isinstance(generated_response.message, BaseMessage): responses_buffer = f"{responses_buffer} {generated_response.message.text}" @@ -330,14 +325,15 @@ async def handle_generate_response( end_of_turn_agent_response_tracker = ( agent_input.agent_response_tracker or asyncio.Event() ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=EndOfTurn(), - is_first=is_first_response_of_turn, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=EndOfTurn(), + is_first=is_first_response_of_turn, + ), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + agent_response_tracker=end_of_turn_agent_response_tracker, ), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off - and generated_response.is_interruptible, - agent_response_tracker=end_of_turn_agent_response_tracker, ) phrase_trigger_match = ( @@ -374,13 +370,17 @@ async def handle_respond(self, transcription: Transcription, conversation_id: st response = None return True if response: - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=BaseMessage(text=response)), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=BaseMessage(text=response)), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + ) ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=EndOfTurn()), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=EndOfTurn()), + is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + ) ) return should_stop else: @@ -408,15 +408,19 @@ async def process(self, item: InterruptibleEvent[AgentInput]): logger.debug("Action is quiet, skipping response generation") return if agent_input.action_output.canned_response is not None: - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=agent_input.action_output.canned_response, - is_sole_text_chunk=True, - ), - is_interruptible=True, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=agent_input.action_output.canned_response, + is_sole_text_chunk=True, + ), + is_interruptible=True, + ) ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=EndOfTurn()), + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=EndOfTurn()), + ) ) return transcription = Transcription( @@ -432,8 +436,10 @@ async def process(self, item: InterruptibleEvent[AgentInput]): return if self.agent_config.send_filler_audio: - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseFillerAudio(), + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseFillerAudio(), + ) ) logger.debug("Responding to transcription") @@ -451,7 +457,11 @@ async def process(self, item: InterruptibleEvent[AgentInput]): if should_stop: logger.debug("Agent requested to stop") - self.produce_interruptible_agent_response_event_nonblocking(AgentResponseStop()) + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseStop(), + ) + ) return except asyncio.CancelledError: pass @@ -478,16 +488,20 @@ async def call_function(self, function_call: FunctionCall, agent_input: AgentInp if "user_message" in params: user_message = params["user_message"] user_message_tracker = asyncio.Event() - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage( - message=BaseMessage(text=user_message), - is_sole_text_chunk=True, - ), - is_interruptible=action.is_interruptible, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage( + message=BaseMessage(text=user_message), + is_sole_text_chunk=True, + ), + is_interruptible=action.is_interruptible, + ) ) - self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=EndOfTurn()), - agent_response_tracker=user_message_tracker, + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseMessage(message=EndOfTurn()), + agent_response_tracker=user_message_tracker, + ) ) action_input = self.create_action_input(action, agent_input, params, user_message_tracker) self.enqueue_action_input(action, action_input, agent_input.conversation_id) @@ -534,6 +548,9 @@ def enqueue_action_input( action_input: ActionInput, conversation_id: str, ): + if self.actions_consumer is None: + logger.warning("No actions consumer attached, skipping action") + return event = self.interruptible_event_factory.create_interruptible_event( action_input, is_interruptible=action.is_interruptible, @@ -543,7 +560,7 @@ def enqueue_action_input( action_input=action_input, conversation_id=conversation_id, ) - self.actions_queue.put_nowait(event) + self.actions_consumer.consume_nonblocking(event) async def respond( self, diff --git a/vocode/streaming/agent/websocket_user_implemented_agent.py b/vocode/streaming/agent/websocket_user_implemented_agent.py index 1a09e199d..f8233588f 100644 --- a/vocode/streaming/agent/websocket_user_implemented_agent.py +++ b/vocode/streaming/agent/websocket_user_implemented_agent.py @@ -27,7 +27,6 @@ class WebSocketUserImplementedAgent(BaseAgent[WebSocketUserImplementedAgentConfig]): input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]] - output_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]] def __init__( self, @@ -63,8 +62,11 @@ def _handle_incoming_socket_message(self, message: WebSocketAgentMessage) -> Non raise Exception("Unknown Socket message type") logger.info("Putting interruptible agent response event in output queue") - self.produce_interruptible_agent_response_event_nonblocking( - agent_response, self.get_agent_config().allow_agent_to_be_cut_off + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + agent_response, + is_interruptible=self.get_agent_config().allow_agent_to_be_cut_off, + ) ) async def _process(self) -> None: @@ -137,5 +139,9 @@ async def receiver(ws: WebSocketClientProtocol) -> None: await asyncio.gather(sender(ws), receiver(ws)) def terminate(self): - self.produce_interruptible_agent_response_event_nonblocking(AgentResponseStop()) + self.agent_responses_consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + AgentResponseStop() + ) + ) super().terminate() diff --git a/vocode/streaming/output_device/blocking_speaker_output.py b/vocode/streaming/output_device/blocking_speaker_output.py index 637a41d4b..8eee5b781 100644 --- a/vocode/streaming/output_device/blocking_speaker_output.py +++ b/vocode/streaming/output_device/blocking_speaker_output.py @@ -20,7 +20,7 @@ def __init__(self, *, device_info: dict, sampling_rate: int): self.sampling_rate = sampling_rate self.device_info = device_info self.input_queue: asyncio.Queue[bytes] = asyncio.Queue() - super().__init__(self.input_queue) + super().__init__() self.stream = sd.OutputStream( channels=1, samplerate=self.sampling_rate, @@ -28,7 +28,7 @@ def __init__(self, *, device_info: dict, sampling_rate: int): device=int(self.device_info["index"]), ) self._ended = False - self.input_queue.put_nowait(self.sampling_rate * b"\x00") + self.consume_nonblocking(self.sampling_rate * b"\x00") self.stream.start() def _run_loop(self): diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index 0e84e4285..d8e9f34a2 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -66,11 +66,13 @@ from vocode.streaming.utils.speed_manager import SpeedManager from vocode.streaming.utils.state_manager import ConversationStateManager from vocode.streaming.utils.worker import ( + AbstractWorker, AsyncQueueWorker, InterruptibleAgentResponseEvent, InterruptibleAgentResponseWorker, InterruptibleEvent, InterruptibleEventFactory, + InterruptibleWorker, ) from vocode.utils.sentry_utils import ( CustomSentrySpans, @@ -143,16 +145,14 @@ class TranscriptionsWorker(AsyncQueueWorker[Transcription]): """Processes all transcriptions: sends an interrupt if needed and sends final transcriptions to the output queue""" + consumer: AbstractWorker[InterruptibleEvent[Transcription]] + def __init__( self, - input_queue: asyncio.Queue[Transcription], - output_queue: asyncio.Queue[InterruptibleEvent[AgentInput]], conversation: "StreamingConversation", interruptible_event_factory: InterruptibleEventFactory, ): - super().__init__(input_queue, output_queue) - self.input_queue = input_queue - self.output_queue = output_queue + super().__init__() self.conversation = conversation self.interruptible_event_factory = interruptible_event_factory self.in_interrupt_endpointing_config = False @@ -313,9 +313,9 @@ async def process(self, transcription: Transcription): agent_response_tracker=agent_response_tracker, ), ) - self.output_queue.put_nowait(event) + self.consumer.consume_nonblocking(event) - class FillerAudioWorker(InterruptibleAgentResponseWorker): + class FillerAudioWorker(InterruptibleWorker[InterruptibleAgentResponseEvent[FillerAudio]]): """ - Waits for a configured number of seconds and then sends filler audio to the output - Exposes wait_for_filler_audio_to_finish() which the AgentResponsesWorker waits on before @@ -324,11 +324,9 @@ class FillerAudioWorker(InterruptibleAgentResponseWorker): def __init__( self, - input_queue: asyncio.Queue[InterruptibleAgentResponseEvent[FillerAudio]], conversation: "StreamingConversation", ): - super().__init__(input_queue=input_queue) - self.input_queue = input_queue + super().__init__() self.conversation = conversation self.current_filler_seconds_per_chunk: Optional[int] = None self.filler_audio_started_event: Optional[threading.Event] = None @@ -369,26 +367,21 @@ async def process(self, item: InterruptibleAgentResponseEvent[FillerAudio]): except asyncio.CancelledError: pass - class AgentResponsesWorker(InterruptibleAgentResponseWorker): + class AgentResponsesWorker(InterruptibleWorker[InterruptibleAgentResponseEvent[AgentResponse]]): """Runs Synthesizer.create_speech and sends the SynthesisResult to the output queue""" + consumer: AbstractWorker[ + InterruptibleAgentResponseEvent[ + Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] + ] + ] + def __init__( self, - input_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]], - output_queue: asyncio.Queue[ - InterruptibleAgentResponseEvent[ - Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] - ] - ], conversation: "StreamingConversation", interruptible_event_factory: InterruptibleEventFactory, ): - super().__init__( - input_queue=input_queue, - output_queue=output_queue, - ) - self.input_queue = input_queue - self.output_queue = output_queue + super().__init__() self.conversation = conversation self.interruptible_event_factory = interruptible_event_factory self.chunk_size = self.conversation._get_synthesizer_chunk_size() @@ -437,10 +430,12 @@ async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]): logger.debug("Sending end of turn") if isinstance(self.conversation.synthesizer, InputStreamingSynthesizer): await self.conversation.synthesizer.handle_end_of_turn() - self.produce_interruptible_agent_response_event_nonblocking( - (agent_response_message.message, None), - is_interruptible=item.is_interruptible, - agent_response_tracker=item.agent_response_tracker, + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + (agent_response_message.message, None), + is_interruptible=item.is_interruptible, + agent_response_tracker=item.agent_response_tracker, + ), ) self.is_first_text_chunk = True return @@ -507,10 +502,12 @@ async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]): if not synthesis_result.cached and synthesis_span: synthesis_result.synthesis_total_span = synthesis_span synthesis_result.ttft_span = ttft_span - self.produce_interruptible_agent_response_event_nonblocking( - (agent_response_message.message, synthesis_result), - is_interruptible=item.is_interruptible, - agent_response_tracker=item.agent_response_tracker, + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + (agent_response_message.message, synthesis_result), + is_interruptible=item.is_interruptible, + agent_response_tracker=item.agent_response_tracker, + ), ) self.last_agent_response_tracker = item.agent_response_tracker if not isinstance(agent_response_message.message, SilenceMessage): @@ -518,20 +515,20 @@ async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]): except asyncio.CancelledError: pass - class SynthesisResultsWorker(InterruptibleAgentResponseWorker): + class SynthesisResultsWorker( + InterruptibleWorker[ + InterruptibleAgentResponseEvent[ + Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] + ] + ] + ): """Plays SynthesisResults from the output queue on the output device""" def __init__( self, - input_queue: asyncio.Queue[ - InterruptibleAgentResponseEvent[ - Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] - ] - ], conversation: "StreamingConversation", ): - super().__init__(input_queue=input_queue) - self.input_queue = input_queue + super().__init__() self.conversation = conversation self.last_transcript_message: Optional[Message] = None @@ -604,49 +601,52 @@ def __init__( self.interruptible_events: queue.Queue[InterruptibleEvent] = queue.Queue() self.interruptible_event_factory = self.QueueingInterruptibleEventFactory(conversation=self) - self.agent.set_interruptible_event_factory(self.interruptible_event_factory) self.synthesis_results_queue: asyncio.Queue[ InterruptibleAgentResponseEvent[ Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] ] ] = asyncio.Queue() - self.filler_audio_queue: asyncio.Queue[InterruptibleAgentResponseEvent[FillerAudio]] = ( - asyncio.Queue() - ) self.state_manager = self.create_state_manager() + + # Transcriptions Worker self.transcriptions_worker = self.TranscriptionsWorker( - input_queue=self.transcriber.output_queue, - output_queue=self.agent.get_input_queue(), conversation=self, interruptible_event_factory=self.interruptible_event_factory, ) + self.transcriber.consumer = self.transcriptions_worker + + # Agent + self.transcriptions_worker.consumer = self.agent + self.agent.set_interruptible_event_factory(self.interruptible_event_factory) self.agent.attach_conversation_state_manager(self.state_manager) + + # Agent Responses Worker self.agent_responses_worker = self.AgentResponsesWorker( - input_queue=self.agent.get_output_queue(), - output_queue=self.synthesis_results_queue, conversation=self, interruptible_event_factory=self.interruptible_event_factory, ) + self.agent.agent_responses_consumer = self.agent_responses_worker + + # Actions Worker self.actions_worker = None if self.agent.get_agent_config().actions: self.actions_worker = ActionsWorker( - input_queue=self.agent.actions_queue, - output_queue=self.agent.get_input_queue(), - interruptible_event_factory=self.interruptible_event_factory, action_factory=self.agent.action_factory, + interruptible_event_factory=self.interruptible_event_factory, ) self.actions_worker.attach_conversation_state_manager(self.state_manager) - self.synthesis_results_worker = self.SynthesisResultsWorker( - input_queue=self.synthesis_results_queue, - conversation=self, - ) + self.actions_worker.consumer = self.agent + self.agent.actions_consumer = self.actions_worker + + # Synthesis Results Worker + self.synthesis_results_worker = self.SynthesisResultsWorker(conversation=self) + self.agent_responses_worker.consumer = self.synthesis_results_worker + + # Filler Audio Worker self.filler_audio_worker = None self.filler_audio_config: Optional[FillerAudioConfig] = None if self.agent.get_agent_config().send_filler_audio: - self.filler_audio_worker = self.FillerAudioWorker( - input_queue=self.filler_audio_queue, - conversation=self, - ) + self.filler_audio_worker = self.FillerAudioWorker(conversation=self) self.speed_coefficient = speed_coefficient self.speed_manager = SpeedManager( diff --git a/vocode/streaming/synthesizer/base_synthesizer.py b/vocode/streaming/synthesizer/base_synthesizer.py index ea368cc15..0f0fea0d3 100644 --- a/vocode/streaming/synthesizer/base_synthesizer.py +++ b/vocode/streaming/synthesizer/base_synthesizer.py @@ -22,6 +22,7 @@ from vocode.streaming.utils import convert_wav, get_chunk_size_per_second from vocode.streaming.utils.async_requester import AsyncRequestor from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log +from vocode.streaming.utils.worker import QueueConsumer FILLER_PHRASES = [ BaseMessage(text="Um..."), @@ -410,14 +411,12 @@ async def experimental_mp3_streaming_output_generator( response: aiohttp.ClientResponse, chunk_size: int, ) -> AsyncGenerator[SynthesisResult.ChunkResult, None]: - miniaudio_worker_input_queue: asyncio.Queue[Union[bytes, None]] = asyncio.Queue() - miniaudio_worker_output_queue: asyncio.Queue[Tuple[bytes, bool]] = asyncio.Queue() + miniaudio_worker_consumer: QueueConsumer = QueueConsumer() miniaudio_worker = MiniaudioWorker( self.synthesizer_config, chunk_size, - miniaudio_worker_input_queue, - miniaudio_worker_output_queue, ) + miniaudio_worker.consumer = miniaudio_worker_consumer miniaudio_worker.start() stream_reader = response.content @@ -433,7 +432,7 @@ async def send_chunks(): # Await the output queue of the MiniaudioWorker and yield the wav chunks in another loop while True: # Get the wav chunk and the flag from the output queue of the MiniaudioWorker - wav_chunk, is_last = await miniaudio_worker.output_queue.get() + wav_chunk, is_last = await miniaudio_worker_consumer.input_queue.get() if self.synthesizer_config.should_encode_as_wav: wav_chunk = encode_as_wav(wav_chunk, self.synthesizer_config) diff --git a/vocode/streaming/synthesizer/miniaudio_worker.py b/vocode/streaming/synthesizer/miniaudio_worker.py index 92d33adc3..fcba60460 100644 --- a/vocode/streaming/synthesizer/miniaudio_worker.py +++ b/vocode/streaming/synthesizer/miniaudio_worker.py @@ -10,23 +10,39 @@ from vocode.streaming.models.synthesizer import SynthesizerConfig from vocode.streaming.utils import convert_wav from vocode.streaming.utils.mp3_helper import decode_mp3 -from vocode.streaming.utils.worker import ThreadAsyncWorker +from vocode.streaming.utils.worker import AbstractWorker, ThreadAsyncWorker class MiniaudioWorker(ThreadAsyncWorker[Union[bytes, None]]): + consumer: AbstractWorker[Tuple[bytes, bool]] + def __init__( self, synthesizer_config: SynthesizerConfig, chunk_size: int, - input_queue: asyncio.Queue[Union[bytes, None]], - output_queue: asyncio.Queue[Tuple[bytes, bool]], ) -> None: - super().__init__(input_queue, output_queue) - self.output_queue = output_queue # for typing + super().__init__() self.synthesizer_config = synthesizer_config self.chunk_size = chunk_size self._ended = False + async def run_thread_forwarding(self): + try: + await asyncio.gather( + self._forward_to_thread(), + self._forward_from_thread(), + ) + except asyncio.CancelledError: + return + + async def _forward_from_thread(self): + while True: + try: + chunk, done = await self.output_janus_queue.async_q.get() + self.consumer.consume_nonblocking((chunk, done)) + except asyncio.CancelledError: + break + def _run_loop(self): # tracks the mp3 so far current_mp3_buffer = bytearray() diff --git a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py index 1e3cf4c0f..cf876a180 100644 --- a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py @@ -66,12 +66,6 @@ def __init__( events_manager=events_manager, speed_coefficient=speed_coefficient, ) - self.transcriptions_worker = self.TranscriptionsWorker( - input_queue=self.transcriber.output_queue, - output_queue=self.agent.get_input_queue(), - conversation=self, - interruptible_event_factory=self.interruptible_event_factory, - ) self.config_manager = config_manager def attach_ws(self, ws: WebSocket): diff --git a/vocode/streaming/transcriber/assembly_ai_transcriber.py b/vocode/streaming/transcriber/assembly_ai_transcriber.py index 47fa12c36..fe10e476a 100644 --- a/vocode/streaming/transcriber/assembly_ai_transcriber.py +++ b/vocode/streaming/transcriber/assembly_ai_transcriber.py @@ -75,7 +75,7 @@ def send_audio(self, chunk): if ( len(self.buffer) / (2 * self.transcriber_config.sampling_rate) ) >= self.transcriber_config.buffer_size_seconds: - self.input_queue.put_nowait(self.buffer) + self.consume_nonblocking(self.buffer) self.buffer = bytearray() def terminate(self): @@ -133,7 +133,7 @@ async def receiver(ws): is_final = "message_type" in data and data["message_type"] == "FinalTranscript" if "text" in data and data["text"]: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=data["text"], confidence=data["confidence"], diff --git a/vocode/streaming/transcriber/azure_transcriber.py b/vocode/streaming/transcriber/azure_transcriber.py index c272fdf89..2cc666340 100644 --- a/vocode/streaming/transcriber/azure_transcriber.py +++ b/vocode/streaming/transcriber/azure_transcriber.py @@ -79,12 +79,12 @@ def recognized_sentence_final(self, evt): op=CustomSentrySpans.LATENCY_OF_CONVERSATION, start_timestamp=datetime.now(tz=timezone.utc), ) - self.output_janus_queue.sync_q.put_nowait( + self.produce_nonblocking( Transcription(message=evt.result.text, confidence=1.0, is_final=True) ) def recognized_sentence_stream(self, evt): - self.output_janus_queue.sync_q.put_nowait( + self.produce_nonblocking( Transcription(message=evt.result.text, confidence=1.0, is_final=False) ) diff --git a/vocode/streaming/transcriber/base_transcriber.py b/vocode/streaming/transcriber/base_transcriber.py index 3e89d23d1..beea852b1 100644 --- a/vocode/streaming/transcriber/base_transcriber.py +++ b/vocode/streaming/transcriber/base_transcriber.py @@ -8,18 +8,19 @@ from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.models.transcriber import TranscriberConfig, Transcription from vocode.streaming.utils.speed_manager import SpeedManager -from vocode.streaming.utils.worker import AsyncWorker, ThreadAsyncWorker +from vocode.streaming.utils.worker import AbstractWorker, AsyncWorker, ThreadAsyncWorker TranscriberConfigType = TypeVar("TranscriberConfigType", bound=TranscriberConfig) -class AbstractTranscriber(Generic[TranscriberConfigType], ABC): +class AbstractTranscriber(Generic[TranscriberConfigType], AbstractWorker[bytes]): + consumer: AbstractWorker[Transcription] + def __init__(self, transcriber_config: TranscriberConfigType): + AbstractWorker.__init__(self) self.transcriber_config = transcriber_config self.is_muted = False self.speed_manager: Optional[SpeedManager] = None - self.input_queue: asyncio.Queue[bytes] = asyncio.Queue() - self.output_queue: asyncio.Queue[Transcription] = asyncio.Queue() def attach_speed_manager(self, speed_manager: SpeedManager): self.speed_manager = speed_manager @@ -47,12 +48,15 @@ def create_silent_chunk(self, chunk_size, sample_width=2): async def _run_loop(self): pass - def send_audio(self, chunk): + def send_audio(self, chunk: bytes): if not self.is_muted: self.consume_nonblocking(chunk) else: self.consume_nonblocking(self.create_silent_chunk(len(chunk))) + def produce_nonblocking(self, item: Transcription): + self.consumer.consume_nonblocking(item) + @abstractmethod def terminate(self): pass @@ -61,7 +65,7 @@ def terminate(self): class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker[bytes]): # type: ignore def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) - AsyncWorker.__init__(self, self.input_queue, self.output_queue) + AsyncWorker.__init__(self) def terminate(self): AsyncWorker.terminate(self) @@ -72,11 +76,31 @@ class BaseThreadAsyncTranscriber( # type: ignore ): def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) - ThreadAsyncWorker.__init__(self, self.input_queue, self.output_queue) + ThreadAsyncWorker.__init__(self) def _run_loop(self): raise NotImplementedError + async def run_thread_forwarding(self): + try: + await asyncio.gather( + self._forward_to_thread(), + self._forward_from_thread(), + ) + except asyncio.CancelledError: + return + + async def _forward_from_thread(self): + while True: + try: + transcription = await self.output_janus_queue.async_q.get() + self.consumer.consume_nonblocking(transcription) + except asyncio.CancelledError: + break + + def produce_nonblocking(self, item: Transcription): + self.output_janus_queue.sync_q.put_nowait(item) + def terminate(self): ThreadAsyncWorker.terminate(self) diff --git a/vocode/streaming/transcriber/deepgram_transcriber.py b/vocode/streaming/transcriber/deepgram_transcriber.py index eef281d58..18bf369de 100644 --- a/vocode/streaming/transcriber/deepgram_transcriber.py +++ b/vocode/streaming/transcriber/deepgram_transcriber.py @@ -190,7 +190,7 @@ def terminate(self): }, ) terminate_msg = json.dumps({"type": "CloseStream"}).encode("utf-8") - self.input_queue.put_nowait(terminate_msg) + self.consume_nonblocking(terminate_msg) # todo (dow-107): typing self._ended = True super().terminate() @@ -485,7 +485,7 @@ async def receiver(ws: WebSocketClientProtocol): is_final_ts=is_final_ts, output_ts=output_ts, ) - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=buffer, confidence=buffer_avg_confidence, @@ -513,7 +513,7 @@ async def receiver(ws: WebSocketClientProtocol): else: interim_message = buffer - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=interim_message, confidence=deepgram_response.top_choice.confidence, diff --git a/vocode/streaming/transcriber/gladia_transcriber.py b/vocode/streaming/transcriber/gladia_transcriber.py index 54366035a..f1b033934 100644 --- a/vocode/streaming/transcriber/gladia_transcriber.py +++ b/vocode/streaming/transcriber/gladia_transcriber.py @@ -53,7 +53,7 @@ def send_audio(self, chunk): if ( len(self.buffer) / (2 * self.transcriber_config.sampling_rate) ) >= self.transcriber_config.buffer_size_seconds: - self.input_queue.put_nowait(self.buffer) + self.consume_nonblocking(self.buffer) self.buffer = bytearray() def terminate(self): @@ -104,7 +104,7 @@ async def receiver(ws): is_final = data["type"] == "final" if "transcription" in data and data["transcription"]: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=data["transcription"], confidence=data["confidence"], diff --git a/vocode/streaming/transcriber/google_transcriber.py b/vocode/streaming/transcriber/google_transcriber.py index 5f3da9a68..97c2c34b5 100644 --- a/vocode/streaming/transcriber/google_transcriber.py +++ b/vocode/streaming/transcriber/google_transcriber.py @@ -80,7 +80,7 @@ def _on_response(self, response): message = top_choice.transcript confidence = top_choice.confidence - self.output_janus_queue.sync_q.put_nowait( + self.produce_nonblocking( Transcription(message=message, confidence=confidence, is_final=result.is_final) ) diff --git a/vocode/streaming/transcriber/rev_ai_transcriber.py b/vocode/streaming/transcriber/rev_ai_transcriber.py index 856768695..49f37eb98 100644 --- a/vocode/streaming/transcriber/rev_ai_transcriber.py +++ b/vocode/streaming/transcriber/rev_ai_transcriber.py @@ -118,12 +118,12 @@ async def receiver(ws: WebSocketClientProtocol): confidence = 1.0 if is_done: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription(message=buffer, confidence=confidence, is_final=True) ) buffer = "" else: - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription( message=buffer, confidence=confidence, @@ -137,5 +137,5 @@ async def receiver(ws: WebSocketClientProtocol): def terminate(self): terminate_msg = json.dumps({"type": "CloseStream"}) - self.input_queue.put_nowait(terminate_msg) + self.consume_nonblocking(terminate_msg) self.closed = True diff --git a/vocode/streaming/transcriber/whisper_cpp_transcriber.py b/vocode/streaming/transcriber/whisper_cpp_transcriber.py index 6f1967de5..c12c6333c 100644 --- a/vocode/streaming/transcriber/whisper_cpp_transcriber.py +++ b/vocode/streaming/transcriber/whisper_cpp_transcriber.py @@ -72,7 +72,7 @@ def _run_loop(self): message_buffer += message is_final = any(message_buffer.endswith(ending) for ending in SENTENCE_ENDINGS) in_memory_wav, audio_buffer = self.create_new_buffer() - self.output_queue.put_nowait( + self.produce_nonblocking( Transcription(message=message_buffer, confidence=confidence, is_final=is_final) ) if is_final: diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index cdbff41ac..5d5b7fda0 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -2,7 +2,8 @@ import asyncio import threading -from typing import Any, Generic, Optional, TypeVar +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Optional, TypeVar import janus from loguru import logger @@ -12,15 +13,39 @@ WorkerInputType = TypeVar("WorkerInputType") -class AsyncWorker(Generic[WorkerInputType]): +class AbstractWorker(Generic[WorkerInputType], ABC): + @abstractmethod + def start(self): + raise NotImplementedError + + @abstractmethod + def consume_nonblocking(self, item: WorkerInputType): + raise NotImplementedError + + def terminate(self): + pass + + +class QueueConsumer(AbstractWorker[WorkerInputType]): + def __init__( + self, + input_queue: Optional[asyncio.Queue[WorkerInputType]] = None, + ) -> None: + self.input_queue: asyncio.Queue[WorkerInputType] = input_queue or asyncio.Queue() + + def consume_nonblocking(self, item: WorkerInputType): + self.input_queue.put_nowait(item) + + def start(self): + pass + + +class AsyncWorker(AbstractWorker[WorkerInputType]): def __init__( self, - input_queue: asyncio.Queue[WorkerInputType], - output_queue: asyncio.Queue = asyncio.Queue(), ) -> None: self.worker_task: Optional[asyncio.Task] = None - self.input_queue = input_queue - self.output_queue = output_queue + self.input_queue: asyncio.Queue[WorkerInputType] = asyncio.Queue() def start(self) -> asyncio.Task: self.worker_task = asyncio_create_task_with_done_error_log( @@ -33,9 +58,6 @@ def start(self) -> asyncio.Task: def consume_nonblocking(self, item: WorkerInputType): self.input_queue.put_nowait(item) - def produce_nonblocking(self, item): - self.output_queue.put_nowait(item) - async def _run_loop(self): raise NotImplementedError @@ -49,10 +71,8 @@ def terminate(self): class ThreadAsyncWorker(AsyncWorker[WorkerInputType]): def __init__( self, - input_queue: asyncio.Queue[WorkerInputType], - output_queue: asyncio.Queue = asyncio.Queue(), ) -> None: - super().__init__(input_queue, output_queue) + super().__init__() self.worker_thread: Optional[threading.Thread] = None self.input_janus_queue: janus.Queue[WorkerInputType] = janus.Queue() self.output_janus_queue: janus.Queue = janus.Queue() @@ -69,10 +89,7 @@ def start(self) -> asyncio.Task: async def run_thread_forwarding(self): try: - await asyncio.gather( - self._forward_to_thread(), - self._forward_from_thead(), - ) + await self._forward_to_thread() except asyncio.CancelledError: return @@ -81,17 +98,9 @@ async def _forward_to_thread(self): item = await self.input_queue.get() self.input_janus_queue.async_q.put_nowait(item) - async def _forward_from_thead(self): - while True: - item = await self.output_janus_queue.async_q.get() - self.output_queue.put_nowait(item) - def _run_loop(self): raise NotImplementedError - def terminate(self): - return super().terminate() - class AsyncQueueWorker(AsyncWorker[WorkerInputType]): async def _run_loop(self): @@ -180,39 +189,15 @@ def create_interruptible_agent_response_event( class InterruptibleWorker(AsyncWorker[InterruptibleEventType]): def __init__( self, - input_queue: asyncio.Queue[InterruptibleEventType], - output_queue: asyncio.Queue = asyncio.Queue(), interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(), max_concurrency=2, ) -> None: - super().__init__(input_queue, output_queue) - self.input_queue = input_queue + super().__init__() self.max_concurrency = max_concurrency self.interruptible_event_factory = interruptible_event_factory self.current_task = None self.interruptible_event = None - def produce_interruptible_event_nonblocking(self, item: Any, is_interruptible: bool = True): - interruptible_event = self.interruptible_event_factory.create_interruptible_event( - item, is_interruptible=is_interruptible - ) - return super().produce_nonblocking(interruptible_event) - - def produce_interruptible_agent_response_event_nonblocking( - self, - item: Any, - is_interruptible: bool = True, - agent_response_tracker: Optional[asyncio.Event] = None, - ): - interruptible_utterance_event = ( - self.interruptible_event_factory.create_interruptible_agent_response_event( - item, - is_interruptible=is_interruptible, - agent_response_tracker=agent_response_tracker or asyncio.Event(), - ) - ) - return super().produce_nonblocking(interruptible_utterance_event) - async def _run_loop(self): # TODO Implement concurrency with max_nb_of_thread while True: From bbed1645e4111140913ff225e3e18083f463c6d9 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 20 Jun 2024 19:02:25 -0700 Subject: [PATCH 14/14] [DOW-107] refactor synthesizer as a worker (#571) * rename create_speech_uncached * deprecate agentstop - should use terminate_conversation instead * deprecate filleraudio from agent - if reimplemented it should go around the inner agent * [unstable] move agentresponsesworker logic into synthesizer * hook everything up * deprecate AgentResponseMessage and just use AgentResponse * few other respond refs * add comment for tear_down vs terminate * fix ref to create_speech_uncached * fix playground --- apps/telephony_app/speller_agent.py | 6 +- docs/open-source/create-your-own-agent.mdx | 10 +- docs/open-source/telephony.mdx | 8 +- playground/streaming/agent/chat.py | 75 ++++---- .../streaming/synthesizer/synthesize.py | 2 +- tests/streaming/agent/test_base_agent.py | 4 +- vocode/streaming/agent/base_agent.py | 70 ++----- vocode/streaming/agent/echo_agent.py | 6 +- vocode/streaming/agent/gpt4all_agent.py | 4 +- .../agent/restful_user_implemented_agent.py | 9 +- vocode/streaming/agent/vertex_ai_agent.py | 6 +- .../agent/websocket_user_implemented_agent.py | 22 +-- vocode/streaming/streaming_conversation.py | 182 ++---------------- .../synthesizer/azure_synthesizer.py | 2 +- .../streaming/synthesizer/bark_synthesizer.py | 2 +- .../streaming/synthesizer/base_synthesizer.py | 175 ++++++++++++++++- .../synthesizer/cartesia_synthesizer.py | 2 +- .../synthesizer/coqui_tts_synthesizer.py | 2 +- .../synthesizer/eleven_labs_synthesizer.py | 2 +- .../eleven_labs_websocket_synthesizer.py | 2 +- .../synthesizer/google_synthesizer.py | 2 +- .../streaming/synthesizer/gtts_synthesizer.py | 2 +- .../synthesizer/play_ht_synthesizer.py | 2 +- .../synthesizer/play_ht_synthesizer_v2.py | 2 +- .../synthesizer/polly_synthesizer.py | 2 +- .../streaming/synthesizer/rime_synthesizer.py | 2 +- .../stream_elements_synthesizer.py | 2 +- vocode/streaming/utils/worker.py | 4 - 28 files changed, 278 insertions(+), 331 deletions(-) diff --git a/apps/telephony_app/speller_agent.py b/apps/telephony_app/speller_agent.py index fc4f441aa..9420e27e5 100644 --- a/apps/telephony_app/speller_agent.py +++ b/apps/telephony_app/speller_agent.py @@ -31,7 +31,7 @@ async def respond( human_input: str, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[Optional[str], bool]: + ) -> Optional[str]: """Generates a response from the SpellerAgent. The response is generated by joining each character in the human input with a space. @@ -43,9 +43,9 @@ async def respond( is_interrupt (bool): A flag indicating whether the agent was interrupted. Returns: - Tuple[Optional[str], bool]: The generated response and a flag indicating whether to stop. + Optional[str]: The generated response """ - return "".join(c + " " for c in human_input), False + return "".join(c + " " for c in human_input) class SpellerAgentFactory(AbstractAgentFactory): diff --git a/docs/open-source/create-your-own-agent.mdx b/docs/open-source/create-your-own-agent.mdx index f4df445f6..5b7fa6c65 100644 --- a/docs/open-source/create-your-own-agent.mdx +++ b/docs/open-source/create-your-own-agent.mdx @@ -7,15 +7,18 @@ You can subclass a [`RespondAgent`](https://github.com/vocodedev/vocode-python/b To do so, you must create an agent type, create an agent config, and then create your agent subclass. In the examples below, we will create an agent that responds with the same message no matter what is said to it, called `BrokenRecordAgent`. ### Agent type + Each agent has a unique agent type string that is checked in various parts of Vocode, most notably in the factories that create agents. So, you must create a new type for your custom agent. See the `AgentType` enum in `vocode/streaming/models/agent.py` for examples. For our `BrokenRecordAgent`, we will use "agent_broken_record" as our type. ### Agent config + Your agent must have a corresponding agent config that is a subclass of `AgentConfig` and is ([JSON-serializable](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json)). Serialization is automatically handled by [Pydantic](https://docs.pydantic.dev/latest/). The agent config should only contain the information you need to deterministically create the same agent each time. This means with the same parameters in your config, the corresponding agent should have the same behavior each time you create it. For our `BrokenRecordAgent`, we create a config like: + ```python from vocode.streaming.models.agent import AgentConfig @@ -24,21 +27,24 @@ class BrokenRecordAgentConfig(AgentConfig, type="agent_broken_record"): ``` ### Custom Agent + Now, you can create your custom agent subclass of `RespondAgent`. In your class header, pass in `RespondAgent` with a your agent type as a type hint. This should look like `RespondAgent[Your_Agent_Type]`. -Each agent should override the `generate_response()` async method to support streaming and `respond()` method to support turn-based conversations. +Each agent should override the `generate_response()` async method to support streaming and `respond()` method to support turn-based conversations. + > If you want to only support turn-based conversations, you do not have to overwrite `generate_response()` but you MUST set `generate_response=False` in your agent config (see `ChatVertexAIAgentConfig` in `vocode/streaming/models/agent.py` for an example). Otherwise, you must ALWAYS implement the `generate_response()` async method. The `generate_response()` method returns an `AsyncGenerator` of tuples containing each message/sentence and a boolean for whether that message can be interrupted by the human speaking. You can automatically create this generator by yielding instead of returning (see example below). We will now define our `BrokenRecordAgent`. Since we simply return the same message each time, we can return and yield that message in `respond()` and `generate_response()`, respectively: + ```python class BrokenRecordAgent(RespondAgent[BrokenRecordAgentConfig]): # is_interrupt is True when the human has just interrupted the bot's last response def respond( self, human_input, is_interrupt: bool = False - ) -> tuple[Optional[str], bool]: + ) -> Optional[str] return self.agent_config.message async def generate_response( diff --git a/docs/open-source/telephony.mdx b/docs/open-source/telephony.mdx index 6537e1e3b..8cbde66b6 100644 --- a/docs/open-source/telephony.mdx +++ b/docs/open-source/telephony.mdx @@ -146,8 +146,8 @@ class SpellerAgent(RespondAgent[SpellerAgentConfig]): human_input: str, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[Optional[str], bool]: - return "".join(c + " " for c in human_input), False + ) -> Optional[str]: + return "".join(c + " " for c in human_input) class SpellerAgentFactory(AbstractAgentFactory): @@ -182,10 +182,10 @@ class SpellerAgent(BaseAgent): human_input: str, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[Optional[str], bool]: + ) -> Optional[str]: call_config = self.config_manager.get_config(conversation_id) if call_config is not None: from_phone = call_config.twilio_from to_phone = call_config.twilio_to - return "".join(c + " " for c in human_input), False + return "".join(c + " " for c in human_input) ``` diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index e93b9eeda..d0c3589ff 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -28,8 +28,7 @@ from vocode.streaming.agent import ChatGPTAgent from vocode.streaming.agent.base_agent import ( AgentResponse, - AgentResponseMessage, - AgentResponseType, + AgentResponse, BaseAgent, TranscriptionAgentInput, ) @@ -113,47 +112,39 @@ async def receiver(): while not ended: try: event = await agent_response_queue.get() - response = event.payload - if response.type == AgentResponseType.FILLER_AUDIO: - print("Would have sent filler audio") - elif response.type == AgentResponseType.STOP: - print("Agent returned stop") - ended = True - break - elif response.type == AgentResponseType.MESSAGE: - agent_response = typing.cast(AgentResponseMessage, response) - - if isinstance(agent_response.message, EndOfTurn): - ignore_until_end_of_turn = False - if random.random() < backchannel_probability: - backchannel = random.choice(BACKCHANNELS) - print("Human: " + f"[{backchannel}]") - agent.transcript.add_human_message( - backchannel, - conversation_id, - is_backchannel=True, - ) - elif isinstance(agent_response.message, BaseMessage): - if ignore_until_end_of_turn: - continue - - message_sent: str - is_final: bool - # TODO: consider allowing the user to interrupt the agent manually by responding fast - if random.random() < interruption_probability: - stop_idx = random.randint(0, len(agent_response.message.text)) - message_sent = agent_response.message.text[:stop_idx] - ignore_until_end_of_turn = True - is_final = False - else: - message_sent = agent_response.message.text - is_final = True - - agent.transcript.add_bot_message( - message_sent, conversation_id, is_final=is_final + agent_response = event.payload + + if isinstance(agent_response.message, EndOfTurn): + ignore_until_end_of_turn = False + if random.random() < backchannel_probability: + backchannel = random.choice(BACKCHANNELS) + print("Human: " + f"[{backchannel}]") + agent.transcript.add_human_message( + backchannel, + conversation_id, + is_backchannel=True, ) + elif isinstance(agent_response.message, BaseMessage): + if ignore_until_end_of_turn: + continue + + message_sent: str + is_final: bool + # TODO: consider allowing the user to interrupt the agent manually by responding fast + if random.random() < interruption_probability: + stop_idx = random.randint(0, len(agent_response.message.text)) + message_sent = agent_response.message.text[:stop_idx] + ignore_until_end_of_turn = True + is_final = False + else: + message_sent = agent_response.message.text + is_final = True + + agent.transcript.add_bot_message( + message_sent, conversation_id, is_final=is_final + ) - print("AI: " + message_sent + ("-" if not is_final else "")) + print("AI: " + message_sent + ("-" if not is_final else "")) except asyncio.CancelledError: break @@ -161,7 +152,7 @@ async def sender(): if agent.agent_config.initial_message is not None: agent.agent_responses_consumer.consume_nonblocking( InterruptibleAgentResponseEvent( - payload=AgentResponseMessage(message=agent.agent_config.initial_message), + payload=AgentResponse(message=agent.agent_config.initial_message), agent_response_tracker=asyncio.Event(), ) ) diff --git a/playground/streaming/synthesizer/synthesize.py b/playground/streaming/synthesizer/synthesize.py index c9431620a..d600683af 100644 --- a/playground/streaming/synthesizer/synthesize.py +++ b/playground/streaming/synthesizer/synthesize.py @@ -31,7 +31,7 @@ async def speak( synthesizer.get_synthesizer_config().sampling_rate, ) # ClientSession needs to be created within the async task - synthesis_result = await synthesizer.create_speech_uncached( + synthesis_result = await synthesizer.create_speech( message=message, chunk_size=int(chunk_size), ) diff --git a/tests/streaming/agent/test_base_agent.py b/tests/streaming/agent/test_base_agent.py index a794db7ef..2eb1429f3 100644 --- a/tests/streaming/agent/test_base_agent.py +++ b/tests/streaming/agent/test_base_agent.py @@ -7,7 +7,7 @@ from vocode.streaming.action.abstract_factory import AbstractActionFactory from vocode.streaming.agent.base_agent import ( AgentResponse, - AgentResponseMessage, + AgentResponse, BaseAgent, GeneratedResponse, TranscriptionAgentInput, @@ -66,7 +66,7 @@ async def _consume_until_end_of_turn( agent_consumer.input_queue.get(), timeout=timeout ) agent_responses.append(agent_response.payload) - if isinstance(agent_response.payload, AgentResponseMessage) and isinstance( + if isinstance(agent_response.payload, AgentResponse) and isinstance( agent_response.payload.message, EndOfTurn ): break diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index 7dd5a3fe7..523aa31d5 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -88,18 +88,7 @@ class ActionResultAgentInput(AgentInput, type=AgentInputType.ACTION_RESULT.value is_quiet: bool = False -class AgentResponseType(str, Enum): - BASE = "agent_response_base" - MESSAGE = "agent_response_message" - STOP = "agent_response_stop" - FILLER_AUDIO = "agent_response_filler_audio" - - -class AgentResponse(TypedModel, type=AgentResponseType.BASE.value): # type: ignore - pass - - -class AgentResponseMessage(AgentResponse, type=AgentResponseType.MESSAGE.value): # type: ignore +class AgentResponse(BaseModel): message: Union[BaseMessage, EndOfTurn] is_interruptible: bool = True # Whether the message is the first message in the response; has metrics implications @@ -108,17 +97,6 @@ class AgentResponseMessage(AgentResponse, type=AgentResponseType.MESSAGE.value): is_sole_text_chunk: bool = False -class AgentResponseStop(AgentResponse, type=AgentResponseType.STOP.value): # type: ignore - pass - - -class AgentResponseFillerAudio( - AgentResponse, - type=AgentResponseType.FILLER_AUDIO.value, # type: ignore -): - pass - - class GeneratedResponse(BaseModel): message: Union[BaseMessage, FunctionCall, EndOfTurn] is_interruptible: bool @@ -248,7 +226,7 @@ async def handle_generate_response( self, transcription: Transcription, agent_input: AgentInput, - ) -> bool: + ): conversation_id = agent_input.conversation_id responses = self._maybe_prepend_interrupt_responses( transcription=transcription, @@ -294,7 +272,7 @@ async def handle_generate_response( agent_response_tracker = agent_input.agent_response_tracker or asyncio.Event() self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage( + AgentResponse( message=generated_response.message, is_first=is_first_response_of_turn, ), @@ -327,7 +305,7 @@ async def handle_generate_response( ) self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage( + AgentResponse( message=EndOfTurn(), is_first=is_first_response_of_turn, ), @@ -353,14 +331,12 @@ async def handle_generate_response( ) self.enqueue_action_input(action, action_input, agent_input.conversation_id) - # TODO: implement should_stop for generate_responses if function_call and self.agent_config.actions is not None: await self.call_function(function_call, agent_input) - return False async def handle_respond(self, transcription: Transcription, conversation_id: str) -> bool: try: - response, should_stop = await self.respond( + response = await self.respond( transcription.message, is_interrupt=transcription.is_interrupt, conversation_id=conversation_id, @@ -372,17 +348,16 @@ async def handle_respond(self, transcription: Transcription, conversation_id: st if response: self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage(message=BaseMessage(text=response)), + AgentResponse(message=BaseMessage(text=response)), is_interruptible=self.agent_config.allow_agent_to_be_cut_off, ) ) self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage(message=EndOfTurn()), + AgentResponse(message=EndOfTurn()), is_interruptible=self.agent_config.allow_agent_to_be_cut_off, ) ) - return should_stop else: logger.debug("No response generated") return False @@ -410,7 +385,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]): if agent_input.action_output.canned_response is not None: self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage( + AgentResponse( message=agent_input.action_output.canned_response, is_sole_text_chunk=True, ), @@ -419,7 +394,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]): ) self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage(message=EndOfTurn()), + AgentResponse(message=EndOfTurn()), ) ) return @@ -435,15 +410,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]): logger.debug("Agent is muted, skipping processing") return - if self.agent_config.send_filler_audio: - self.agent_responses_consumer.consume_nonblocking( - self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseFillerAudio(), - ) - ) - logger.debug("Responding to transcription") - should_stop = False if self.agent_config.generate_responses: # TODO (EA): this is quite ugly but necessary to have the agent act properly after an action completes if not isinstance(agent_input, ActionResultAgentInput): @@ -451,18 +418,9 @@ async def process(self, item: InterruptibleEvent[AgentInput]): sentry_callable=sentry_sdk.start_span, op=CustomSentrySpans.LANGUAGE_MODEL_TIME_TO_FIRST_TOKEN, ) - should_stop = await self.handle_generate_response(transcription, agent_input) + await self.handle_generate_response(transcription, agent_input) else: - should_stop = await self.handle_respond(transcription, agent_input.conversation_id) - - if should_stop: - logger.debug("Agent requested to stop") - self.agent_responses_consumer.consume_nonblocking( - self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseStop(), - ) - ) - return + await self.handle_respond(transcription, agent_input.conversation_id) except asyncio.CancelledError: pass @@ -490,7 +448,7 @@ async def call_function(self, function_call: FunctionCall, agent_input: AgentInp user_message_tracker = asyncio.Event() self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage( + AgentResponse( message=BaseMessage(text=user_message), is_sole_text_chunk=True, ), @@ -499,7 +457,7 @@ async def call_function(self, function_call: FunctionCall, agent_input: AgentInp ) self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage(message=EndOfTurn()), + AgentResponse(message=EndOfTurn()), agent_response_tracker=user_message_tracker, ) ) @@ -567,7 +525,7 @@ async def respond( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[Optional[str], bool]: + ) -> Optional[str]: raise NotImplementedError def generate_response( diff --git a/vocode/streaming/agent/echo_agent.py b/vocode/streaming/agent/echo_agent.py index 010925773..e55a02249 100644 --- a/vocode/streaming/agent/echo_agent.py +++ b/vocode/streaming/agent/echo_agent.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Tuple +from typing import AsyncGenerator, Optional, Tuple from vocode.streaming.agent.base_agent import GeneratedResponse, RespondAgent from vocode.streaming.models.agent import EchoAgentConfig @@ -11,8 +11,8 @@ async def respond( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[str, bool]: - return human_input, False + ) -> Optional[str]: + return human_input async def generate_response( self, diff --git a/vocode/streaming/agent/gpt4all_agent.py b/vocode/streaming/agent/gpt4all_agent.py index f87d8d24f..4d832478e 100644 --- a/vocode/streaming/agent/gpt4all_agent.py +++ b/vocode/streaming/agent/gpt4all_agent.py @@ -23,5 +23,5 @@ async def respond( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[Optional[str], bool]: - return (await self.turn_based_agent.respond_async(human_input)), False + ) -> Optional[str]: + return await self.turn_based_agent.respond_async(human_input) diff --git a/vocode/streaming/agent/restful_user_implemented_agent.py b/vocode/streaming/agent/restful_user_implemented_agent.py index 02999617c..c7f202139 100644 --- a/vocode/streaming/agent/restful_user_implemented_agent.py +++ b/vocode/streaming/agent/restful_user_implemented_agent.py @@ -29,7 +29,7 @@ async def respond( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[Optional[str], bool]: + ) -> Optional[str]: config = self.agent_config.respond try: # TODO: cache session @@ -46,12 +46,11 @@ async def respond( assert response.status == 200 output: RESTfulAgentOutput = RESTfulAgentOutput.parse_obj(await response.json()) output_response = None - should_stop = False if output.type == RESTfulAgentOutputType.TEXT: output_response = cast(RESTfulAgentText, output).response elif output.type == RESTfulAgentOutputType.END: - should_stop = True - return output_response, should_stop + await self.conversation_state_manager.terminate_conversation() + return output_response except Exception as e: logger.error(f"Error in response from RESTful agent: {e}") - return None, True + return None diff --git a/vocode/streaming/agent/vertex_ai_agent.py b/vocode/streaming/agent/vertex_ai_agent.py index 264676733..2ac28b41e 100644 --- a/vocode/streaming/agent/vertex_ai_agent.py +++ b/vocode/streaming/agent/vertex_ai_agent.py @@ -1,6 +1,6 @@ import asyncio from concurrent.futures import ThreadPoolExecutor -from typing import Tuple +from typing import Optional, Tuple from langchain import ConversationChain from langchain.memory import ConversationBufferMemory @@ -44,7 +44,7 @@ async def respond( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> Tuple[str, bool]: + ) -> Optional[str]: # Vertex AI doesn't allow async, so we run in a separate thread text = await asyncio.get_event_loop().run_in_executor( self.thread_pool_executor, @@ -53,4 +53,4 @@ async def respond( ) logger.debug(f"LLM response: {text}") - return text, False + return text diff --git a/vocode/streaming/agent/websocket_user_implemented_agent.py b/vocode/streaming/agent/websocket_user_implemented_agent.py index f8233588f..5757382c5 100644 --- a/vocode/streaming/agent/websocket_user_implemented_agent.py +++ b/vocode/streaming/agent/websocket_user_implemented_agent.py @@ -8,8 +8,7 @@ from vocode.streaming.agent.base_agent import ( AgentInput, AgentResponse, - AgentResponseMessage, - AgentResponseStop, + AgentResponse, BaseAgent, TranscriptionAgentInput, ) @@ -48,15 +47,15 @@ async def _run_loop(self) -> None: restarts += 1 logger.debug("Socket Agent connection died, restarting, num_restarts: %s", restarts) - def _handle_incoming_socket_message(self, message: WebSocketAgentMessage) -> None: + async def _handle_incoming_socket_message(self, message: WebSocketAgentMessage) -> None: logger.info("Handling incoming message from Socket Agent: %s", message) agent_response: AgentResponse if isinstance(message, WebSocketAgentTextMessage): - agent_response = AgentResponseMessage(message=BaseMessage(text=message.data.text)) + agent_response = AgentResponse(message=BaseMessage(text=message.data.text)) elif isinstance(message, WebSocketAgentStopMessage): - agent_response = AgentResponseStop() + await self.conversation_state_manager.terminate_conversation() self.has_ended = True else: raise Exception("Unknown Socket message type") @@ -92,9 +91,6 @@ async def sender( ) agent_request_json = agent_request.json() logger.info(f"Sending data to web socket agent: {agent_request_json}") - if isinstance(agent_request, AgentResponseStop): - # In practice, it doesn't make sense for the client to send a text and stop message to the agent service - self.has_ended = True await ws.send(agent_request_json) @@ -116,7 +112,7 @@ async def receiver(ws: WebSocketClientProtocol) -> None: logger.info("Received data from web socket agent") data = json.loads(msg) message = WebSocketAgentMessage.parse_obj(data) - self._handle_incoming_socket_message(message) + await self._handle_incoming_socket_message(message) except websockets.exceptions.ConnectionClosed as e: logger.error(f'WebSocket Agent Receive Error: Connection Closed - "{e}"') @@ -137,11 +133,3 @@ async def receiver(ws: WebSocketClientProtocol) -> None: logger.debug("Terminating Web Socket User Implemented Agent receiver") await asyncio.gather(sender(ws), receiver(ws)) - - def terminate(self): - self.agent_responses_consumer.consume_nonblocking( - self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseStop() - ) - ) - super().terminate() diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index d8e9f34a2..f7dbbb09c 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -28,9 +28,7 @@ from vocode.streaming.agent.base_agent import ( AgentInput, AgentResponse, - AgentResponseFillerAudio, - AgentResponseMessage, - AgentResponseStop, + AgentResponse, BaseAgent, TranscriptionAgentInput, ) @@ -69,7 +67,6 @@ AbstractWorker, AsyncQueueWorker, InterruptibleAgentResponseEvent, - InterruptibleAgentResponseWorker, InterruptibleEvent, InterruptibleEventFactory, InterruptibleWorker, @@ -367,154 +364,6 @@ async def process(self, item: InterruptibleAgentResponseEvent[FillerAudio]): except asyncio.CancelledError: pass - class AgentResponsesWorker(InterruptibleWorker[InterruptibleAgentResponseEvent[AgentResponse]]): - """Runs Synthesizer.create_speech and sends the SynthesisResult to the output queue""" - - consumer: AbstractWorker[ - InterruptibleAgentResponseEvent[ - Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] - ] - ] - - def __init__( - self, - conversation: "StreamingConversation", - interruptible_event_factory: InterruptibleEventFactory, - ): - super().__init__() - self.conversation = conversation - self.interruptible_event_factory = interruptible_event_factory - self.chunk_size = self.conversation._get_synthesizer_chunk_size() - self.last_agent_response_tracker: Optional[asyncio.Event] = None - self.is_first_text_chunk = True - - def send_filler_audio(self, agent_response_tracker: Optional[asyncio.Event]): - assert self.conversation.filler_audio_worker is not None - logger.debug("Sending filler audio") - if self.conversation.synthesizer.filler_audios: - filler_audio = random.choice(self.conversation.synthesizer.filler_audios) - logger.debug(f"Chose {filler_audio.message.text}") - event = self.interruptible_event_factory.create_interruptible_agent_response_event( - filler_audio, - is_interruptible=filler_audio.is_interruptible, - agent_response_tracker=agent_response_tracker, - ) - self.conversation.filler_audio_worker.consume_nonblocking(event) - else: - logger.debug("No filler audio available for synthesizer") - - async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]): - if not self.conversation.synthesis_enabled: - logger.debug("Synthesis disabled, not synthesizing speech") - return - try: - agent_response = item.payload - if isinstance(agent_response, AgentResponseFillerAudio): - self.send_filler_audio(item.agent_response_tracker) - return - if isinstance(agent_response, AgentResponseStop): - logger.debug("Agent requested to stop") - if self.last_agent_response_tracker is not None: - await self.last_agent_response_tracker.wait() - item.agent_response_tracker.set() - self.conversation.mark_terminated(bot_disconnect=True) - return - - agent_response_message = typing.cast(AgentResponseMessage, agent_response) - - if self.conversation.filler_audio_worker is not None: - if self.conversation.filler_audio_worker.interrupt_current_filler_audio(): - await self.conversation.filler_audio_worker.wait_for_filler_audio_to_finish() - - if isinstance(agent_response_message.message, EndOfTurn): - logger.debug("Sending end of turn") - if isinstance(self.conversation.synthesizer, InputStreamingSynthesizer): - await self.conversation.synthesizer.handle_end_of_turn() - self.consumer.consume_nonblocking( - self.interruptible_event_factory.create_interruptible_agent_response_event( - (agent_response_message.message, None), - is_interruptible=item.is_interruptible, - agent_response_tracker=item.agent_response_tracker, - ), - ) - self.is_first_text_chunk = True - return - - synthesizer_base_name: Optional[str] = ( - synthesizer_base_name_if_should_report_to_sentry(self.conversation.synthesizer) - ) - create_speech_span: Optional[Span] = None - ttft_span: Optional[Span] = None - synthesis_span: Optional[Span] = None - if synthesizer_base_name and agent_response_message.is_first: - complete_span_by_op(CustomSentrySpans.LANGUAGE_MODEL_TIME_TO_FIRST_TOKEN) - - sentry_create_span( - sentry_callable=sentry_sdk.start_span, - op=CustomSentrySpans.SYNTHESIS_TIME_TO_FIRST_TOKEN, - ) - - synthesis_span = sentry_create_span( - sentry_callable=sentry_sdk.start_span, - op=f"{synthesizer_base_name}{CustomSentrySpans.SYNTHESIZER_SYNTHESIS_TOTAL}", - ) - if synthesis_span: - ttft_span = sentry_create_span( - sentry_callable=synthesis_span.start_child, - op=f"{synthesizer_base_name}{CustomSentrySpans.SYNTHESIZER_TIME_TO_FIRST_TOKEN}", - ) - if ttft_span: - create_speech_span = sentry_create_span( - sentry_callable=ttft_span.start_child, - op=f"{synthesizer_base_name}{CustomSentrySpans.SYNTHESIZER_CREATE_SPEECH}", - ) - maybe_synthesis_result: Optional[SynthesisResult] = None - if isinstance( - self.conversation.synthesizer, - InputStreamingSynthesizer, - ) and isinstance(agent_response_message.message, LLMToken): - logger.debug("Sending chunk to synthesizer") - await self.conversation.synthesizer.send_token_to_synthesizer( - message=agent_response_message.message, - chunk_size=self.chunk_size, - ) - else: - logger.debug("Synthesizing speech for message") - maybe_synthesis_result = await self.conversation.synthesizer.create_speech( - agent_response_message.message, - self.chunk_size, - is_first_text_chunk=self.is_first_text_chunk, - is_sole_text_chunk=agent_response_message.is_sole_text_chunk, - ) - if create_speech_span: - create_speech_span.finish() - # For input streaming synthesizers, subsequent chunks are contained in the same SynthesisResult - if isinstance(self.conversation.synthesizer, InputStreamingSynthesizer): - if not self.is_first_text_chunk: - maybe_synthesis_result = None - elif isinstance(agent_response_message.message, LLMToken): - maybe_synthesis_result = ( - self.conversation.synthesizer.get_current_utterance_synthesis_result() - ) - if maybe_synthesis_result is not None: - synthesis_result = maybe_synthesis_result - synthesis_result.is_first = agent_response_message.is_first - if not synthesis_result.cached and synthesis_span: - synthesis_result.synthesis_total_span = synthesis_span - synthesis_result.ttft_span = ttft_span - self.consumer.consume_nonblocking( - self.interruptible_event_factory.create_interruptible_agent_response_event( - (agent_response_message.message, synthesis_result), - is_interruptible=item.is_interruptible, - agent_response_tracker=item.agent_response_tracker, - ), - ) - self.last_agent_response_tracker = item.agent_response_tracker - if not isinstance(agent_response_message.message, SilenceMessage): - self.is_first_text_chunk = False - except asyncio.CancelledError: - pass - class SynthesisResultsWorker( InterruptibleWorker[ InterruptibleAgentResponseEvent[ @@ -620,12 +469,10 @@ def __init__( self.agent.set_interruptible_event_factory(self.interruptible_event_factory) self.agent.attach_conversation_state_manager(self.state_manager) - # Agent Responses Worker - self.agent_responses_worker = self.AgentResponsesWorker( - conversation=self, - interruptible_event_factory=self.interruptible_event_factory, - ) - self.agent.agent_responses_consumer = self.agent_responses_worker + # Synthesizer + self.agent.agent_responses_consumer = self.synthesizer + self.synthesizer.set_interruptible_event_factory(self.interruptible_event_factory) + self.synthesizer.attach_conversation_state_manager(self.state_manager) # Actions Worker self.actions_worker = None @@ -640,7 +487,7 @@ def __init__( # Synthesis Results Worker self.synthesis_results_worker = self.SynthesisResultsWorker(conversation=self) - self.agent_responses_worker.consumer = self.synthesis_results_worker + self.synthesizer.consumer = self.synthesis_results_worker # Filler Audio Worker self.filler_audio_worker = None @@ -687,7 +534,6 @@ def create_state_manager(self) -> ConversationStateManager: async def start(self, mark_ready: Optional[Callable[[], Awaitable[None]]] = None): self.transcriber.start() self.transcriptions_worker.start() - self.agent_responses_worker.start() self.synthesis_results_worker.start() self.output_device.start() if self.filler_audio_worker is not None: @@ -708,6 +554,7 @@ async def start(self, mark_ready: Optional[Callable[[], Awaitable[None]]] = None await self.synthesizer.set_filler_audios(self.filler_audio_config) self.agent.start() + self.synthesizer.start() initial_message = self.agent.get_agent_config().initial_message if initial_message: asyncio_create_task_with_done_error_log( @@ -779,15 +626,15 @@ async def send_single_message( ): agent_response_event = ( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage(message=message, is_sole_text_chunk=True), + AgentResponse(message=message, is_sole_text_chunk=True), is_interruptible=False, agent_response_tracker=message_tracker, ) ) - self.agent_responses_worker.consume_nonblocking(agent_response_event) - self.agent_responses_worker.consume_nonblocking( + self.synthesizer.consume_nonblocking(agent_response_event) + self.synthesizer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( - AgentResponseMessage(message=EndOfTurn()), + AgentResponse(message=EndOfTurn()), is_interruptible=True, ), ) @@ -829,7 +676,7 @@ async def broadcast_interrupt(self): break self.output_device.interrupt() self.agent.cancel_current_task() - self.agent_responses_worker.cancel_current_task() + self.synthesizer.cancel_current_task() if self.actions_worker: self.actions_worker.cancel_current_task() return num_interrupts > 0 @@ -1007,6 +854,7 @@ async def terminate(self): self.events_task.cancel() await self.events_manager.flush() logger.debug("Tearing down synthesizer") + # TODO (DOW-114): we won't need this once terminate() is async await self.synthesizer.tear_down() logger.debug("Terminating agent") if isinstance(self.agent, ChatGPTAgent) and self.agent.agent_config.vector_db_config: @@ -1022,8 +870,8 @@ async def terminate(self): self.transcriber.terminate() logger.debug("Terminating transcriptions worker") self.transcriptions_worker.terminate() - logger.debug("Terminating final transcriptions worker") - self.agent_responses_worker.terminate() + logger.debug("Terminating synthesizer") + self.synthesizer.terminate() logger.debug("Terminating synthesis results worker") self.synthesis_results_worker.terminate() if self.filler_audio_worker is not None: diff --git a/vocode/streaming/synthesizer/azure_synthesizer.py b/vocode/streaming/synthesizer/azure_synthesizer.py index eee831e29..63887d3fb 100644 --- a/vocode/streaming/synthesizer/azure_synthesizer.py +++ b/vocode/streaming/synthesizer/azure_synthesizer.py @@ -231,7 +231,7 @@ def get_message_up_to( return ssml_fragment.split(">")[-1] return message - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/bark_synthesizer.py b/vocode/streaming/synthesizer/bark_synthesizer.py index e6f9058ba..c5e5cbbc5 100644 --- a/vocode/streaming/synthesizer/bark_synthesizer.py +++ b/vocode/streaming/synthesizer/bark_synthesizer.py @@ -25,7 +25,7 @@ def __init__( preload_models(**self.synthesizer_config.preload_kwargs) self.thread_pool_executor = ThreadPoolExecutor(max_workers=1) - async def create_speech( + async def create_speech_with_cache( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/base_synthesizer.py b/vocode/streaming/synthesizer/base_synthesizer.py index 0f0fea0d3..41c8cce0e 100644 --- a/vocode/streaming/synthesizer/base_synthesizer.py +++ b/vocode/streaming/synthesizer/base_synthesizer.py @@ -4,25 +4,54 @@ import math import os import wave -from typing import Any, AsyncGenerator, Callable, Generic, List, Optional, Tuple, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, +) import aiohttp from loguru import logger from nltk.tokenize import word_tokenize from nltk.tokenize.treebank import TreebankWordDetokenizer +import sentry_sdk from sentry_sdk.tracing import Span as SentrySpan +from vocode.streaming.agent.base_agent import AgentResponse, AgentResponse +from vocode.streaming.models.actions import EndOfTurn from vocode.streaming.models.agent import FillerAudioConfig from vocode.streaming.models.audio import AudioEncoding, SamplingRate -from vocode.streaming.models.message import BaseMessage, BotBackchannel, SilenceMessage +from vocode.streaming.models.message import BaseMessage, BotBackchannel, LLMToken, SilenceMessage from vocode.streaming.models.synthesizer import SynthesizerConfig from vocode.streaming.synthesizer.audio_cache import AudioCache +from vocode.streaming.synthesizer.input_streaming_synthesizer import InputStreamingSynthesizer from vocode.streaming.synthesizer.miniaudio_worker import MiniaudioWorker from vocode.streaming.telephony.constants import MULAW_SILENCE_BYTE, PCM_SILENCE_BYTE from vocode.streaming.utils import convert_wav, get_chunk_size_per_second from vocode.streaming.utils.async_requester import AsyncRequestor from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log -from vocode.streaming.utils.worker import QueueConsumer +from vocode.streaming.utils.worker import AbstractWorker, InterruptibleWorker, QueueConsumer +from vocode.streaming.utils.worker import ( + InterruptibleAgentResponseEvent, + InterruptibleEventFactory, +) +from vocode.utils.sentry_utils import ( + CustomSentrySpans, + complete_span_by_op, + sentry_create_span, + synthesizer_base_name_if_should_report_to_sentry, +) +from sentry_sdk.tracing import Span + +if TYPE_CHECKING: + from vocode.streaming.utils.state_manager import ConversationStateManager FILLER_PHRASES = [ BaseMessage(text="Um..."), @@ -223,11 +252,24 @@ def create_synthesis_result(self, chunk_size) -> SynthesisResult: SynthesizerConfigType = TypeVar("SynthesizerConfigType", bound=SynthesizerConfig) -class BaseSynthesizer(Generic[SynthesizerConfigType]): +class BaseSynthesizer( + Generic[SynthesizerConfigType], + InterruptibleWorker[InterruptibleAgentResponseEvent[AgentResponse]], +): + conversation_state_manager: "ConversationStateManager" + interruptible_event_factory: InterruptibleEventFactory + + consumer: AbstractWorker[ + InterruptibleAgentResponseEvent[ + Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] + ] + ] + def __init__( self, synthesizer_config: SynthesizerConfigType, ): + InterruptibleWorker.__init__(self) self.synthesizer_config = synthesizer_config if synthesizer_config.audio_encoding == AudioEncoding.MULAW: assert ( @@ -238,6 +280,125 @@ def __init__( self.total_chars: int = 0 self.cost_per_char: Optional[float] = None + self.last_agent_response_tracker: Optional[asyncio.Event] = None + self.is_first_text_chunk = True + + async def process( + self, item: InterruptibleAgentResponseEvent[AgentResponse] + ): # todo (dow-107): fix typing + if not self.conversation_state_manager._conversation.synthesis_enabled: + logger.debug("Synthesis disabled, not synthesizing speech") + return + try: + agent_response = item.payload + + # todo (dow-107): resupport filler audio + filler_audio_worker = self.conversation_state_manager._conversation.filler_audio_worker + if filler_audio_worker is not None: + if filler_audio_worker.interrupt_current_filler_audio(): + await filler_audio_worker.wait_for_filler_audio_to_finish() + + if isinstance(agent_response.message, EndOfTurn): + logger.debug("Sending end of turn") + if isinstance(self, InputStreamingSynthesizer): + await self.handle_end_of_turn() + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + (agent_response.message, None), + is_interruptible=item.is_interruptible, + agent_response_tracker=item.agent_response_tracker, + ), + ) + self.is_first_text_chunk = True + return + + synthesizer_base_name: Optional[str] = synthesizer_base_name_if_should_report_to_sentry( + self + ) + create_speech_span: Optional[Span] = None + ttft_span: Optional[Span] = None + synthesis_span: Optional[Span] = None + if synthesizer_base_name and agent_response.is_first: + complete_span_by_op(CustomSentrySpans.LANGUAGE_MODEL_TIME_TO_FIRST_TOKEN) + + sentry_create_span( + sentry_callable=sentry_sdk.start_span, + op=CustomSentrySpans.SYNTHESIS_TIME_TO_FIRST_TOKEN, + ) + + synthesis_span = sentry_create_span( + sentry_callable=sentry_sdk.start_span, + op=f"{synthesizer_base_name}{CustomSentrySpans.SYNTHESIZER_SYNTHESIS_TOTAL}", + ) + if synthesis_span: + ttft_span = sentry_create_span( + sentry_callable=synthesis_span.start_child, + op=f"{synthesizer_base_name}{CustomSentrySpans.SYNTHESIZER_TIME_TO_FIRST_TOKEN}", + ) + if ttft_span: + create_speech_span = sentry_create_span( + sentry_callable=ttft_span.start_child, + op=f"{synthesizer_base_name}{CustomSentrySpans.SYNTHESIZER_CREATE_SPEECH}", + ) + maybe_synthesis_result: Optional[SynthesisResult] = None + if isinstance( + self, + InputStreamingSynthesizer, + ) and isinstance(agent_response.message, LLMToken): + logger.debug("Sending chunk to synthesizer") + await self.send_token_to_synthesizer( + message=agent_response.message, + chunk_size=self.chunk_size, + ) + else: + logger.debug("Synthesizing speech for message") + maybe_synthesis_result = await self.create_speech_with_cache( + agent_response.message, + self.chunk_size, + is_first_text_chunk=self.is_first_text_chunk, + is_sole_text_chunk=agent_response.is_sole_text_chunk, + ) + if create_speech_span: + create_speech_span.finish() + # For input streaming synthesizers, subsequent chunks are contained in the same SynthesisResult + if isinstance(self, InputStreamingSynthesizer): + if not self.is_first_text_chunk: + maybe_synthesis_result = None + elif isinstance(agent_response.message, LLMToken): + maybe_synthesis_result = self.get_current_utterance_synthesis_result() + if maybe_synthesis_result is not None: + synthesis_result = maybe_synthesis_result + synthesis_result.is_first = agent_response.is_first + if not synthesis_result.cached and synthesis_span: + synthesis_result.synthesis_total_span = synthesis_span + synthesis_result.ttft_span = ttft_span + self.consumer.consume_nonblocking( + self.interruptible_event_factory.create_interruptible_agent_response_event( + (agent_response.message, synthesis_result), + is_interruptible=item.is_interruptible, + agent_response_tracker=item.agent_response_tracker, + ), + ) + self.last_agent_response_tracker = item.agent_response_tracker + if not isinstance(agent_response.message, SilenceMessage): + self.is_first_text_chunk = False + except asyncio.CancelledError: + pass + + @property + def chunk_size(self) -> int: + return self.conversation_state_manager._conversation._get_synthesizer_chunk_size() + + def attach_conversation_state_manager( + self, conversation_state_manager: "ConversationStateManager" + ): + self.conversation_state_manager = conversation_state_manager + + def set_interruptible_event_factory( + self, interruptible_event_factory: InterruptibleEventFactory + ): + self.interruptible_event_factory = interruptible_event_factory + @classmethod def get_voice_identifier(cls, synthesizer_config: SynthesizerConfigType) -> str: raise NotImplementedError @@ -325,7 +486,7 @@ async def get_cached_audio( trailing_silence_seconds = message.trailing_silence_seconds return CachedAudio(message, audio_data, self.synthesizer_config, trailing_silence_seconds) - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, @@ -334,7 +495,7 @@ async def create_speech_uncached( ) -> SynthesisResult: raise NotImplementedError - async def create_speech( + async def create_speech_with_cache( self, message: BaseMessage, chunk_size: int, @@ -350,7 +511,7 @@ async def create_speech( maybe_cached_audio = await self.get_cached_audio(message) if maybe_cached_audio is not None: return maybe_cached_audio.create_synthesis_result(chunk_size) - return await self.create_speech_uncached( + return await self.create_speech( message, chunk_size, is_first_text_chunk=is_first_text_chunk, diff --git a/vocode/streaming/synthesizer/cartesia_synthesizer.py b/vocode/streaming/synthesizer/cartesia_synthesizer.py index 84f779ca0..fa7ccdab7 100644 --- a/vocode/streaming/synthesizer/cartesia_synthesizer.py +++ b/vocode/streaming/synthesizer/cartesia_synthesizer.py @@ -60,7 +60,7 @@ def __init__( self.client = self.cartesia_tts(api_key=self.api_key) self.voice_embedding = self.client.get_voice_embedding(voice_id=self.voice_id) - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/coqui_tts_synthesizer.py b/vocode/streaming/synthesizer/coqui_tts_synthesizer.py index ad7d6b7de..b10ca42bc 100644 --- a/vocode/streaming/synthesizer/coqui_tts_synthesizer.py +++ b/vocode/streaming/synthesizer/coqui_tts_synthesizer.py @@ -23,7 +23,7 @@ def __init__( self.language = synthesizer_config.language self.thread_pool_executor = ThreadPoolExecutor(max_workers=1) - async def create_speech( + async def create_speech_with_cache( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/eleven_labs_synthesizer.py b/vocode/streaming/synthesizer/eleven_labs_synthesizer.py index 714b5bf72..39a7d685e 100644 --- a/vocode/streaming/synthesizer/eleven_labs_synthesizer.py +++ b/vocode/streaming/synthesizer/eleven_labs_synthesizer.py @@ -65,7 +65,7 @@ def __init__( f"Unsupported audio encoding: {self.synthesizer_config.audio_encoding}" ) - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py b/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py index d95271c41..2b79fcac5 100644 --- a/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py +++ b/vocode/streaming/synthesizer/eleven_labs_websocket_synthesizer.py @@ -270,7 +270,7 @@ def get_current_utterance_synthesis_result(self): lambda seconds: self.get_current_message_so_far(seconds), ) - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/google_synthesizer.py b/vocode/streaming/synthesizer/google_synthesizer.py index 225ab2d5d..b32af76fc 100644 --- a/vocode/streaming/synthesizer/google_synthesizer.py +++ b/vocode/streaming/synthesizer/google_synthesizer.py @@ -56,7 +56,7 @@ def synthesize(self, message: str) -> Any: ) # TODO: make this nonblocking, see speech.TextToSpeechAsyncClient - async def create_speech( + async def create_speech_with_cache( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/gtts_synthesizer.py b/vocode/streaming/synthesizer/gtts_synthesizer.py index 8ad8a31d5..f63334487 100644 --- a/vocode/streaming/synthesizer/gtts_synthesizer.py +++ b/vocode/streaming/synthesizer/gtts_synthesizer.py @@ -19,7 +19,7 @@ def __init__( self.thread_pool_executor = ThreadPoolExecutor(max_workers=1) - async def create_speech( + async def create_speech_with_cache( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/play_ht_synthesizer.py b/vocode/streaming/synthesizer/play_ht_synthesizer.py index 05405af9d..f64a7b875 100644 --- a/vocode/streaming/synthesizer/play_ht_synthesizer.py +++ b/vocode/streaming/synthesizer/play_ht_synthesizer.py @@ -34,7 +34,7 @@ def __init__( self.max_backoff_retries = max_backoff_retries self.backoff_retry_delay = backoff_retry_delay - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/play_ht_synthesizer_v2.py b/vocode/streaming/synthesizer/play_ht_synthesizer_v2.py index 6db09ef7d..5b7f4ef1e 100644 --- a/vocode/streaming/synthesizer/play_ht_synthesizer_v2.py +++ b/vocode/streaming/synthesizer/play_ht_synthesizer_v2.py @@ -76,7 +76,7 @@ def playht_client(self) -> AsyncClient: return self.playht_client_on_prem return self.playht_client_saas - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/polly_synthesizer.py b/vocode/streaming/synthesizer/polly_synthesizer.py index 72385e05f..eb327f464 100644 --- a/vocode/streaming/synthesizer/polly_synthesizer.py +++ b/vocode/streaming/synthesizer/polly_synthesizer.py @@ -74,7 +74,7 @@ def get_message_up_to( return message[: event["start"]] return message - async def create_speech( + async def create_speech_with_cache( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/rime_synthesizer.py b/vocode/streaming/synthesizer/rime_synthesizer.py index f776da4a8..10ca0e73f 100644 --- a/vocode/streaming/synthesizer/rime_synthesizer.py +++ b/vocode/streaming/synthesizer/rime_synthesizer.py @@ -51,7 +51,7 @@ def get_voice_identifier(cls, synthesizer_config: RimeSynthesizerConfig): ) ) - async def create_speech_uncached( + async def create_speech( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/synthesizer/stream_elements_synthesizer.py b/vocode/streaming/synthesizer/stream_elements_synthesizer.py index 3d656f5cc..b88c658fd 100644 --- a/vocode/streaming/synthesizer/stream_elements_synthesizer.py +++ b/vocode/streaming/synthesizer/stream_elements_synthesizer.py @@ -18,7 +18,7 @@ def __init__( super().__init__(synthesizer_config) self.voice = synthesizer_config.voice - async def create_speech( + async def create_speech_with_cache( self, message: BaseMessage, chunk_size: int, diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index 5d5b7fda0..e0d909b23 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -244,7 +244,3 @@ def cancel_current_task(self): return self.current_task.cancel() return False - - -class InterruptibleAgentResponseWorker(InterruptibleWorker[InterruptibleAgentResponseEvent]): - pass