diff --git a/src/main/java/io/spokestack/spokestack/Spokestack.java b/src/main/java/io/spokestack/spokestack/Spokestack.java
index 71d8798..08a09d0 100644
--- a/src/main/java/io/spokestack/spokestack/Spokestack.java
+++ b/src/main/java/io/spokestack/spokestack/Spokestack.java
@@ -2,9 +2,8 @@
import android.content.Context;
import androidx.lifecycle.Lifecycle;
+import io.spokestack.spokestack.nlu.NLUManager;
import io.spokestack.spokestack.nlu.NLUResult;
-import io.spokestack.spokestack.nlu.NLUService;
-import io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU;
import io.spokestack.spokestack.tts.SynthesisRequest;
import io.spokestack.spokestack.tts.TTSManager;
import io.spokestack.spokestack.util.AsyncResult;
@@ -65,7 +64,7 @@
*
*
* @see SpeechPipeline
- * @see TensorflowNLU
+ * @see io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU
* @see TTSManager
*/
public final class Spokestack extends SpokestackAdapter
@@ -75,7 +74,7 @@ public final class Spokestack extends SpokestackAdapter
private final boolean autoClassify;
private final TranscriptEditor transcriptEditor;
private SpeechPipeline speechPipeline;
- private TensorflowNLU nlu;
+ private NLUManager nlu;
private TTSManager tts;
/**
@@ -109,6 +108,35 @@ private Spokestack(Builder builder) throws Exception {
}
}
+ /**
+ * Package-private constructor used for testing with an injected NLU.
+ *
+ * @param builder The builder to use for everything but NLU.
+ * @param nluManager The NLU manager to inject.
+ */
+ Spokestack(Builder builder, NLUManager nluManager) throws Exception {
+ this.listeners = new ArrayList<>();
+ this.listeners.addAll(builder.listeners);
+ this.autoClassify = builder.autoClassify;
+ this.transcriptEditor = builder.transcriptEditor;
+ if (builder.useAsr) {
+ this.speechPipeline = builder.getPipelineBuilder()
+ .addOnSpeechEventListener(this)
+ .build();
+ }
+ if (builder.useNLU) {
+ this.nlu = nluManager;
+ }
+ if (builder.useTTS) {
+ if (!builder.useTTSPlayback) {
+ builder.ttsBuilder.setOutputClass(null);
+ }
+ this.tts = builder.getTtsBuilder()
+ .addTTSListener(this)
+ .build();
+ }
+ }
+
// speech pipeline
/**
@@ -126,14 +154,18 @@ public SpeechPipeline getSpeechPipeline() {
* pipeline.
*/
public void start() throws Exception {
- this.speechPipeline.start();
+ if (this.speechPipeline != null) {
+ this.speechPipeline.start();
+ }
}
/**
* Stops the speech pipeline and releases all its internal resources.
*/
public void stop() {
- this.speechPipeline.stop();
+ if (this.speechPipeline != null) {
+ this.speechPipeline.stop();
+ }
}
/**
@@ -142,7 +174,9 @@ public void stop() {
* conjunction with a microphone button.
*/
public void activate() {
- this.speechPipeline.activate();
+ if (this.speechPipeline != null) {
+ this.speechPipeline.activate();
+ }
}
/**
@@ -156,15 +190,17 @@ public void activate() {
*
*/
public void deactivate() {
- this.speechPipeline.deactivate();
+ if (this.speechPipeline != null) {
+ this.speechPipeline.deactivate();
+ }
}
// NLU
/**
- * @return The NLU service currently in use.
+ * @return The NLU manager currently in use.
*/
- public NLUService getNlu() {
+ public NLUManager getNlu() {
return nlu;
}
@@ -184,7 +220,10 @@ public NLUService getNlu() {
* classification.
*/
public AsyncResult classify(String utterance) {
- return classifyInternal(utterance);
+ if (this.nlu != null) {
+ return classifyInternal(utterance);
+ }
+ return null;
}
// TTS
@@ -205,7 +244,9 @@ public TTSManager getTts() {
* @throws Exception If there is an error constructing TTS components.
*/
public void prepareTts() throws Exception {
- this.tts.prepare();
+ if (this.tts != null) {
+ this.tts.prepare();
+ }
}
/**
@@ -219,7 +260,9 @@ public void prepareTts() throws Exception {
*
*/
public void releaseTts() {
- this.tts.release();
+ if (this.tts != null) {
+ this.tts.release();
+ }
}
/**
@@ -229,14 +272,18 @@ public void releaseTts() {
* @param request The synthesis request data.
*/
public void synthesize(SynthesisRequest request) {
- this.tts.synthesize(request);
+ if (this.tts != null) {
+ this.tts.synthesize(request);
+ }
}
/**
* Stops playback of any playing or queued synthesis results.
*/
public void stopPlayback() {
- this.tts.stopPlayback();
+ if (this.tts != null) {
+ this.tts.stopPlayback();
+ }
}
// listeners
@@ -294,6 +341,11 @@ public void onEvent(@NotNull SpeechContext.Event event,
}
}
+ @Override
+ public void nluResult(@NotNull NLUResult result) {
+ super.nluResult(result);
+ }
+
private AsyncResult classifyInternal(String text) {
AsyncResult result =
this.nlu.classify(text);
@@ -322,7 +374,7 @@ public void close() {
*/
public static class Builder {
private final SpeechPipeline.Builder pipelineBuilder;
- private final TensorflowNLU.Builder nluBuilder;
+ private final NLUManager.Builder nluBuilder;
private final TTSManager.Builder ttsBuilder;
private final List listeners = new ArrayList<>();
@@ -434,7 +486,7 @@ public Builder() {
.setConfig(this.speechConfig)
.useProfile(profileClass);
this.nluBuilder =
- new TensorflowNLU.Builder().setConfig(this.speechConfig);
+ new NLUManager.Builder().setConfig(this.speechConfig);
String ttsServiceClass =
"io.spokestack.spokestack.tts.SpokestackTTSService";
String ttsOutputClass =
@@ -461,14 +513,12 @@ private void setDefaults(SpeechConfig config) {
* for testing.
*
* @param pipeline the speech pipeline builder
- * @param nlu the NLU builder
* @param tts the TTS builder
*/
- Builder(SpeechPipeline.Builder pipeline, TensorflowNLU.Builder nlu,
- TTSManager.Builder tts) {
+ Builder(SpeechPipeline.Builder pipeline, TTSManager.Builder tts) {
this.speechConfig = new SpeechConfig();
this.pipelineBuilder = pipeline;
- this.nluBuilder = nlu;
+ this.nluBuilder = new NLUManager.Builder();
this.ttsBuilder = tts;
}
@@ -482,7 +532,7 @@ public SpeechPipeline.Builder getPipelineBuilder() {
/**
* @return The builder used to configure the NLU subsystem.
*/
- public TensorflowNLU.Builder getNluBuilder() {
+ public NLUManager.Builder getNluBuilder() {
return nluBuilder;
}
diff --git a/src/main/java/io/spokestack/spokestack/nlu/NLUManager.java b/src/main/java/io/spokestack/spokestack/nlu/NLUManager.java
new file mode 100644
index 0000000..b765233
--- /dev/null
+++ b/src/main/java/io/spokestack/spokestack/nlu/NLUManager.java
@@ -0,0 +1,157 @@
+package io.spokestack.spokestack.nlu;
+
+import io.spokestack.spokestack.SpeechConfig;
+import io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU;
+import io.spokestack.spokestack.util.AsyncResult;
+import io.spokestack.spokestack.util.EventTracer;
+import io.spokestack.spokestack.util.TraceListener;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Manager for natural language understanding (NLU) components in Spokestack.
+ *
+ *
+ * Spokestack's NLU manager follows the same setup pattern as its {@link
+ * io.spokestack.spokestack.SpeechPipeline} and {@link io.spokestack.spokestack.tts.TTSManager}
+ * modules. The manager constructs the component ultimately responsible for
+ * classification (an {@link NLUService}) and manages the context required to
+ * perform these classifications and dispatch events to registered listeners.
+ *
+ */
+public final class NLUManager {
+ private final NLUService nlu;
+ private final NLUContext context;
+
+ /**
+ * Constructs a new {@code NLUManager} with an initialized NLU service.
+ *
+ * @param builder builder with configuration parameters
+ * @throws Exception If there is an error constructing the service.
+ */
+ private NLUManager(Builder builder) throws Exception {
+ this.context = builder.context;
+ this.nlu = buildService(builder);
+ }
+
+ private NLUService buildService(Builder builder) throws Exception {
+ Object constructed = Class
+ .forName(builder.serviceClass)
+ .getConstructor(SpeechConfig.class, NLUContext.class)
+ .newInstance(builder.config, builder.context);
+ return (NLUService) constructed;
+ }
+
+ /**
+ * Classify a user utterance, returning a wrapper that can either block
+ * until the classification is complete or call a registered callback when
+ * the result is ready.
+ *
+ * @param utterance The utterance to classify.
+ * @return An object representing the result of the asynchronous
+ * classification.
+ */
+ public AsyncResult classify(String utterance) {
+ return this.nlu.classify(utterance, this.context);
+ }
+
+ /**
+ * Add a new listener to receive trace events from the NLU subsystem.
+ *
+ * @param listener The listener to add.
+ */
+ public void addListener(TraceListener listener) {
+ this.context.addTraceListener(listener);
+ }
+
+ /**
+ * Remove a trace listener, allowing it to be garbage collected.
+ *
+ * @param listener The listener to remove.
+ */
+ public void removeListener(TraceListener listener) {
+ this.context.removeTraceListener(listener);
+ }
+
+ /**
+ * Fluent builder interface for initializing an NLU manager.
+ */
+ public static class Builder {
+ private final List traceListeners = new ArrayList<>();
+ private NLUContext context;
+ private SpeechConfig config = new SpeechConfig();
+ private String serviceClass;
+
+ /**
+ * Creates a new builder instance.
+ */
+ public Builder() {
+ config.put("trace-level", EventTracer.Level.ERROR.value());
+ this.serviceClass = TensorflowNLU.class.getCanonicalName();
+ }
+
+ /**
+ * Sets the name of the NLU service class to be used.
+ *
+ * @param className The name of the NLU service class to be used.
+ * @return this
+ */
+ public Builder setServiceClass(String className) {
+ this.serviceClass = className;
+ return this;
+ }
+
+ /**
+ * Attaches a configuration object, overwriting any existing
+ * configuration.
+ *
+ * @param value configuration to attach
+ * @return this
+ */
+ public Builder setConfig(SpeechConfig value) {
+ this.config = value;
+ return this;
+ }
+
+ /**
+ * Sets a configuration value.
+ *
+ * @param key configuration property name
+ * @param value property value
+ * @return this
+ */
+ public Builder setProperty(String key, Object value) {
+ config.put(key, value);
+ return this;
+ }
+
+ /**
+ * Adds a trace listener to receive events from the NLU system.
+ *
+ * @param listener the listener to register
+ * @return this
+ */
+ public Builder addTraceListener(TraceListener listener) {
+ this.traceListeners.add(listener);
+ return this;
+ }
+
+ /**
+ * Create a new NLU service, automatically loading any necessary
+ * resources in the background. Any errors encountered during loading
+ * will be reported to registered {@link TraceListener}s.
+ *
+ * @return An initialized {@code NLUManager} instance
+ * @throws Exception If there is an error constructing the NLU service.
+ */
+ public NLUManager build() throws Exception {
+ this.context = new NLUContext(this.config);
+ for (TraceListener listener : this.traceListeners) {
+ this.context.addTraceListener(listener);
+ }
+ return new NLUManager(this);
+ }
+
+ }
+}
diff --git a/src/main/java/io/spokestack/spokestack/nlu/NLUService.java b/src/main/java/io/spokestack/spokestack/nlu/NLUService.java
index a3da21e..78afb6f 100644
--- a/src/main/java/io/spokestack/spokestack/nlu/NLUService.java
+++ b/src/main/java/io/spokestack/spokestack/nlu/NLUService.java
@@ -5,6 +5,12 @@
/**
* A simple interface for components that provide intent classification and slot
* recognition, either on-device or via a network request.
+ *
+ *
+ * To participate in Spokestack's {@link NLUManager}, an NLUService must have a
+ * constructor that accepts instances of {@link io.spokestack.spokestack.SpeechConfig}
+ * and {@link NLUContext}.
+ *
*/
public interface NLUService {
diff --git a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java
index 13a11dc..7187094 100644
--- a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java
+++ b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java
@@ -8,7 +8,6 @@
import io.spokestack.spokestack.nlu.NLUResult;
import io.spokestack.spokestack.nlu.NLUService;
import io.spokestack.spokestack.nlu.Slot;
-import io.spokestack.spokestack.util.TraceListener;
import io.spokestack.spokestack.nlu.tensorflow.parsers.DigitsParser;
import io.spokestack.spokestack.nlu.tensorflow.parsers.IdentityParser;
import io.spokestack.spokestack.nlu.tensorflow.parsers.IntegerParser;
@@ -16,6 +15,7 @@
import io.spokestack.spokestack.tensorflow.TensorflowModel;
import io.spokestack.spokestack.util.AsyncResult;
import io.spokestack.spokestack.util.EventTracer;
+import io.spokestack.spokestack.util.TraceListener;
import io.spokestack.spokestack.util.Tuple;
import java.io.FileReader;
@@ -55,19 +55,25 @@
* wordpiece-vocab-path (string, required): file system path to the
* wordpiece vocabulary file used by the wordpiece token encoder.
*
+ *
+ * slot-<slotType> (string, optional): class name of a slot
+ * parser capable of parsing slots with the {@code slotType} type. For
+ * example, a custom slot parser used to parse slots listed as {@code user}
+ * in the NLU metadata should be provided under the key {@code slot-user}.
+ *
*
*/
public final class TensorflowNLU implements NLUService {
private final ExecutorService executor =
Executors.newSingleThreadExecutor();
- private final TextEncoder textEncoder;
private final NLUContext context;
- private final int sepTokenId;
- private final int padTokenId;
- private final Thread loadThread;
- private TensorflowModel nluModel = null;
- private TFNLUOutput outputParser = null;
+ private TextEncoder textEncoder;
+ private int sepTokenId;
+ private int padTokenId;
+ private Thread loadThread;
+ private TensorflowModel nluModel;
+ private TFNLUOutput outputParser;
private int maxTokens;
private volatile boolean ready = false;
@@ -80,18 +86,72 @@ public final class TensorflowNLU implements NLUService {
* @param builder builder with configuration parameters
*/
private TensorflowNLU(Builder builder) {
- String modelPath = builder.config.getString("nlu-model-path");
- String metadataPath = builder.config.getString("nlu-metadata-path");
this.context = builder.context;
- this.textEncoder = builder.textEncoder;
- this.loadThread = builder.threadFactory.newThread(
+ SpeechConfig config = transferSlotParsers(
+ builder.slotParserClasses, builder.config);
+ load(config,
+ builder.textEncoder,
+ builder.modelLoader,
+ builder.threadFactory);
+ }
+
+ private SpeechConfig transferSlotParsers(Map parserClasses,
+ SpeechConfig config) {
+ for (Map.Entry parser : parserClasses.entrySet()) {
+ config.put("slot-" + parser.getKey(), parser.getValue());
+ }
+ return config;
+ }
+
+ /**
+ * Public constructor for {@code NLUManager} participation. Uses a {@link
+ * WordpieceTextEncoder} and a default TensorFlow model loader.
+ *
+ *
+ * The model and text encoder are loaded on a background thread.
+ *
+ *
+ * @param speechConfig configuration properties
+ * @param nluContext The context used to register listeners and deliver
+ * trace and error events.
+ */
+ public TensorflowNLU(SpeechConfig speechConfig, NLUContext nluContext) {
+ this.context = nluContext;
+ load(speechConfig,
+ new WordpieceTextEncoder(speechConfig, this.context),
+ new TensorflowModel.Loader(),
+ Thread::new);
+ }
+
+ private void load(SpeechConfig config,
+ TextEncoder encoder,
+ TensorflowModel.Loader loader,
+ ThreadFactory threadFactory) {
+ String modelPath = config.getString("nlu-model-path");
+ String metadataPath = config.getString("nlu-metadata-path");
+ Map slotParsers = getSlotParsers(config);
+ this.textEncoder = encoder;
+ this.loadThread = threadFactory.newThread(
() -> {
- loadModel(builder.modelLoader, metadataPath, modelPath);
- initParsers(builder.slotParserClasses);
+ loadModel(loader, metadataPath, modelPath);
+ initParsers(slotParsers);
});
this.loadThread.start();
- this.padTokenId = this.textEncoder.encodeSingle("[PAD]");
- this.sepTokenId = this.textEncoder.encodeSingle("[SEP]");
+
+ this.padTokenId = encoder.encodeSingle("[PAD]");
+ this.sepTokenId = encoder.encodeSingle("[SEP]");
+ }
+
+ private Map getSlotParsers(SpeechConfig config) {
+ HashMap slotParsers = new HashMap<>();
+
+ for (Map.Entry prop : config.getParams().entrySet()) {
+ if (prop.getKey().startsWith("slot-")) {
+ String slotType = prop.getKey().replace("slot-", "");
+ slotParsers.put(slotType, String.valueOf(prop.getValue()));
+ }
+ }
+ return slotParsers;
}
private void initParsers(Map parserClasses) {
@@ -104,6 +164,7 @@ private void initParsers(Map parserClasses) {
.newInstance();
slotParsers.put(slotType, parser);
this.outputParser.registerSlotParsers(slotParsers);
+ this.ready = true;
} catch (Exception e) {
this.context.traceError("Error loading slot parsers: %s",
e.getLocalizedMessage());
@@ -126,7 +187,6 @@ private void loadModel(TensorflowModel.Loader loader,
/ this.nluModel.getInputSize();
this.outputParser = new TFNLUOutput(metadata);
warmup();
- this.ready = true;
} catch (IOException e) {
this.context.traceError("Error loading NLU model: %s",
e.getLocalizedMessage());
@@ -261,6 +321,7 @@ private int[] pad(List ids) {
/**
* Add a new listener to receive trace events from the NLU subsystem.
+ *
* @param listener The listener to add.
*/
public void addListener(TraceListener listener) {
@@ -269,6 +330,7 @@ public void addListener(TraceListener listener) {
/**
* Remove a trace listener, allowing it to be garbage collected.
+ *
* @param listener The listener to remove.
*/
public void removeListener(TraceListener listener) {
@@ -328,7 +390,7 @@ public Builder setConfig(SpeechConfig value) {
* @param loader The TensorFlow model loader to use.
* @return this
*/
- public Builder setModelLoader(TensorflowModel.Loader loader) {
+ Builder setModelLoader(TensorflowModel.Loader loader) {
this.modelLoader = loader;
return this;
}
@@ -339,7 +401,7 @@ public Builder setModelLoader(TensorflowModel.Loader loader) {
* @param encoder The text encoder to use.
* @return this
*/
- public Builder setTextEncoder(TextEncoder encoder) {
+ Builder setTextEncoder(TextEncoder encoder) {
this.textEncoder = encoder;
return this;
}
diff --git a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java
index 7015798..3596ec9 100644
--- a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java
+++ b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java
@@ -28,11 +28,12 @@ final class WordpieceTextEncoder implements TextEncoder {
private static final String UNKNOWN = "[UNK]";
private static final String SUFFIX_MARKER = "##";
- private NLUContext context;
+ private final Thread loadThread;
+ private final NLUContext context;
+
private HashMap vocabulary;
private volatile boolean ready = false;
- private Thread loadThread;
/**
* Creates a new Wordpiece token encoder.
@@ -56,7 +57,7 @@ final class WordpieceTextEncoder implements TextEncoder {
* thread.
*/
WordpieceTextEncoder(SpeechConfig config, NLUContext nluContext,
- ThreadFactory threadFactory) {
+ ThreadFactory threadFactory) {
String vocabFile = config.getString("wordpiece-vocab-path");
this.context = nluContext;
this.loadThread = threadFactory.newThread(() -> loadVocab(vocabFile));
diff --git a/src/test/java/io/spokestack/spokestack/SpokestackTest.java b/src/test/java/io/spokestack/spokestack/SpokestackTest.java
index 07c625b..18b42ea 100644
--- a/src/test/java/io/spokestack/spokestack/SpokestackTest.java
+++ b/src/test/java/io/spokestack/spokestack/SpokestackTest.java
@@ -4,10 +4,9 @@
import android.os.SystemClock;
import androidx.lifecycle.LifecycleOwner;
import androidx.lifecycle.LifecycleRegistry;
-import io.spokestack.spokestack.nlu.NLUContext;
+import io.spokestack.spokestack.nlu.NLUManager;
import io.spokestack.spokestack.nlu.NLUResult;
import io.spokestack.spokestack.nlu.tensorflow.NLUTestUtils;
-import io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU;
import io.spokestack.spokestack.tts.SynthesisRequest;
import io.spokestack.spokestack.tts.TTSEvent;
import io.spokestack.spokestack.tts.TTSManager;
@@ -72,12 +71,6 @@ public void testBuild() throws Exception {
.setProperty("test", "test")
.build();
- // no subsystems exist to handle these calls
- assertThrows(NullPointerException.class, spokestack::start);
- assertThrows(NullPointerException.class,
- () -> spokestack.classify("test"));
- assertThrows(NullPointerException.class, spokestack::stopPlayback);
-
// closing with no active subsystems is fine
assertDoesNotThrow(spokestack::close);
}
@@ -123,13 +116,13 @@ public void testSpeechPipeline() throws Exception {
public void testNlu() throws Exception {
TestAdapter listener = new TestAdapter();
- Spokestack spokestack = new Spokestack
- .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts())
+ Spokestack.Builder builder = new Spokestack
+ .Builder(new SpeechPipeline.Builder(), mockTts())
.addListener(listener)
.withoutWakeword()
- .withoutTts()
- .build();
+ .withoutTts();
+ Spokestack spokestack = new Spokestack(builder, mockNlu());
listener.setSpokestack(spokestack);
NLUResult result = spokestack.classify("test").get();
@@ -137,8 +130,7 @@ public void testNlu() throws Exception {
assertNotNull(lastResult);
assertEquals(result.getIntent(), lastResult.getIntent());
- NLUContext fakeContext = new NLUContext(testConfig());
- result = spokestack.getNlu().classify("test", fakeContext).get();
+ result = spokestack.getNlu().classify("test").get();
assertEquals(result.getIntent(), lastResult.getIntent());
// classification is called automatically on ASR results
@@ -156,12 +148,13 @@ public void testAutoClassification() throws Exception {
TestAdapter listener = new TestAdapter();
Spokestack.Builder builder = new Spokestack
- .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts())
+ .Builder(new SpeechPipeline.Builder(), mockTts())
.withoutWakeword()
.withoutAutoClassification()
.addListener(listener);
- Spokestack spokestack = mockAndroidComponents(builder).build();
+ builder = mockAndroidComponents(builder);
+ Spokestack spokestack = new Spokestack(builder, mockNlu());
// automatic classification can be disabled
listener.clear();
@@ -178,12 +171,13 @@ public void testTranscriptEditing() throws Exception {
TestAdapter listener = new TestAdapter();
Spokestack.Builder builder = new Spokestack
- .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts())
+ .Builder(new SpeechPipeline.Builder(), mockTts())
.withoutWakeword()
.addListener(listener)
.withTranscriptEditor(String::toUpperCase);
- Spokestack spokestack = mockAndroidComponents(builder).build();
+ builder = mockAndroidComponents(builder);
+ Spokestack spokestack = new Spokestack(builder, mockNlu());
// transcripts can be edited before automatic classification
String transcript = "test";
@@ -191,7 +185,7 @@ public void testTranscriptEditing() throws Exception {
SpeechContext context = spokestack.getSpeechPipeline().getContext();
context.setTranscript(transcript);
context.dispatch(SpeechContext.Event.RECOGNIZE);
- NLUResult result = spokestack.classify("test").get();
+ NLUResult result = spokestack.classify("TEST").get();
NLUResult lastResult = listener.nluResults.poll(1, TimeUnit.SECONDS);
assertNotNull(lastResult);
assertEquals(result.getIntent(), lastResult.getIntent());
@@ -211,11 +205,12 @@ public void testTts() throws Exception {
SpeechPipeline.Builder pipelineBuilder = new SpeechPipeline.Builder();
Spokestack.Builder builder = new Spokestack
- .Builder(pipelineBuilder, mockNlu(), mockTts())
+ .Builder(pipelineBuilder, mockTts())
.withoutWakeword()
.addListener(listener);
- Spokestack spokestack = mockAndroidComponents(builder).build();
+ builder = mockAndroidComponents(builder);
+ Spokestack spokestack = new Spokestack(builder, mockNlu());
listener.setSpokestack(spokestack);
TTSManager tts = spokestack.getTts();
@@ -253,7 +248,7 @@ public void testListenerManagement() throws Exception {
TestAdapter listener2 = new TestAdapter();
Spokestack.Builder builder = new Spokestack
- .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts())
+ .Builder(new SpeechPipeline.Builder(), mockTts())
.withoutWakeword()
.setConfig(testConfig())
.setProperty("trace-level", EventTracer.Level.INFO.value())
@@ -261,7 +256,7 @@ public void testListenerManagement() throws Exception {
builder = mockAndroidComponents(builder);
builder.getPipelineBuilder().setStageClasses(new ArrayList<>());
- Spokestack spokestack = builder.build();
+ Spokestack spokestack = new Spokestack(builder, mockNlu());
spokestack.addListener(listener2);
spokestack.getSpeechPipeline().activate();
@@ -295,11 +290,8 @@ public void testListenerManagement() throws Exception {
assertEquals(error, err);
}
- private TensorflowNLU.Builder mockNlu() throws Exception {
- mockStatic(SystemClock.class);
-
- NLUTestUtils.TestEnv nluEnv = new NLUTestUtils.TestEnv();
- return nluEnv.nluBuilder;
+ private NLUManager mockNlu() throws Exception {
+ return NLUTestUtils.mockManager();
}
private TTSManager.Builder mockTts() {
diff --git a/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java b/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java
index 1bdc103..04631e9 100644
--- a/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java
+++ b/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java
@@ -1,11 +1,14 @@
package io.spokestack.spokestack.nlu.tensorflow;
-import android.os.SystemClock;
import com.google.gson.Gson;
import com.google.gson.stream.JsonReader;
import io.spokestack.spokestack.SpeechConfig;
+import io.spokestack.spokestack.nlu.NLUContext;
+import io.spokestack.spokestack.nlu.NLUManager;
import io.spokestack.spokestack.nlu.NLUResult;
+import io.spokestack.spokestack.nlu.NLUService;
import io.spokestack.spokestack.tensorflow.TensorflowModel;
+import io.spokestack.spokestack.util.AsyncResult;
import java.io.FileNotFoundException;
import java.io.FileReader;
@@ -16,7 +19,6 @@
import java.util.concurrent.Future;
import static org.mockito.Mockito.*;
-import static org.powermock.api.mockito.PowerMockito.mockStatic;
public class NLUTestUtils {
@@ -27,6 +29,13 @@ public static SpeechConfig testConfig() {
.put("wordpiece-vocab-path", "src/test/resources/vocab.txt");
}
+ public static NLUManager mockManager() throws Exception {
+ return new NLUManager.Builder()
+ .setServiceClass(NLUTestUtils.class.getCanonicalName()
+ + "$MockNLU")
+ .build();
+ }
+
public static class TestModel extends TensorflowModel {
public TestModel(TensorflowModel.Loader loader) {
super(loader);
@@ -134,4 +143,25 @@ public EncodedTokens encode(String text) {
return encoded;
}
}
+
+ public static class MockNLU implements NLUService {
+
+ public MockNLU(SpeechConfig config, NLUContext context) {
+ // empty constructor so it can be built by the manager
+ }
+
+ @Override
+ public AsyncResult classify(
+ String utterance,
+ NLUContext context
+ ) {
+ AsyncResult result = new AsyncResult<>(() ->
+ new NLUResult.Builder(utterance)
+ .withIntent(utterance)
+ .build());
+ result.run();
+ return result;
+ }
+
+ }
}