Skip to content
This repository has been archived by the owner on May 6, 2022. It is now read-only.

Commit

Permalink
fix: enforce ordering of TTS responses
Browse files Browse the repository at this point in the history
Currently, TTS requests submitted in close proximity
can result in audio being delivered to the client in a
different order than the requests were submitted.

This change keeps the requests asynchronous (as they must
be for Android networking) while enforcing ordering for the
results by introducing a request queue in the TTS manager.
  • Loading branch information
space-pope committed Apr 29, 2021
1 parent d640584 commit b327265
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ public void setCredentials(String apiId, String apiSecret) {
* and any additional metadata.
*/
public void synthesize(SynthesisRequest request) {
HashMap<String, String> headers = new HashMap<>();
String requestId = request.metadata.get("id");
if (requestId != null) {
headers.put("x-request-id", requestId);
}

HashMap<String, String> variables = new HashMap<>();
String param = "text";
Expand All @@ -135,12 +131,16 @@ public void synthesize(SynthesisRequest request) {
}
variables.put("voice", request.voice);
String queryString = String.format(GRAPHQL_QUERY, param, method);
postSpeech(headers, queryString, variables);
postSpeech(requestId, queryString, variables);
}

private void postSpeech(Map<String, String> headers,
private void postSpeech(String requestId,
String queryString,
Map<String, String> variables) {
HashMap<String, String> headers = new HashMap<>();
if (requestId != null) {
headers.put("x-request-id", requestId);
}
if (this.ttsApiId == null) {
ttsCallback.onError("API ID not provided");
return;
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/io/spokestack/spokestack/tts/SynthesisRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

/**
* <p>
Expand Down Expand Up @@ -164,7 +165,7 @@ public static class Builder {
private CharSequence textToSynthesize;
private Mode synthesisMode = Mode.TEXT;
private String ttsVoice = "demo-male";
private Map<String, String> metadata = new HashMap<>();
private HashMap<String, String> metadata = new HashMap<>();

/**
* Create a new {@code TTSInput} builder with the only required data,
Expand Down Expand Up @@ -208,7 +209,7 @@ public Builder withVoice(String voice) {
* @return The current builder.
*/
public Builder withData(Map<String, String> requestData) {
this.metadata = requestData;
this.metadata.putAll(requestData);
return this;
}

Expand All @@ -220,6 +221,11 @@ public Builder withData(Map<String, String> requestData) {
* builder.
*/
public SynthesisRequest build() {
// add a random request ID if one doesn't exist
if (!metadata.containsKey("id")) {
UUID id = UUID.randomUUID();
metadata.put("id", id.toString());
}
return new SynthesisRequest(textToSynthesize, synthesisMode,
ttsVoice, metadata);
}
Expand Down
47 changes: 45 additions & 2 deletions src/main/java/io/spokestack/spokestack/tts/TTSManager.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package io.spokestack.spokestack.tts;

import android.content.Context;
import androidx.annotation.NonNull;
import io.spokestack.spokestack.SpeechConfig;
import io.spokestack.spokestack.SpeechOutput;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;

/**
* Manager for text-to-speech output in Spokestack.
Expand Down Expand Up @@ -52,11 +55,14 @@
* reallocate a manager's resources.
* </p>
*/
public final class TTSManager implements AutoCloseable {
public final class TTSManager implements AutoCloseable, TTSListener {
private final String ttsServiceClass;
private final String outputClass;
private final SpeechConfig config;
private final List<TTSListener> listeners = new ArrayList<>();
private final Queue<SynthesisRequest> requests = new ArrayDeque<>();
private final Object lock = new Object();
private boolean synthesizing = false;
private TTSService ttsService;
private SpeechOutput output;
private Context appContext;
Expand Down Expand Up @@ -104,7 +110,10 @@ public void synthesize(SynthesisRequest request) {
if (this.ttsService == null) {
throw new IllegalStateException("TTS closed; call prepare()");
}
this.ttsService.synthesize(request);
synchronized (lock) {
this.requests.add(request);
}
processQueue();
}

/**
Expand All @@ -114,6 +123,10 @@ public void stopPlayback() {
if (this.output != null) {
this.output.stopPlayback();
}
synchronized (lock) {
this.requests.clear();
this.synthesizing = false;
}
}

/**
Expand Down Expand Up @@ -183,6 +196,7 @@ public void prepare() throws Exception {
this.output.setAndroidContext(appContext);
this.ttsService.addListener(this.output);
}
this.ttsService.addListener(this);
for (TTSListener listener : this.listeners) {
this.ttsService.addListener(listener);
if (this.output != null) {
Expand Down Expand Up @@ -236,6 +250,35 @@ private void raiseError(Throwable e) {
}
}

@Override
public void eventReceived(@NonNull TTSEvent event) {
switch (event.type) {
case AUDIO_AVAILABLE:
case ERROR:
this.synthesizing = false;
processQueue();
default:
break;
}
}

private void processQueue() {
SynthesisRequest request = null;
if (!this.synthesizing) {
synchronized (lock) {
if (!this.synthesizing) {
request = this.requests.poll();
if (request != null) {
this.synthesizing = true;
}
}
}
}
if (request != null) {
this.ttsService.synthesize(request);
}
}

/**
* TTS manager builder.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,21 @@
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.Buffer;
import okio.BufferedSource;
import org.jetbrains.annotations.NotNull;
import org.junit.Before;
import org.junit.Test;

import java.io.IOException;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class SpokestackTTSClientTest {
private Response invalidResponse;
private static final String AUDIO_URL =
"https://spokestack.io/tts/test.mp3";
private static final Gson gson = new Gson();
private OkHttpClient httpClient;

Expand Down Expand Up @@ -223,14 +217,6 @@ public void onSynthesisResponse(AudioResponse response) {
*/
private class FakeResponder implements Interceptor {

private static final String TEXT_JSON =
"{\"data\": {\"synthesizeText\": {\"url\": \""
+ AUDIO_URL + "\"}}}";

private static final String ERROR_JSON =
"{\"data\": null, "
+ "\"errors\": [{\"message\": \"invalid_ssml\" }]}";

@NotNull
@Override
public Response intercept(@NotNull Chain chain) throws IOException {
Expand All @@ -241,30 +227,11 @@ public Response intercept(@NotNull Chain chain) throws IOException {
}
if (hasInvalidBody(request)) {
// simulate a GraphQL error, which are wrapped in HTTP 200s
return createResponse(request, ERROR_JSON);
}
return createResponse(request, TEXT_JSON);
}

private Response createResponse(Request request, String body)
throws IOException {
ResponseBody responseBody = mock(ResponseBody.class);
BufferedSource source = mock(BufferedSource.class);
when(source.readString(any(Charset.class))).thenReturn(body);
when(responseBody.source()).thenReturn(source);
Response.Builder builder = new Response.Builder()
.request(request)
.protocol(okhttp3.Protocol.HTTP_1_1)
.code(200)
.message("OK")
.body(responseBody);

String requestId = request.header("x-request-id");
if (requestId != null) {
builder.header("x-request-id", requestId);
return TTSTestUtils.createHttpResponse(request,
TTSTestUtils.ERROR_JSON);
}

return builder.build();
return TTSTestUtils.createHttpResponse(request,
TTSTestUtils.TEXT_JSON);
}

private boolean hasInvalidId(Request request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ private class FakeResponder implements Interceptor {
public Response intercept(@NotNull Chain chain) throws IOException {
Request request = chain.request();
RequestBody body = request.body();
String requestId = request.header("x-request-id");
Buffer buffer = new Buffer();
body.writeTo(buffer);
Map json = gson.fromJson(buffer.readUtf8(), Map.class);
Expand All @@ -169,10 +170,11 @@ public Response intercept(@NotNull Chain chain) throws IOException {
throw new IOException("test exc");
}

return createResponse(text == null);
return createResponse(requestId, text == null);
}

private Response createResponse(boolean isSsml) throws IOException {
private Response createResponse(String requestId,
boolean isSsml) throws IOException {
Request request = new okhttp3.Request.Builder()
.url("http://example.com/")
.build();
Expand All @@ -185,6 +187,7 @@ private Response createResponse(boolean isSsml) throws IOException {
when(body.source()).thenReturn(responseSource);
return new Response.Builder()
.request(request)
.header("x-request-id", requestId)
.protocol(okhttp3.Protocol.HTTP_1_1)
.code(200)
.message("OK")
Expand Down

0 comments on commit b327265

Please sign in to comment.