From b572dd840d7a750cbfba68fa8e3ec7a636c53ff8 Mon Sep 17 00:00:00 2001 From: Josh Ziegler Date: Wed, 21 Apr 2021 12:19:10 -0400 Subject: [PATCH] Rasa dialogue policy tests --- .../spokestack/rasa/RasaDialoguePolicy.java | 32 +++-- .../rasa/RasaDialoguePolicyTest.java | 115 ++++++++++++++++++ 2 files changed, 139 insertions(+), 8 deletions(-) create mode 100644 src/test/java/io/spokestack/spokestack/rasa/RasaDialoguePolicyTest.java diff --git a/src/main/java/io/spokestack/spokestack/rasa/RasaDialoguePolicy.java b/src/main/java/io/spokestack/spokestack/rasa/RasaDialoguePolicy.java index 4fc2d87..82d2e5a 100644 --- a/src/main/java/io/spokestack/spokestack/rasa/RasaDialoguePolicy.java +++ b/src/main/java/io/spokestack/spokestack/rasa/RasaDialoguePolicy.java @@ -1,6 +1,7 @@ package io.spokestack.spokestack.rasa; import com.google.gson.Gson; +import com.google.gson.JsonSyntaxException; import com.google.gson.reflect.TypeToken; import io.spokestack.spokestack.SpeechConfig; import io.spokestack.spokestack.dialogue.ConversationData; @@ -66,29 +67,44 @@ public void handleTurn( String intent = userTurn.getIntent(); if (!intent.equals(RasaOpenSourceNLU.RASA_INTENT)) { // we can't handle non-Rasa intents - dispatchError(eventDispatcher, intent); + dispatchError(eventDispatcher, "non-Rasa intent: " + intent); } - List responses = getResponses(userTurn); + List responses = getResponses(userTurn, eventDispatcher); for (RasaResponse response : responses) { - dispatchResponse(eventDispatcher, response); + // guard against trailing commas in the json + if (response != null) { + dispatchResponse(eventDispatcher, response); + } } } - private void dispatchError(DialogueDispatcher dispatcher, String intent) { + private void dispatchError(DialogueDispatcher dispatcher, String msg) { ConversationState state = new ConversationState.Builder() - .withError("non-Rasa intent: " + intent) + .withError(msg) .build(); DialogueEvent event = new DialogueEvent(DialogueEvent.Type.ERROR, state); dispatcher.dispatch(event); } - private List getResponses(NLUResult userTurn) { + private List getResponses(NLUResult userTurn, + DialogueDispatcher dispatcher) { Object response = userTurn.getContext() .get(RasaOpenSourceNLU.RESPONSE_KEY); - String json = response != null ? response.toString() : "{}"; - return this.gson.fromJson(json, RasaResponse.TYPE); + String json = String.valueOf(response); + List responses = null; + try { + responses = this.gson.fromJson(json, RasaResponse.TYPE); + } catch (JsonSyntaxException e) { + // let the null check below handle the error + } + + if (responses == null) { + dispatchError(dispatcher, "invalid server response: " + json); + return new ArrayList<>(); + } + return responses; } private void dispatchResponse(DialogueDispatcher dispatcher, diff --git a/src/test/java/io/spokestack/spokestack/rasa/RasaDialoguePolicyTest.java b/src/test/java/io/spokestack/spokestack/rasa/RasaDialoguePolicyTest.java new file mode 100644 index 0000000..66ac24a --- /dev/null +++ b/src/test/java/io/spokestack/spokestack/rasa/RasaDialoguePolicyTest.java @@ -0,0 +1,115 @@ +package io.spokestack.spokestack.rasa; + +import io.spokestack.spokestack.SpeechConfig; +import io.spokestack.spokestack.dialogue.ConversationData; +import io.spokestack.spokestack.dialogue.DialogueDispatcher; +import io.spokestack.spokestack.dialogue.DialogueEvent; +import io.spokestack.spokestack.dialogue.DialogueListener; +import io.spokestack.spokestack.dialogue.InMemoryConversationData; +import io.spokestack.spokestack.dialogue.Prompt; +import io.spokestack.spokestack.nlu.NLUResult; +import io.spokestack.spokestack.util.EventTracer; +import junit.framework.TestListener; +import org.jetbrains.annotations.NotNull; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.*; + +public class RasaDialoguePolicyTest { + + @Test + public void unusedMethods() { + RasaDialoguePolicy policy = new RasaDialoguePolicy(testConfig()); + ConversationData dataStore = new InMemoryConversationData(); + dataStore.set("key", "val"); + assertEquals("", policy.dump(dataStore)); + policy.load("state", dataStore); + } + + @Test + public void testEvents() throws InterruptedException { + RasaDialoguePolicy policy = new RasaDialoguePolicy(testConfig()); + ConversationData dataStore = new InMemoryConversationData(); + TestListener listener = new TestListener(); + DialogueDispatcher dispatcher = testDispatcher(listener); + // null response throws an error + String response = null; + NLUResult result = rasaResult(response); + policy.handleTurn(result, dataStore, dispatcher); + DialogueEvent event = listener.events.poll(500, TimeUnit.MILLISECONDS); + assertNotNull(event); + assertEquals(DialogueEvent.Type.ERROR, event.type); + + // empty response throws an error + response = "{}"; + result = rasaResult(response); + policy.handleTurn(result, dataStore, dispatcher); + event = listener.events.poll(500, TimeUnit.MILLISECONDS); + assertNotNull(event); + assertEquals(DialogueEvent.Type.ERROR, event.type); + + String prompt = "hi"; + String imageURL = "https://example.com"; + response = "[" + + "{\"recipient_id\": \"id\", \"text\": \"" + prompt + "\"}," + + "{\"recipient_id\": \"id\", \"image\": \"" + imageURL + "\"}," + + "]"; + result = rasaResult(response); + policy.handleTurn(result, dataStore, dispatcher); + List events = new ArrayList<>(); + listener.events.drainTo(events); + event = events.get(0); + assertEquals(DialogueEvent.Type.PROMPT, event.type); + String receivedPrompt = event.state.getPrompt().getText(dataStore); + assertEquals(prompt, receivedPrompt); + event = events.get(1); + assertEquals(DialogueEvent.Type.ACTION, event.type); + String receivedURL = event.state.getPrompt().getText(dataStore); + assertEquals(imageURL, receivedURL); + } + + private SpeechConfig testConfig() { + SpeechConfig config = new SpeechConfig(); + config.put("sample-rate", 16000); + config.put("frame-width", 20); + config.put("buffer-width", 300); + return config; + } + + private DialogueDispatcher testDispatcher(DialogueListener listener) { + int level = EventTracer.Level.INFO.value(); + List listeners = new ArrayList<>(); + listeners.add(listener); + return new DialogueDispatcher(level, listeners); + } + + private NLUResult rasaResult(String response) { + HashMap rasaContext = new HashMap<>(); + rasaContext.put(RasaOpenSourceNLU.RESPONSE_KEY, response); + return new NLUResult.Builder("test utterance") + .withIntent(RasaOpenSourceNLU.RASA_INTENT) + .withContext(rasaContext) + .build(); + } + + static class TestListener implements DialogueListener { + LinkedBlockingQueue events = new LinkedBlockingQueue<>(); + + @Override + public void onDialogueEvent(@NotNull DialogueEvent event) { + events.add(event); + } + + @Override + public void onTrace(@NotNull EventTracer.Level level, + @NotNull String message) { + // no-op + } + } +} \ No newline at end of file