diff --git a/src/main/java/io/spokestack/spokestack/Spokestack.java b/src/main/java/io/spokestack/spokestack/Spokestack.java index 71d8798..08a09d0 100644 --- a/src/main/java/io/spokestack/spokestack/Spokestack.java +++ b/src/main/java/io/spokestack/spokestack/Spokestack.java @@ -2,9 +2,8 @@ import android.content.Context; import androidx.lifecycle.Lifecycle; +import io.spokestack.spokestack.nlu.NLUManager; import io.spokestack.spokestack.nlu.NLUResult; -import io.spokestack.spokestack.nlu.NLUService; -import io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU; import io.spokestack.spokestack.tts.SynthesisRequest; import io.spokestack.spokestack.tts.TTSManager; import io.spokestack.spokestack.util.AsyncResult; @@ -65,7 +64,7 @@ *

* * @see SpeechPipeline - * @see TensorflowNLU + * @see io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU * @see TTSManager */ public final class Spokestack extends SpokestackAdapter @@ -75,7 +74,7 @@ public final class Spokestack extends SpokestackAdapter private final boolean autoClassify; private final TranscriptEditor transcriptEditor; private SpeechPipeline speechPipeline; - private TensorflowNLU nlu; + private NLUManager nlu; private TTSManager tts; /** @@ -109,6 +108,35 @@ private Spokestack(Builder builder) throws Exception { } } + /** + * Package-private constructor used for testing with an injected NLU. + * + * @param builder The builder to use for everything but NLU. + * @param nluManager The NLU manager to inject. + */ + Spokestack(Builder builder, NLUManager nluManager) throws Exception { + this.listeners = new ArrayList<>(); + this.listeners.addAll(builder.listeners); + this.autoClassify = builder.autoClassify; + this.transcriptEditor = builder.transcriptEditor; + if (builder.useAsr) { + this.speechPipeline = builder.getPipelineBuilder() + .addOnSpeechEventListener(this) + .build(); + } + if (builder.useNLU) { + this.nlu = nluManager; + } + if (builder.useTTS) { + if (!builder.useTTSPlayback) { + builder.ttsBuilder.setOutputClass(null); + } + this.tts = builder.getTtsBuilder() + .addTTSListener(this) + .build(); + } + } + // speech pipeline /** @@ -126,14 +154,18 @@ public SpeechPipeline getSpeechPipeline() { * pipeline. */ public void start() throws Exception { - this.speechPipeline.start(); + if (this.speechPipeline != null) { + this.speechPipeline.start(); + } } /** * Stops the speech pipeline and releases all its internal resources. */ public void stop() { - this.speechPipeline.stop(); + if (this.speechPipeline != null) { + this.speechPipeline.stop(); + } } /** @@ -142,7 +174,9 @@ public void stop() { * conjunction with a microphone button. */ public void activate() { - this.speechPipeline.activate(); + if (this.speechPipeline != null) { + this.speechPipeline.activate(); + } } /** @@ -156,15 +190,17 @@ public void activate() { *

*/ public void deactivate() { - this.speechPipeline.deactivate(); + if (this.speechPipeline != null) { + this.speechPipeline.deactivate(); + } } // NLU /** - * @return The NLU service currently in use. + * @return The NLU manager currently in use. */ - public NLUService getNlu() { + public NLUManager getNlu() { return nlu; } @@ -184,7 +220,10 @@ public NLUService getNlu() { * classification. */ public AsyncResult classify(String utterance) { - return classifyInternal(utterance); + if (this.nlu != null) { + return classifyInternal(utterance); + } + return null; } // TTS @@ -205,7 +244,9 @@ public TTSManager getTts() { * @throws Exception If there is an error constructing TTS components. */ public void prepareTts() throws Exception { - this.tts.prepare(); + if (this.tts != null) { + this.tts.prepare(); + } } /** @@ -219,7 +260,9 @@ public void prepareTts() throws Exception { *

*/ public void releaseTts() { - this.tts.release(); + if (this.tts != null) { + this.tts.release(); + } } /** @@ -229,14 +272,18 @@ public void releaseTts() { * @param request The synthesis request data. */ public void synthesize(SynthesisRequest request) { - this.tts.synthesize(request); + if (this.tts != null) { + this.tts.synthesize(request); + } } /** * Stops playback of any playing or queued synthesis results. */ public void stopPlayback() { - this.tts.stopPlayback(); + if (this.tts != null) { + this.tts.stopPlayback(); + } } // listeners @@ -294,6 +341,11 @@ public void onEvent(@NotNull SpeechContext.Event event, } } + @Override + public void nluResult(@NotNull NLUResult result) { + super.nluResult(result); + } + private AsyncResult classifyInternal(String text) { AsyncResult result = this.nlu.classify(text); @@ -322,7 +374,7 @@ public void close() { */ public static class Builder { private final SpeechPipeline.Builder pipelineBuilder; - private final TensorflowNLU.Builder nluBuilder; + private final NLUManager.Builder nluBuilder; private final TTSManager.Builder ttsBuilder; private final List listeners = new ArrayList<>(); @@ -434,7 +486,7 @@ public Builder() { .setConfig(this.speechConfig) .useProfile(profileClass); this.nluBuilder = - new TensorflowNLU.Builder().setConfig(this.speechConfig); + new NLUManager.Builder().setConfig(this.speechConfig); String ttsServiceClass = "io.spokestack.spokestack.tts.SpokestackTTSService"; String ttsOutputClass = @@ -461,14 +513,12 @@ private void setDefaults(SpeechConfig config) { * for testing. * * @param pipeline the speech pipeline builder - * @param nlu the NLU builder * @param tts the TTS builder */ - Builder(SpeechPipeline.Builder pipeline, TensorflowNLU.Builder nlu, - TTSManager.Builder tts) { + Builder(SpeechPipeline.Builder pipeline, TTSManager.Builder tts) { this.speechConfig = new SpeechConfig(); this.pipelineBuilder = pipeline; - this.nluBuilder = nlu; + this.nluBuilder = new NLUManager.Builder(); this.ttsBuilder = tts; } @@ -482,7 +532,7 @@ public SpeechPipeline.Builder getPipelineBuilder() { /** * @return The builder used to configure the NLU subsystem. */ - public TensorflowNLU.Builder getNluBuilder() { + public NLUManager.Builder getNluBuilder() { return nluBuilder; } diff --git a/src/main/java/io/spokestack/spokestack/nlu/NLUManager.java b/src/main/java/io/spokestack/spokestack/nlu/NLUManager.java new file mode 100644 index 0000000..b765233 --- /dev/null +++ b/src/main/java/io/spokestack/spokestack/nlu/NLUManager.java @@ -0,0 +1,157 @@ +package io.spokestack.spokestack.nlu; + +import io.spokestack.spokestack.SpeechConfig; +import io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU; +import io.spokestack.spokestack.util.AsyncResult; +import io.spokestack.spokestack.util.EventTracer; +import io.spokestack.spokestack.util.TraceListener; + +import java.util.ArrayList; +import java.util.List; + +/** + * Manager for natural language understanding (NLU) components in Spokestack. + * + *

+ * Spokestack's NLU manager follows the same setup pattern as its {@link + * io.spokestack.spokestack.SpeechPipeline} and {@link io.spokestack.spokestack.tts.TTSManager} + * modules. The manager constructs the component ultimately responsible for + * classification (an {@link NLUService}) and manages the context required to + * perform these classifications and dispatch events to registered listeners. + *

+ */ +public final class NLUManager { + private final NLUService nlu; + private final NLUContext context; + + /** + * Constructs a new {@code NLUManager} with an initialized NLU service. + * + * @param builder builder with configuration parameters + * @throws Exception If there is an error constructing the service. + */ + private NLUManager(Builder builder) throws Exception { + this.context = builder.context; + this.nlu = buildService(builder); + } + + private NLUService buildService(Builder builder) throws Exception { + Object constructed = Class + .forName(builder.serviceClass) + .getConstructor(SpeechConfig.class, NLUContext.class) + .newInstance(builder.config, builder.context); + return (NLUService) constructed; + } + + /** + * Classify a user utterance, returning a wrapper that can either block + * until the classification is complete or call a registered callback when + * the result is ready. + * + * @param utterance The utterance to classify. + * @return An object representing the result of the asynchronous + * classification. + */ + public AsyncResult classify(String utterance) { + return this.nlu.classify(utterance, this.context); + } + + /** + * Add a new listener to receive trace events from the NLU subsystem. + * + * @param listener The listener to add. + */ + public void addListener(TraceListener listener) { + this.context.addTraceListener(listener); + } + + /** + * Remove a trace listener, allowing it to be garbage collected. + * + * @param listener The listener to remove. + */ + public void removeListener(TraceListener listener) { + this.context.removeTraceListener(listener); + } + + /** + * Fluent builder interface for initializing an NLU manager. + */ + public static class Builder { + private final List traceListeners = new ArrayList<>(); + private NLUContext context; + private SpeechConfig config = new SpeechConfig(); + private String serviceClass; + + /** + * Creates a new builder instance. + */ + public Builder() { + config.put("trace-level", EventTracer.Level.ERROR.value()); + this.serviceClass = TensorflowNLU.class.getCanonicalName(); + } + + /** + * Sets the name of the NLU service class to be used. + * + * @param className The name of the NLU service class to be used. + * @return this + */ + public Builder setServiceClass(String className) { + this.serviceClass = className; + return this; + } + + /** + * Attaches a configuration object, overwriting any existing + * configuration. + * + * @param value configuration to attach + * @return this + */ + public Builder setConfig(SpeechConfig value) { + this.config = value; + return this; + } + + /** + * Sets a configuration value. + * + * @param key configuration property name + * @param value property value + * @return this + */ + public Builder setProperty(String key, Object value) { + config.put(key, value); + return this; + } + + /** + * Adds a trace listener to receive events from the NLU system. + * + * @param listener the listener to register + * @return this + */ + public Builder addTraceListener(TraceListener listener) { + this.traceListeners.add(listener); + return this; + } + + /** + * Create a new NLU service, automatically loading any necessary + * resources in the background. Any errors encountered during loading + * will be reported to registered {@link TraceListener}s. + * + * @return An initialized {@code NLUManager} instance + * @throws Exception If there is an error constructing the NLU service. + */ + public NLUManager build() throws Exception { + this.context = new NLUContext(this.config); + for (TraceListener listener : this.traceListeners) { + this.context.addTraceListener(listener); + } + return new NLUManager(this); + } + + } +} diff --git a/src/main/java/io/spokestack/spokestack/nlu/NLUService.java b/src/main/java/io/spokestack/spokestack/nlu/NLUService.java index a3da21e..78afb6f 100644 --- a/src/main/java/io/spokestack/spokestack/nlu/NLUService.java +++ b/src/main/java/io/spokestack/spokestack/nlu/NLUService.java @@ -5,6 +5,12 @@ /** * A simple interface for components that provide intent classification and slot * recognition, either on-device or via a network request. + * + *

+ * To participate in Spokestack's {@link NLUManager}, an NLUService must have a + * constructor that accepts instances of {@link io.spokestack.spokestack.SpeechConfig} + * and {@link NLUContext}. + *

*/ public interface NLUService { diff --git a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java index 13a11dc..7187094 100644 --- a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java +++ b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/TensorflowNLU.java @@ -8,7 +8,6 @@ import io.spokestack.spokestack.nlu.NLUResult; import io.spokestack.spokestack.nlu.NLUService; import io.spokestack.spokestack.nlu.Slot; -import io.spokestack.spokestack.util.TraceListener; import io.spokestack.spokestack.nlu.tensorflow.parsers.DigitsParser; import io.spokestack.spokestack.nlu.tensorflow.parsers.IdentityParser; import io.spokestack.spokestack.nlu.tensorflow.parsers.IntegerParser; @@ -16,6 +15,7 @@ import io.spokestack.spokestack.tensorflow.TensorflowModel; import io.spokestack.spokestack.util.AsyncResult; import io.spokestack.spokestack.util.EventTracer; +import io.spokestack.spokestack.util.TraceListener; import io.spokestack.spokestack.util.Tuple; import java.io.FileReader; @@ -55,19 +55,25 @@ * wordpiece-vocab-path (string, required): file system path to the * wordpiece vocabulary file used by the wordpiece token encoder. * + *
  • + * slot-<slotType> (string, optional): class name of a slot + * parser capable of parsing slots with the {@code slotType} type. For + * example, a custom slot parser used to parse slots listed as {@code user} + * in the NLU metadata should be provided under the key {@code slot-user}. + *
  • * */ public final class TensorflowNLU implements NLUService { private final ExecutorService executor = Executors.newSingleThreadExecutor(); - private final TextEncoder textEncoder; private final NLUContext context; - private final int sepTokenId; - private final int padTokenId; - private final Thread loadThread; - private TensorflowModel nluModel = null; - private TFNLUOutput outputParser = null; + private TextEncoder textEncoder; + private int sepTokenId; + private int padTokenId; + private Thread loadThread; + private TensorflowModel nluModel; + private TFNLUOutput outputParser; private int maxTokens; private volatile boolean ready = false; @@ -80,18 +86,72 @@ public final class TensorflowNLU implements NLUService { * @param builder builder with configuration parameters */ private TensorflowNLU(Builder builder) { - String modelPath = builder.config.getString("nlu-model-path"); - String metadataPath = builder.config.getString("nlu-metadata-path"); this.context = builder.context; - this.textEncoder = builder.textEncoder; - this.loadThread = builder.threadFactory.newThread( + SpeechConfig config = transferSlotParsers( + builder.slotParserClasses, builder.config); + load(config, + builder.textEncoder, + builder.modelLoader, + builder.threadFactory); + } + + private SpeechConfig transferSlotParsers(Map parserClasses, + SpeechConfig config) { + for (Map.Entry parser : parserClasses.entrySet()) { + config.put("slot-" + parser.getKey(), parser.getValue()); + } + return config; + } + + /** + * Public constructor for {@code NLUManager} participation. Uses a {@link + * WordpieceTextEncoder} and a default TensorFlow model loader. + * + *

    + * The model and text encoder are loaded on a background thread. + *

    + * + * @param speechConfig configuration properties + * @param nluContext The context used to register listeners and deliver + * trace and error events. + */ + public TensorflowNLU(SpeechConfig speechConfig, NLUContext nluContext) { + this.context = nluContext; + load(speechConfig, + new WordpieceTextEncoder(speechConfig, this.context), + new TensorflowModel.Loader(), + Thread::new); + } + + private void load(SpeechConfig config, + TextEncoder encoder, + TensorflowModel.Loader loader, + ThreadFactory threadFactory) { + String modelPath = config.getString("nlu-model-path"); + String metadataPath = config.getString("nlu-metadata-path"); + Map slotParsers = getSlotParsers(config); + this.textEncoder = encoder; + this.loadThread = threadFactory.newThread( () -> { - loadModel(builder.modelLoader, metadataPath, modelPath); - initParsers(builder.slotParserClasses); + loadModel(loader, metadataPath, modelPath); + initParsers(slotParsers); }); this.loadThread.start(); - this.padTokenId = this.textEncoder.encodeSingle("[PAD]"); - this.sepTokenId = this.textEncoder.encodeSingle("[SEP]"); + + this.padTokenId = encoder.encodeSingle("[PAD]"); + this.sepTokenId = encoder.encodeSingle("[SEP]"); + } + + private Map getSlotParsers(SpeechConfig config) { + HashMap slotParsers = new HashMap<>(); + + for (Map.Entry prop : config.getParams().entrySet()) { + if (prop.getKey().startsWith("slot-")) { + String slotType = prop.getKey().replace("slot-", ""); + slotParsers.put(slotType, String.valueOf(prop.getValue())); + } + } + return slotParsers; } private void initParsers(Map parserClasses) { @@ -104,6 +164,7 @@ private void initParsers(Map parserClasses) { .newInstance(); slotParsers.put(slotType, parser); this.outputParser.registerSlotParsers(slotParsers); + this.ready = true; } catch (Exception e) { this.context.traceError("Error loading slot parsers: %s", e.getLocalizedMessage()); @@ -126,7 +187,6 @@ private void loadModel(TensorflowModel.Loader loader, / this.nluModel.getInputSize(); this.outputParser = new TFNLUOutput(metadata); warmup(); - this.ready = true; } catch (IOException e) { this.context.traceError("Error loading NLU model: %s", e.getLocalizedMessage()); @@ -261,6 +321,7 @@ private int[] pad(List ids) { /** * Add a new listener to receive trace events from the NLU subsystem. + * * @param listener The listener to add. */ public void addListener(TraceListener listener) { @@ -269,6 +330,7 @@ public void addListener(TraceListener listener) { /** * Remove a trace listener, allowing it to be garbage collected. + * * @param listener The listener to remove. */ public void removeListener(TraceListener listener) { @@ -328,7 +390,7 @@ public Builder setConfig(SpeechConfig value) { * @param loader The TensorFlow model loader to use. * @return this */ - public Builder setModelLoader(TensorflowModel.Loader loader) { + Builder setModelLoader(TensorflowModel.Loader loader) { this.modelLoader = loader; return this; } @@ -339,7 +401,7 @@ public Builder setModelLoader(TensorflowModel.Loader loader) { * @param encoder The text encoder to use. * @return this */ - public Builder setTextEncoder(TextEncoder encoder) { + Builder setTextEncoder(TextEncoder encoder) { this.textEncoder = encoder; return this; } diff --git a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java index 7015798..3596ec9 100644 --- a/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java +++ b/src/main/java/io/spokestack/spokestack/nlu/tensorflow/WordpieceTextEncoder.java @@ -28,11 +28,12 @@ final class WordpieceTextEncoder implements TextEncoder { private static final String UNKNOWN = "[UNK]"; private static final String SUFFIX_MARKER = "##"; - private NLUContext context; + private final Thread loadThread; + private final NLUContext context; + private HashMap vocabulary; private volatile boolean ready = false; - private Thread loadThread; /** * Creates a new Wordpiece token encoder. @@ -56,7 +57,7 @@ final class WordpieceTextEncoder implements TextEncoder { * thread. */ WordpieceTextEncoder(SpeechConfig config, NLUContext nluContext, - ThreadFactory threadFactory) { + ThreadFactory threadFactory) { String vocabFile = config.getString("wordpiece-vocab-path"); this.context = nluContext; this.loadThread = threadFactory.newThread(() -> loadVocab(vocabFile)); diff --git a/src/test/java/io/spokestack/spokestack/SpokestackTest.java b/src/test/java/io/spokestack/spokestack/SpokestackTest.java index 07c625b..18b42ea 100644 --- a/src/test/java/io/spokestack/spokestack/SpokestackTest.java +++ b/src/test/java/io/spokestack/spokestack/SpokestackTest.java @@ -4,10 +4,9 @@ import android.os.SystemClock; import androidx.lifecycle.LifecycleOwner; import androidx.lifecycle.LifecycleRegistry; -import io.spokestack.spokestack.nlu.NLUContext; +import io.spokestack.spokestack.nlu.NLUManager; import io.spokestack.spokestack.nlu.NLUResult; import io.spokestack.spokestack.nlu.tensorflow.NLUTestUtils; -import io.spokestack.spokestack.nlu.tensorflow.TensorflowNLU; import io.spokestack.spokestack.tts.SynthesisRequest; import io.spokestack.spokestack.tts.TTSEvent; import io.spokestack.spokestack.tts.TTSManager; @@ -72,12 +71,6 @@ public void testBuild() throws Exception { .setProperty("test", "test") .build(); - // no subsystems exist to handle these calls - assertThrows(NullPointerException.class, spokestack::start); - assertThrows(NullPointerException.class, - () -> spokestack.classify("test")); - assertThrows(NullPointerException.class, spokestack::stopPlayback); - // closing with no active subsystems is fine assertDoesNotThrow(spokestack::close); } @@ -123,13 +116,13 @@ public void testSpeechPipeline() throws Exception { public void testNlu() throws Exception { TestAdapter listener = new TestAdapter(); - Spokestack spokestack = new Spokestack - .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts()) + Spokestack.Builder builder = new Spokestack + .Builder(new SpeechPipeline.Builder(), mockTts()) .addListener(listener) .withoutWakeword() - .withoutTts() - .build(); + .withoutTts(); + Spokestack spokestack = new Spokestack(builder, mockNlu()); listener.setSpokestack(spokestack); NLUResult result = spokestack.classify("test").get(); @@ -137,8 +130,7 @@ public void testNlu() throws Exception { assertNotNull(lastResult); assertEquals(result.getIntent(), lastResult.getIntent()); - NLUContext fakeContext = new NLUContext(testConfig()); - result = spokestack.getNlu().classify("test", fakeContext).get(); + result = spokestack.getNlu().classify("test").get(); assertEquals(result.getIntent(), lastResult.getIntent()); // classification is called automatically on ASR results @@ -156,12 +148,13 @@ public void testAutoClassification() throws Exception { TestAdapter listener = new TestAdapter(); Spokestack.Builder builder = new Spokestack - .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts()) + .Builder(new SpeechPipeline.Builder(), mockTts()) .withoutWakeword() .withoutAutoClassification() .addListener(listener); - Spokestack spokestack = mockAndroidComponents(builder).build(); + builder = mockAndroidComponents(builder); + Spokestack spokestack = new Spokestack(builder, mockNlu()); // automatic classification can be disabled listener.clear(); @@ -178,12 +171,13 @@ public void testTranscriptEditing() throws Exception { TestAdapter listener = new TestAdapter(); Spokestack.Builder builder = new Spokestack - .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts()) + .Builder(new SpeechPipeline.Builder(), mockTts()) .withoutWakeword() .addListener(listener) .withTranscriptEditor(String::toUpperCase); - Spokestack spokestack = mockAndroidComponents(builder).build(); + builder = mockAndroidComponents(builder); + Spokestack spokestack = new Spokestack(builder, mockNlu()); // transcripts can be edited before automatic classification String transcript = "test"; @@ -191,7 +185,7 @@ public void testTranscriptEditing() throws Exception { SpeechContext context = spokestack.getSpeechPipeline().getContext(); context.setTranscript(transcript); context.dispatch(SpeechContext.Event.RECOGNIZE); - NLUResult result = spokestack.classify("test").get(); + NLUResult result = spokestack.classify("TEST").get(); NLUResult lastResult = listener.nluResults.poll(1, TimeUnit.SECONDS); assertNotNull(lastResult); assertEquals(result.getIntent(), lastResult.getIntent()); @@ -211,11 +205,12 @@ public void testTts() throws Exception { SpeechPipeline.Builder pipelineBuilder = new SpeechPipeline.Builder(); Spokestack.Builder builder = new Spokestack - .Builder(pipelineBuilder, mockNlu(), mockTts()) + .Builder(pipelineBuilder, mockTts()) .withoutWakeword() .addListener(listener); - Spokestack spokestack = mockAndroidComponents(builder).build(); + builder = mockAndroidComponents(builder); + Spokestack spokestack = new Spokestack(builder, mockNlu()); listener.setSpokestack(spokestack); TTSManager tts = spokestack.getTts(); @@ -253,7 +248,7 @@ public void testListenerManagement() throws Exception { TestAdapter listener2 = new TestAdapter(); Spokestack.Builder builder = new Spokestack - .Builder(new SpeechPipeline.Builder(), mockNlu(), mockTts()) + .Builder(new SpeechPipeline.Builder(), mockTts()) .withoutWakeword() .setConfig(testConfig()) .setProperty("trace-level", EventTracer.Level.INFO.value()) @@ -261,7 +256,7 @@ public void testListenerManagement() throws Exception { builder = mockAndroidComponents(builder); builder.getPipelineBuilder().setStageClasses(new ArrayList<>()); - Spokestack spokestack = builder.build(); + Spokestack spokestack = new Spokestack(builder, mockNlu()); spokestack.addListener(listener2); spokestack.getSpeechPipeline().activate(); @@ -295,11 +290,8 @@ public void testListenerManagement() throws Exception { assertEquals(error, err); } - private TensorflowNLU.Builder mockNlu() throws Exception { - mockStatic(SystemClock.class); - - NLUTestUtils.TestEnv nluEnv = new NLUTestUtils.TestEnv(); - return nluEnv.nluBuilder; + private NLUManager mockNlu() throws Exception { + return NLUTestUtils.mockManager(); } private TTSManager.Builder mockTts() { diff --git a/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java b/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java index 1bdc103..04631e9 100644 --- a/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java +++ b/src/test/java/io/spokestack/spokestack/nlu/tensorflow/NLUTestUtils.java @@ -1,11 +1,14 @@ package io.spokestack.spokestack.nlu.tensorflow; -import android.os.SystemClock; import com.google.gson.Gson; import com.google.gson.stream.JsonReader; import io.spokestack.spokestack.SpeechConfig; +import io.spokestack.spokestack.nlu.NLUContext; +import io.spokestack.spokestack.nlu.NLUManager; import io.spokestack.spokestack.nlu.NLUResult; +import io.spokestack.spokestack.nlu.NLUService; import io.spokestack.spokestack.tensorflow.TensorflowModel; +import io.spokestack.spokestack.util.AsyncResult; import java.io.FileNotFoundException; import java.io.FileReader; @@ -16,7 +19,6 @@ import java.util.concurrent.Future; import static org.mockito.Mockito.*; -import static org.powermock.api.mockito.PowerMockito.mockStatic; public class NLUTestUtils { @@ -27,6 +29,13 @@ public static SpeechConfig testConfig() { .put("wordpiece-vocab-path", "src/test/resources/vocab.txt"); } + public static NLUManager mockManager() throws Exception { + return new NLUManager.Builder() + .setServiceClass(NLUTestUtils.class.getCanonicalName() + + "$MockNLU") + .build(); + } + public static class TestModel extends TensorflowModel { public TestModel(TensorflowModel.Loader loader) { super(loader); @@ -134,4 +143,25 @@ public EncodedTokens encode(String text) { return encoded; } } + + public static class MockNLU implements NLUService { + + public MockNLU(SpeechConfig config, NLUContext context) { + // empty constructor so it can be built by the manager + } + + @Override + public AsyncResult classify( + String utterance, + NLUContext context + ) { + AsyncResult result = new AsyncResult<>(() -> + new NLUResult.Builder(utterance) + .withIntent(utterance) + .build()); + result.run(); + return result; + } + + } }