From ed37b30e18d360b950c198662c1972de578af20b Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Mon, 11 Jan 2021 17:39:04 -0800 Subject: [PATCH 1/6] Fix `genie server` --- tool/server.ts | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tool/server.ts b/tool/server.ts index 2e5b7dd90..59fbc18b4 100644 --- a/tool/server.ts +++ b/tool/server.ts @@ -24,7 +24,7 @@ import bodyParser from 'body-parser'; // FIXME //import logger from 'morgan'; import errorhandler from 'errorhandler'; -import qv from 'query-validation'; +import * as qv from 'query-validation'; import * as Tp from 'thingpedia'; import * as ThingTalk from 'thingtalk'; @@ -106,13 +106,8 @@ const QUERY_PARAMS = { async function queryNLU(params : Record, data : QueryNLUData, res : express.Response) { - const thingtalk_version = data.thingtalk_version; const app = res.app; - if (thingtalk_version !== ThingTalk.version) { - res.status(400).json({ error: 'Invalid ThingTalk version' }); - return; - } if (params.locale !== app.args.locale) { res.status(400).json({ error: 'Unsupported language' }); return; From d50bbb03161e3b0afd97f4b11ecaf33ead571ff4 Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Mon, 11 Jan 2021 17:52:21 -0800 Subject: [PATCH 2/6] Predictor: add support for confidence scores Thanks to MC dropout and calibration, genienlp can now output confidence scores also when not doing beam search --- lib/prediction/localparserclient.ts | 55 ++++++++++++++++------------- lib/prediction/predictor.ts | 38 ++++++++++---------- lib/prediction/types.ts | 11 ++++++ tool/server.ts | 16 +-------- 4 files changed, 61 insertions(+), 59 deletions(-) diff --git a/lib/prediction/localparserclient.ts b/lib/prediction/localparserclient.ts index 59ec2d12a..14692c59e 100644 --- a/lib/prediction/localparserclient.ts +++ b/lib/prediction/localparserclient.ts @@ -18,6 +18,7 @@ // // Author: Giovanni Campagna +import assert from 'assert'; import * as Tp from 'thingpedia'; import * as ThingTalk from 'thingtalk'; @@ -179,6 +180,11 @@ export default class LocalParserClient { let result : PredictionCandidate[]|null = null; let exact : string[][]|null = null; + const intent = { + command: 1, + other: 0, + ignore: 0 + }; if (tokens.length === 0) { result = [{ @@ -206,27 +212,26 @@ export default class LocalParserClient { } if (result === null) { - if (options.expect === 'Location') { - result = [{ - code: ['$answer', '(', 'new', 'Location', '(', '"', ...tokens, '"', ')', ')', ';'], - score: 1 - }]; - } else { - if (contextCode) - contextCode = this._applyPreHeuristics(contextCode); + if (contextCode) + contextCode = this._applyPreHeuristics(contextCode); - let candidates; - if (contextCode) - candidates = await this._predictor.predict(contextCode.join(' '), tokens.join(' '), answer, NLU_TASK, options.example_id); - else - candidates = await this._predictor.predict(tokens.join(' '), undefined, answer, SEMANTIC_PARSING_TASK, options.example_id); - result = candidates.map((c) => { - return { - code: c.answer.split(' '), - score: c.score - }; - }); - } + let candidates; + if (contextCode) + candidates = await this._predictor.predict(contextCode.join(' '), tokens.join(' '), answer, NLU_TASK, options.example_id); + else + candidates = await this._predictor.predict(tokens.join(' '), undefined, answer, SEMANTIC_PARSING_TASK, options.example_id); + assert(candidates.length > 0); + + result = candidates.map((c) => { + return { + code: c.answer.split(' '), + score: c.score.confidence ?? 1 + }; + }); + + intent.ignore = candidates[0].score.ignore ?? 0; + intent.command = (candidates[0].score.in_domain ?? 1) * (1 - intent.ignore); + intent.other = 1 - intent.ignore - intent.command; } let result2 = result!; // guaranteed not null @@ -263,18 +268,18 @@ export default class LocalParserClient { result: 'ok', tokens: tokens, candidates: result2, - entities: entities + entities: entities, + intent }; } async generateUtterance(contextCode : string[], contextEntities : EntityMap, targetAct : string[]) : Promise { - let candidates = await this._predictor.predict(contextCode.join(' ') + ' ' + targetAct.join(' '), NLG_QUESTION, undefined, NLG_TASK); - candidates = candidates.map((cand) => { + const candidates = await this._predictor.predict(contextCode.join(' ') + ' ' + targetAct.join(' '), NLG_QUESTION, undefined, NLG_TASK); + return candidates.map((cand) => { return { answer: this._langPack.postprocessNLG(cand.answer, contextEntities), - score: cand.score + score: cand.score.confidence ?? 1 }; }); - return candidates; } } diff --git a/lib/prediction/predictor.ts b/lib/prediction/predictor.ts index c30c155ca..ff5d71dc0 100644 --- a/lib/prediction/predictor.ts +++ b/lib/prediction/predictor.ts @@ -27,13 +27,13 @@ import JsonDatagramSocket from '../utils/json_datagram_socket'; const DEFAULT_QUESTION = 'translate from english to thingtalk'; -interface PredictionCandidate { +export interface RawPredictionCandidate { answer : string; - score : number; + score : Record; } interface Request { - resolve(data : PredictionCandidate[][]) : void; + resolve(data : RawPredictionCandidate[][]) : void; reject(err : Error) : void; } @@ -46,7 +46,7 @@ interface Example { answer ?: string; example_id ?: string; - resolve(data : PredictionCandidate[]) : void; + resolve(data : RawPredictionCandidate[]) : void; reject(err : Error) : void; } @@ -114,14 +114,16 @@ class LocalWorker extends events.EventEmitter { if (msg.error) { req.reject(new Error(msg.error)); } else { - req.resolve(msg.instances.map((instance : any) : PredictionCandidate[] => { + req.resolve(msg.instances.map((instance : any) : RawPredictionCandidate[] => { if (instance.candidates) { return instance.candidates; } else { - // no beam search, hence only one candidate, and fixed score + // no beam search, hence only one candidate + // the score might present or not, depending on whether + // we calibrate or not return [{ answer: instance.answer, - score: 1 + score: instance.score || {} }]; } })); @@ -136,7 +138,7 @@ class LocalWorker extends events.EventEmitter { this._requests.clear(); } - request(task : string, minibatch : Example[]) : Promise { + request(task : string, minibatch : Example[]) : Promise { const id = this._nextId ++; return new Promise((resolve, reject) => { @@ -164,23 +166,21 @@ class RemoteWorker extends events.EventEmitter { start() {} stop() {} - async request(task : string, minibatch : Example[]) : Promise { + async request(task : string, minibatch : Example[]) : Promise { const response = await Tp.Helpers.Http.post(this._url, JSON.stringify({ - id: 0, // should be ignored task, instances: minibatch }), { dataContentType: 'application/json', accept: 'application/json' }); - const parsed = JSON.parse(response); - // TODO: this needs to be updated when genienlp kfserver is fixed to avoid - // double wrapping in JSON - return JSON.parse(parsed.predictions).instances.map((instance : any) : PredictionCandidate[] => { + return JSON.parse(response).predictions.map((instance : any) : RawPredictionCandidate[] => { if (instance.candidates) { return instance.candidates; } else { - // no beam search, hence only one candidate, and fixed score + // no beam search, hence only one candidate + // the score might present or not, depending on whether + // we calibrate or not return [{ answer: instance.answer, - score: 1 + score: instance.score || {} }]; } }); @@ -254,7 +254,7 @@ export default class Predictor { } } - predict(context : string, question = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string) : Promise { + predict(context : string, question = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string) : Promise { assert(typeof context === 'string'); assert(typeof question === 'string'); @@ -262,9 +262,9 @@ export default class Predictor { if (!this._worker) this.start(); - let resolve ! : (data : PredictionCandidate[]) => void, + let resolve ! : (data : RawPredictionCandidate[]) => void, reject ! : (err : Error) => void; - const promise = new Promise((_resolve, _reject) => { + const promise = new Promise((_resolve, _reject) => { resolve = _resolve; reject = _reject; }); diff --git a/lib/prediction/types.ts b/lib/prediction/types.ts index 8f13259d1..2d2c3bc9e 100644 --- a/lib/prediction/types.ts +++ b/lib/prediction/types.ts @@ -40,11 +40,22 @@ export interface PredictionCandidate { score : number|'Infinity'; } +// this type matches the NLP web API exactly, including some +// odd aspects around "intent" export interface PredictionResult { result : 'ok'; tokens : string[]; entities : EntityMap; candidates : PredictionCandidate[]; + + // the server's best guess of whether this is a command (in-domain), + // an out of domain command (could be a new function, web question, or + // chatty sentence), or should be ignored altogether + intent : { + command : number; + other : number; + ignore : number; + } } export interface GenerationResult { diff --git a/tool/server.ts b/tool/server.ts index 59fbc18b4..45865cb34 100644 --- a/tool/server.ts +++ b/tool/server.ts @@ -113,23 +113,9 @@ async function queryNLU(params : Record, return; } - // emulate the frontend classifier for API compatibility - const intent = { - question: 0, - command: 1, - chatty: 0, - other: 0 - }; - const result = await res.app.backend.nlu.sendUtterance(data.q, data.context ? data.context.split(' ') : undefined, data.entities, data); - - res.json({ - candidates: result.candidates, - tokens: result.tokens, - entities: result.entities, - intent - }); + res.json(result); } interface QueryNLGData { From 735c865789db228622cbdf54085d8ee97fcc0da8 Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Fri, 5 Feb 2021 14:44:12 -0800 Subject: [PATCH 3/6] Refactor the dialogue loop Move all the code responsible for converting natural language to ThingTalk and deciding what to do with the ThingTalk code to the DialogueLoop class, leaving Conversation to only take care of maintaining history and dispatching to multiple web clients for the same conversation. --- lib/dialogue-agent/conversation.ts | 264 +++---------- lib/dialogue-agent/dialogue-loop.ts | 546 +++++++++++++++++++-------- lib/dialogue-agent/dialogue_queue.ts | 7 +- lib/dialogue-agent/user-input.ts | 162 +------- lib/utils/thingtalk/index.ts | 52 +++ test/agent/index.js | 6 +- tool/interactive-annotate.ts | 41 +- 7 files changed, 507 insertions(+), 571 deletions(-) diff --git a/lib/dialogue-agent/conversation.ts b/lib/dialogue-agent/conversation.ts index b350a5851..8790caf62 100644 --- a/lib/dialogue-agent/conversation.ts +++ b/lib/dialogue-agent/conversation.ts @@ -20,14 +20,12 @@ import * as events from 'events'; -import interpolate from 'string-interp'; import type * as Tp from 'thingpedia'; import * as ThingTalk from 'thingtalk'; import * as I18n from '../i18n'; -import * as ParserClient from '../prediction/parserclient'; -import UserInput, { PlatformData } from './user-input'; +import { PlatformData } from './user-input'; import ValueCategory from './value-category'; import DialogueLoop from './dialogue-loop'; import { MessageType, Message, RDL } from './protocol'; @@ -74,19 +72,19 @@ export interface ConversationDelegate { addMessage(msg : Message) : Promise; } -interface SetContextOptions { - explicitStrings ?: boolean; -} - interface ResultLike { toLocaleString(locale ?: string) : string; } -interface PredictionCandidate { - target : UserInput; - score : number|'Infinity'; -} - +/** + * A single session of conversation in Almond. + * + * This object is responsible for maintaining the history of the conversation + * to support clients reconnecting to the same conversation later, as well + * as tracking connected clients and inactivity timeouts. + * + * The actual conversation logic is in {@link DialogueLoop}. + */ export default class Conversation extends events.EventEmitter { private _engine : Engine; private _user : AssistantUser; @@ -98,18 +96,10 @@ export default class Conversation extends events.EventEmitter { private _options : ConversationOptions; private _debug : boolean; rng : () => number; - private _prefs : Tp.Preferences; - private _nlu : ParserClient.ParserClient; - private _nlg : ParserClient.ParserClient; - - private _raw : boolean; - private _lastCommand : ParserClient.PredictionResult|null; - private _lastCandidates : PredictionCandidate[]|null; private _loop : DialogueLoop; private _expecting : ValueCategory|null; private _context : Context; - private _choices : string[]; private _delegates : Set; private _history : Message[]; @@ -137,30 +127,18 @@ export default class Conversation extends events.EventEmitter { else this._stats = stats; - this._raw = false; this._options = options; this._debug = !!this._options.debug; this.rng = options.rng || Math.random; - this._prefs = engine.platform.getSharedPreferences(); - this._nlu = ParserClient.get(this._options.nluServerUrl, engine.platform.locale, engine.platform, - undefined, engine.thingpedia); - if (this._options.nlgServerUrl) - this._nlg = ParserClient.get(this._options.nlgServerUrl, engine.platform.locale, engine.platform); - else - this._nlg = this._nlu; - this._lastCommand = null; - this._lastCandidates = null; - - this._loop = new DialogueLoop(this, this._engine, this._debug); - this._choices = []; + this._loop = new DialogueLoop(this, this._engine, { + nluServerUrl: options.nluServerUrl, + nlgServerUrl: options.nlgServerUrl, + debug: this._debug + }); this._expecting = null; - this._context = { - code: [], - entities: {} - }; - this.setContext(null); + this._context = { code: ['null'], entities: {} }; this._delegates = new Set; this._history = []; this._nextMsgId = 0; @@ -223,157 +201,18 @@ export default class Conversation extends events.EventEmitter { return this._loop.dispatchNotifyError(appId, icon, error); } - setExpected(expecting : ValueCategory|null, raw : boolean) : void { + setExpected(expecting : ValueCategory|null, context : Context) : void { this._expecting = expecting; - this._choices = []; - this._raw = raw; + this._context = context; } async start() : Promise { - await this._nlu.start(); - if (this._nlu !== this._nlg) - await this._nlg.start(); this._resetInactivityTimeout(); return this._loop.start(!!this._options.showWelcome); } async stop() : Promise { - await this._nlu.stop(); - if (this._nlu !== this._nlg) - await this._nlg.stop(); - } - - private _isUnsupportedError(e : Error) : boolean { - // FIXME there should be a better way to do this - - // 'xxx has no actions yyy' or 'xxx has no queries yyy' - // quite likely means that the NN worked but it produced a device that - // was not approved yet (otherwise the NN itself would catch the invalid function and - // skip this result) and we don't have the necessary developer key - // in that case, we reply to the user that the command is unsupported - return /(invalid kind| has no (quer(ies|y)|actions?)) /i.test(e.message); - } - - // set confident = true only if - // 1) we are not dealing with natural language (code, gui, etc), or - // 2) we find an exact match - private _doHandleCommand(intent : UserInput, - analyzed : ParserClient.PredictionResult|null, - candidates : PredictionCandidate[], - confident=false) { - this._lastCommand = analyzed; - this._lastCandidates = candidates; - return this._loop.handle(intent, confident); - } - - private _getContext(currentCommand : string|null, platformData : PlatformData) { - return { - command: currentCommand, - previousCommand: this._lastCommand, - previousCandidates: this._lastCandidates, - platformData: platformData - }; - } - - setContext(context : ThingTalk.Ast.DialogueState|null, options : SetContextOptions = {}) { - if (context === null) { - this._context = { - code: ['null'], - entities: {} - }; - } else { - const [code, entities] = ThingTalkUtils.serializeNormalized(context); - this._context = { code, entities }; - } - } - - async generateAnswer(policyPrediction : ThingTalk.Ast.DialogueState) : Promise { - const [targetAct,] = ThingTalkUtils.serializeNormalized(policyPrediction, this._context.entities); - const result = await this._nlg.generateUtterance(this._context.code, this._context.entities, targetAct); - return result[0].answer; - } - - private async _continueHandleCommand(command : string, - analyzed : ParserClient.PredictionResult, - platformData : PlatformData) : Promise { - // parse all code sequences into an Intent - // this will correctly filter out anything that does not parse - if (analyzed.candidates.length > 0) - console.log('Analyzed message into ' + analyzed.candidates[0].code.join(' ')); - else - console.log('Failed to analyze message'); - const candidates = await Promise.all(analyzed.candidates.map(async (candidate, beamposition) => { - let parsed; - try { - parsed = await UserInput.parse({ code: candidate.code, entities: analyzed.entities }, - this.thingpedia, this.schemas, this._getContext(command, platformData)); - } catch(e) { - // Likely, a type error in the ThingTalk code; not a big deal, but we still log it - console.log(`Failed to parse beam ${beamposition}: ${e.message}`); - - if (this._isUnsupportedError(e)) - parsed = new UserInput.Unsupported(platformData); - else - return null; - } - return { target: parsed, score: candidate.score }; - })).then((candidates) => candidates.filter((c : T) : c is Exclude => c !== null)); - - // here we used to do a complex heuristic dance of probabilities and confidence scores - // we do none of that, because Almond-NNParser does not give us useful scores - - if (candidates.length > 0) { - let i = 0; - let choice = candidates[i]; - while (i < candidates.length-1 && choice.target instanceof UserInput.Unsupported && choice.score === 'Infinity') { - i++; - choice = candidates[i]; - } - - this.stats.hit('sabrina-command-good'); - const confident = choice.score === 'Infinity'; - return this._doHandleCommand(choice.target, analyzed, candidates, confident); - } else { - this._lastCommand = analyzed; - this._lastCandidates = candidates; - - this.stats.hit('sabrina-failure'); - return this._loop.handle(new UserInput.Failed(command, platformData)); - } - } - - private async _errorWrap(fn : () => Promise, platformData : PlatformData) : Promise { - try { - try { - await fn(); - } catch(e) { - if (this._isUnsupportedError(e)) - await this._doHandleCommand(new UserInput.Unsupported(platformData), null, [], true); - else - throw e; - } - } catch(e) { - if (e.code === 'EHOSTUNREACH' || e.code === 'ETIMEDOUT') { - await this.sendReply('Sorry, I cannot contact the Almond service. Please check your Internet connection and try again later.', null); - } else if (typeof e.code === 'number' && (e.code === 404 || e.code >= 500)) { - await this.sendReply('Sorry, there seems to be a problem with the Almond service at the moment. Please try again later.', null); - } else { - await this.sendReply(interpolate(this._("Sorry, I had an error processing your command: ${error}"), { - error: e.message - }, { locale: this.platform.locale, timezone: this.platform.timezone })||'', null); - console.error(e); - } - await this._loop.reset(); - await this.sendAskSpecial(); - } - } - - private _sendUtterance(utterance : string) { - return this._nlu.sendUtterance(utterance, this._context.code, this._context.entities, { - expect: this._expecting ? String(this._expecting) : undefined, - choices: this._choices, - store: this._prefs.get('sabrina-store-log') as string || 'no' - }); + return this._loop.stop(); } private _resetInactivityTimeout() { @@ -417,8 +256,7 @@ export default class Conversation extends events.EventEmitter { await Promise.all(Array.from(this._delegates).map((out) => out.addMessage(msg))); } - async handleCommand(command : string, platformData : PlatformData = {}, - postprocess ?: (analysis : ParserClient.PredictionResult) => void) : Promise { + async handleCommand(command : string, platformData : PlatformData = {}) : Promise { this.stats.hit('sabrina-command'); this.emit('active'); this._resetInactivityTimeout(); @@ -426,22 +264,7 @@ export default class Conversation extends events.EventEmitter { if (this._debug) console.log('Received assistant command ' + command); - return this._errorWrap(async () => { - if (this._raw && command !== null) { - let value; - if (this._expecting === ValueCategory.Location) - value = new ThingTalk.Ast.LocationValue(new ThingTalk.Ast.UnresolvedLocation(command)); - else - value = new ThingTalk.Ast.Value.String(command); - const intent = new UserInput.Answer(value, platformData); - return this._doHandleCommand(intent, null, [], true); - } - - const analyzed = await this._sendUtterance(command); - if (postprocess) - postprocess(analyzed); - return this._continueHandleCommand(command, analyzed, platformData); - }, platformData); + return this._loop.handleCommand({ type: 'command', utterance: command, platformData }); } async handleParsedCommand(root : any, title ?: string, platformData : PlatformData = {}) : Promise { @@ -451,6 +274,7 @@ export default class Conversation extends events.EventEmitter { if (typeof root === 'string') root = JSON.parse(root); await this._addMessage({ type: MessageType.COMMAND, command: title || '\\r ' + JSON.stringify(root), json: root }); + if (this._debug) console.log('Received pre-parsed assistant command'); if (root.example_id) { @@ -459,11 +283,25 @@ export default class Conversation extends events.EventEmitter { }); } - return this._errorWrap(async () => { - const intent = await UserInput.parse(root, this.thingpedia, this.schemas, - this._getContext(null, platformData)); - return this._doHandleCommand(intent, null, [], true); - }, platformData); + if ('program' in root) + return this.handleThingTalk(root.program, platformData); + + const { code, entities } = root; + for (const name in entities) { + if (name.startsWith('SLOT_')) { + const slotname = root.slots![parseInt(name.substring('SLOT_'.length))]; + const slotType = ThingTalk.Type.fromString(root.slotTypes![slotname]); + const value = ThingTalk.Ast.Value.fromJSON(slotType, entities[name]); + entities[name] = value; + } + } + + const parsed = await ThingTalkUtils.parsePrediction(code, entities, { + thingpediaClient: this._engine.thingpedia, + schemaRetriever: this._engine.schemas, + loadMetadata: true + }, true); + return this._loop.handleCommand({ type: 'thingtalk', parsed, platformData }); } async handleThingTalk(program : string, platformData : PlatformData = {}) : Promise { @@ -474,10 +312,12 @@ export default class Conversation extends events.EventEmitter { if (this._debug) console.log('Received ThingTalk program'); - return this._errorWrap(async () => { - const intent = await UserInput.parse({ program }, this.thingpedia, this.schemas, this._getContext(null, platformData)); - return this._doHandleCommand(intent, null, [], true); - }, platformData); + const parsed = await ThingTalkUtils.parse(program, { + thingpediaClient: this._engine.thingpedia, + schemaRetriever: this._engine.schemas, + loadMetadata: true + }); + return this._loop.handleCommand({ type: 'thingtalk', parsed, platformData }); } async setHypothesis(hypothesis : string) : Promise { @@ -522,21 +362,11 @@ export default class Conversation extends events.EventEmitter { sendChoice(idx : number, title : string) { if (this._expecting !== ValueCategory.MultipleChoice) console.log('UNEXPECTED: sendChoice while not expecting a MultipleChoice'); - - this._choices[idx] = title; if (this._debug) console.log('Genie sends multiple choice button: '+ title); return this._addMessage({ type: MessageType.CHOICE, idx, title }); } - async resendChoices() { - if (this._expecting !== ValueCategory.MultipleChoice) - console.log('UNEXPECTED: sendChoice while not expecting a MultipleChoice'); - - for (let idx = 0; idx < this._choices.length; idx++) - await this._addMessage({ type: MessageType.CHOICE, idx, title: this._choices[idx] }); - } - sendButton(title : string, json : string) { if (this._debug) console.log('Genie sends generic button: '+ title); diff --git a/lib/dialogue-agent/dialogue-loop.ts b/lib/dialogue-agent/dialogue-loop.ts index f5cf0e6e5..b4b0bf192 100644 --- a/lib/dialogue-agent/dialogue-loop.ts +++ b/lib/dialogue-agent/dialogue-loop.ts @@ -23,17 +23,22 @@ import assert from 'assert'; import * as Tp from 'thingpedia'; import * as ThingTalk from 'thingtalk'; -const Ast = ThingTalk.Ast; +import { Ast } from 'thingtalk'; import interpolate from 'string-interp'; import AsyncQueue from 'consumer-queue'; import { getProgramIcon } from '../utils/icons'; -import { computePrediction, computeNewState, prepareContextForPrediction } from '../utils/thingtalk'; +import * as ThingTalkUtils from '../utils/thingtalk'; +import { EntityMap } from '../utils/entity-utils'; import type Engine from '../engine'; +import * as ParserClient from '../prediction/parserclient'; import ValueCategory from './value-category'; import QueueItem from './dialogue_queue'; -import UserInput, { PlatformData } from './user-input'; +import { + UserInput, + PlatformData +} from './user-input'; import { CancellationError } from './errors'; import * as Helpers from './helpers'; @@ -52,9 +57,45 @@ const TERMINAL_STATES = [ 'sys_end', 'sys_action_success' ]; +enum CommandAnalysisType { + // special commands + STOP, + NEVERMIND, + WAKEUP, + DEBUG, + + // some sort of command + IN_DOMAIN_COMMAND, + OUT_OF_DOMAIN_COMMAND, + PARSE_FAILURE, + + // ignore this command and do nothing + IGNORE +} + +interface CommandAnalysisResult { + type : CommandAnalysisType; + + // not null if this command was generated as a ThingTalk $answer() + // only used by legacy ask() methods + answer : Ast.Value|number|null; + + // the user target + parsed : Ast.Input|null; +} + +interface DialogueLoopOptions { + nluServerUrl : string|undefined; + nlgServerUrl : string|undefined; + debug : boolean; +} + export default class DialogueLoop { conversation : Conversation; engine : Engine; + + private _nlu : ParserClient.ParserClient; + private _nlg : ParserClient.ParserClient; private _textFormatter : TextFormatter; private _cardFormatter : CardFormatter; @@ -68,33 +109,41 @@ export default class DialogueLoop { icon : string|null; expecting : ValueCategory|null; platformData : PlatformData; + private _raw = false; + private _choices : string[]; private _dialogueState : ThingTalk.Ast.DialogueState|null; private _executorState : undefined; private _lastNotificationApp : string|undefined; + private _stopped = false; private _mgrResolve : (() => void)|null; private _mgrPromise : Promise|null; constructor(conversation : Conversation, engine : Engine, - debug : boolean) { + options : DialogueLoopOptions) { this._userInputQueue = new AsyncQueue(); this._notifyQueue = new AsyncQueue(); - this._debug = debug; + this._debug = options.debug; this.conversation = conversation; this.engine = engine; this._prefs = engine.platform.getSharedPreferences(); + this._nlu = ParserClient.get(options.nluServerUrl || undefined, engine.platform.locale, engine.platform, + undefined, engine.thingpedia); + this._nlg = ParserClient.get(options.nlgServerUrl || undefined, engine.platform.locale, engine.platform); + this._textFormatter = new TextFormatter(engine.platform.locale, engine.platform.timezone, engine.schemas); this._cardFormatter = new CardFormatter(engine.platform.locale, engine.platform.timezone, engine.schemas); this.icon = null; this.expecting = null; + this._choices = []; this.platformData = {}; this._mgrResolve = null; this._mgrPromise = null; - this._agent = new ExecutionDialogueAgent(engine, this, debug); + this._agent = new ExecutionDialogueAgent(engine, this, options.debug); this._policy = new DialoguePolicy({ thingpedia: conversation.thingpedia, schemas: conversation.schemas, @@ -131,7 +180,7 @@ export default class DialogueLoop { })||''; } - async nextIntent() : Promise { + async nextCommand() : Promise { await this.conversation.sendAskSpecial(); this._mgrPromise = null; this._mgrResolve!(); @@ -147,27 +196,143 @@ export default class DialogueLoop { } } - private async _handleUICommand(intent : UserInput.UICommand) { - switch (intent.type) { - case 'stop': - // stop means cancel, but without a failure message - throw new CancellationError(); + private _getSpecialThingTalkType(input : Ast.Input) : CommandAnalysisType { + if (input instanceof Ast.ControlCommand) { + if (input.intent instanceof Ast.SpecialControlIntent) { + switch (input.intent.type) { + case 'stop': + return CommandAnalysisType.STOP; + case 'nevermind': + return CommandAnalysisType.NEVERMIND; + case 'wakeup': + return CommandAnalysisType.WAKEUP; + case 'debug': + return CommandAnalysisType.DEBUG; + case 'failed': + return CommandAnalysisType.PARSE_FAILURE; + } + } + } - case 'nevermind': - await this.reply(this._("Sorry I couldn't help on that.")); - throw new CancellationError(); + // anything else is automatically in-domain + return CommandAnalysisType.IN_DOMAIN_COMMAND; + } - case 'debug': - await this.reply("Current State:\n" + (this._dialogueState ? this._dialogueState.prettyprint() : "null")); - break; + private _maybeGetThingTalkAnswer(input : Ast.Input) : Ast.Value|number|null { + if (input instanceof Ast.ControlCommand) { + if (input.intent instanceof Ast.SpecialControlIntent) { + switch (input.intent.type) { + case 'yes': + case 'no': + return new Ast.Value.Boolean(input.intent.type === 'yes'); + } + } else if (input.intent instanceof Ast.AnswerControlIntent + || input.intent instanceof Ast.ChoiceControlIntent) { + return input.intent.value; + } + } + return null; + } - case 'wakeup': - // nothing to do - break; + private _prepareContextForPrediction(state : Ast.DialogueState|null, forSide : 'user'|'agent') : [string[], EntityMap] { + const prepared = ThingTalkUtils.prepareContextForPrediction(state, forSide); + return ThingTalkUtils.serializeNormalized(prepared); + } - default: - await this.fail(); + private async _analyzeCommand(command : UserInput) : Promise { + if (command.type === 'thingtalk') { + const type = this._getSpecialThingTalkType(command.parsed); + return { + type, + answer: this._maybeGetThingTalkAnswer(command.parsed), + parsed: type === CommandAnalysisType.IN_DOMAIN_COMMAND ? command.parsed : null + }; + } + + // ok so this was a natural language + + if (this._raw) { + // in "raw mode", all natural language becomes an answer + let value; + if (this.expecting === ValueCategory.Location) + value = new Ast.LocationValue(new Ast.UnresolvedLocation(command.utterance)); + else + value = new Ast.Value.String(command.utterance); + return { + type: CommandAnalysisType.IN_DOMAIN_COMMAND, + answer: value, + parsed: new Ast.ControlCommand(null, new Ast.AnswerControlIntent(null, value)) + }; + } + + // alright, let's ask parser first then + let nluResult : ParserClient.PredictionResult; + try { + const [contextCode, contextEntities] = this._prepareContextForPrediction(this._dialogueState, 'user'); + + nluResult = await this._nlu.sendUtterance(command.utterance, contextCode, contextEntities, { + expect: this.expecting ? String(this.expecting) : undefined, + choices: this._choices, + store: this._prefs.get('sabrina-store-log') as string || 'no' + }); + } catch(e) { + if (e.code === 'EHOSTUNREACH' || e.code === 'ETIMEDOUT') { + await this.reply(this._("Sorry, I cannot contact the Almond service. Please check your Internet connection and try again later."), null); + throw new CancellationError(); + } else if (typeof e.code === 'number' && (e.code === 404 || e.code >= 500)) { + await this.reply(this._("Sorry, there seems to be a problem with the Almond service at the moment. Please try again later."), null); + throw new CancellationError(); + } else { + throw e; + } + } + + // parse all code sequences into an Intent + // this will correctly filter out anything that does not parse + if (nluResult.candidates.length > 0) + this.debug('Analyzed message into ' + nluResult.candidates[0].code.join(' ')); + else + this.debug('Failed to analyze message'); + const candidates = await Promise.all(nluResult.candidates.map(async (candidate, beamposition) => { + let parsed; + try { + parsed = await ThingTalkUtils.parsePrediction(candidate.code, nluResult.entities, { + thingpediaClient: this.engine.thingpedia, + schemaRetriever: this.engine.schemas, + loadMetadata: true, + }, true); + } catch(e) { + // Likely, a type error in the ThingTalk code; not a big deal, but we still log it + console.log(`Failed to parse beam ${beamposition}: ${e.message}`); + parsed = new Ast.ControlCommand(null, new Ast.SpecialControlIntent(null, 'failed')); + } + return { parsed, score: candidate.score }; + })); + // ensure that we always have at least one candidate by pushing $failed at the end + candidates.push({ parsed: new Ast.ControlCommand(null, new Ast.SpecialControlIntent(null, 'failed')), score: 0 }); + + // ignore all candidates with score==Infinity that we failed to parse + // (these are exact matches that correspond to skills not available for + // this user) + let i = 0; + let choice = candidates[i]; + let type = this._getSpecialThingTalkType(choice.parsed); + while (i < candidates.length-1 && type === CommandAnalysisType.PARSE_FAILURE && choice.score === 'Infinity') { + i++; + choice = candidates[i]; + type = this._getSpecialThingTalkType(choice.parsed); } + + if (type === CommandAnalysisType.PARSE_FAILURE) + this.conversation.stats.hit('sabrina-failure'); + else + this.conversation.stats.hit('sabrina-command-good'); + + return { + type, + answer: this._maybeGetThingTalkAnswer(choice.parsed), + parsed: type === CommandAnalysisType.IN_DOMAIN_COMMAND ? choice.parsed : null + }; } private async _getFallbackExamples(command : string) { @@ -188,56 +353,6 @@ export default class DialogueLoop { this.replyButton(Helpers.presentExample(this, ex.utterance), JSON.stringify(ex.target)); } - private async _computePrediction(intent : UserInput) : Promise { - // handle all intents generated internally and by the UI: - // - // - Failed when parsing fails - // - Answer when the user clicks a button, or when the agent is in "raw mode" - // - NeverMind when the user clicks the X button - // - Debug when the user clicks/types "debug" - // - WakeUp when the user says the wake word and nothing else - if (intent instanceof UserInput.Failed) { - await this._getFallbackExamples(intent.utterance); - return null; - } - if (intent instanceof UserInput.Unsupported) { - this.icon = null; - await this.reply(this._("Sorry, I don't know how to do that yet.")); - throw new CancellationError(); - } - if (intent instanceof UserInput.Answer) { - const handled = await this._policy.handleAnswer(this._dialogueState, intent.value); - if (!handled) { - await this.fail(); - return null; - } - return computePrediction(this._dialogueState, handled, 'user'); - } - if (intent instanceof UserInput.MultipleChoiceAnswer) { - await this.fail(); - return null; - } - - if (intent instanceof UserInput.Program) { - // convert thingtalk programs to dialogue states so we can use "\t" without too much typing - const prediction = new Ast.DialogueState(null, 'org.thingpedia.dialogue.transaction', 'execute', null, []); - for (const stmt of intent.program.statements) { - if (stmt instanceof Ast.Assignment) - throw new Error(`Unsupported: assignment statement`); - prediction.history.push(new Ast.DialogueHistoryItem(null, stmt, null, 'accepted')); - } - return prediction; - } - - if (intent instanceof UserInput.UICommand) { - await this._handleUICommand(intent); - return null; - } - - assert(intent instanceof UserInput.DialogueState); - return intent.prediction; - } - private _useNeuralNLG() : boolean { return this._prefs.get('experimental-use-neural-nlg') as boolean; } @@ -255,14 +370,15 @@ export default class DialogueLoop { if (this._useNeuralNLG()) { [this._dialogueState, expect, , numResults] = policyResult; - const policyPrediction = computeNewState(oldState, this._dialogueState, 'agent'); + const policyPrediction = ThingTalkUtils.computeNewState(oldState, this._dialogueState, 'agent'); this.debug(`Agent act:`); this.debug(policyPrediction.prettyprint()); - const context = prepareContextForPrediction(oldState, 'agent'); - await this.conversation.setContext(context); + const [contextCode, contextEntities] = this._prepareContextForPrediction(this._dialogueState, 'agent'); - utterance = await this.conversation.generateAnswer(policyPrediction); + const [targetAct,] = ThingTalkUtils.serializeNormalized(policyPrediction, contextEntities); + const result = await this._nlg.generateUtterance(contextCode, contextEntities, targetAct); + utterance = result[0].answer; } else { [this._dialogueState, expect, utterance, numResults] = policyResult; } @@ -272,44 +388,105 @@ export default class DialogueLoop { if (expect === null && TERMINAL_STATES.includes(this._dialogueState!.dialogueAct)) throw new CancellationError(); - await this.setExpected(expect); return [expect, numResults]; } - private async _handleUserInput(intent : UserInput) { + private async _handleUICommand(type : CommandAnalysisType.STOP|CommandAnalysisType.NEVERMIND|CommandAnalysisType.DEBUG|CommandAnalysisType.WAKEUP) { + switch (type) { + case CommandAnalysisType.STOP: + // stop means cancel, but without a failure message + throw new CancellationError(); + + case CommandAnalysisType.NEVERMIND: + await this.reply(this._("Sorry I couldn't help on that.")); + throw new CancellationError(); + + case CommandAnalysisType.DEBUG: + await this.reply("Current State:\n" + (this._dialogueState ? this._dialogueState.prettyprint() : "null")); + break; + + case CommandAnalysisType.WAKEUP: + // "wakeup" means the user said "hey almond" without anything else, + // or said "hey almond wake up", or triggered one of the LaunchIntents + // in Google Assistant or Alexa, or similar "opening" statements + // we show the welcome message if the current state is null, + // and do nothing otherwise + if (this._dialogueState === null) { + this._showWelcome(); + // keep the microphone open for a while + await this.setExpected(ValueCategory.Command); + } + } + } + + private async _handleUserInput(command : UserInput) { for (;;) { - const prediction = await this._computePrediction(intent); - if (prediction === null) { - intent = await this.nextIntent(); - continue; + const analyzed = await this._analyzeCommand(command); + + switch (analyzed.type) { + case CommandAnalysisType.STOP: + case CommandAnalysisType.NEVERMIND: + case CommandAnalysisType.DEBUG: + case CommandAnalysisType.WAKEUP: + await this._handleUICommand(analyzed.type); + break; + + case CommandAnalysisType.PARSE_FAILURE: + await this._getFallbackExamples(command.type === 'command' ? command.utterance : ''); + break; + + case CommandAnalysisType.OUT_OF_DOMAIN_COMMAND: + // TODO dispatch this out + await this.reply(this._("Sorry, I don't know how to do that yet.")); + throw new CancellationError(); + + default: { + // everything else is an in-domain command + const prediction = await ThingTalkUtils.inputToDialogueState(this._policy, this._dialogueState, analyzed.parsed!); + if (prediction === null) { + // the command does not make sense in the current state + // do nothing and keep the current state + // (this can only occur with commands caught by the exact + // matcher like "yes" or "no") + await this.fail(); + break; + } + + const terminated = await this._handleNormalDialogueCommand(prediction); + if (terminated) + return; + } } - this._dialogueState = computeNewState(this._dialogueState, prediction, 'user'); - this._checkPolicy(this._dialogueState.policy); - this.icon = getProgramIcon(this._dialogueState); - //this.debug(`Before execution:`); - //this.debug(this._dialogueState.prettyprint()); + command = await this.nextCommand(); + } + } - const { newDialogueState, newExecutorState, newResults } = await this._agent.execute(this._dialogueState, this._executorState); - this._dialogueState = newDialogueState; - this._executorState = newExecutorState; - this.debug(`Execution state:`); - this.debug(this._dialogueState!.prettyprint()); + private async _handleNormalDialogueCommand(prediction : Ast.DialogueState) : Promise { + this._dialogueState = ThingTalkUtils.computeNewState(this._dialogueState, prediction, 'user'); + this._checkPolicy(this._dialogueState.policy); + this.icon = getProgramIcon(this._dialogueState); - const [expect, numResults] = await this._doAgentReply(); + //this.debug(`Before execution:`); + //this.debug(this._dialogueState.prettyprint()); - for (const [outputType, outputValue] of newResults.slice(0, numResults)) { - const formatted = await this._cardFormatter.formatForType(outputType, outputValue, { removeText: true }); + const { newDialogueState, newExecutorState, newResults } = await this._agent.execute(this._dialogueState, this._executorState); + this._dialogueState = newDialogueState; + this._executorState = newExecutorState; + this.debug(`Execution state:`); + this.debug(this._dialogueState!.prettyprint()); - for (const card of formatted) - await this.replyCard(card); - } + const [expect, numResults] = await this._doAgentReply(); - if (expect === null) - return; + for (const [outputType, outputValue] of newResults.slice(0, numResults)) { + const formatted = await this._cardFormatter.formatForType(outputType, outputValue, { removeText: true }); - intent = await this.nextIntent(); + for (const card of formatted) + await this.replyCard(card); } + + await this.setExpected(expect); + return expect === null; } private async _showNotification(appId : string, @@ -365,33 +542,38 @@ export default class DialogueLoop { } } + private async _showWelcome() { + await this._doAgentReply(); + // reset the dialogue state here; if we don't, we we'll see sys_greet as an agent + // dialogue act; this is never seen in training, because in training the user speaks + // first, so it confuses the neural network + this._dialogueState = null; + // the utterance ends with "what can i do for you?", which is expect = 'generic' + // but we don't want to keep the microphone open here, we want to go back to wake-word mode + // so we unconditionally close the round here + await this.setExpected(null); + } + private async _loop(showWelcome : boolean) { // if we want to show the welcome message, we run the policy on the `null` state, which will return the sys_greet intent - if (showWelcome) { - await this._doAgentReply(); - // reset the dialogue state here; if we don't, we we'll see sys_greet as an agent - // dialogue act; this is never seen in training, because in training the user speaks - // first, so it confuses the neural network - this._dialogueState = null; - // the utterance ends with "what can i do for you?", which is expect = 'generic' - // but we don't want to keep the microphone open here, we want to go back to wake-word mode - // so we unconditionally close the round here - await this.setExpected(null); - } + if (showWelcome) + await this._showWelcome(); - for (;;) { + while (!this._stopped) { const item = await this.nextQueueItem(); try { if (item instanceof QueueItem.UserInput) { this._lastNotificationApp = undefined; - await this._handleUserInput(item.intent); + await this._handleUserInput(item.command); } else { await this._handleAPICall(item); this._dialogueState = null; } } catch(e) { if (e.code === 'ECANCELLED') { - await this.reset(); + this.icon = null; + this._dialogueState = null; + await this.setExpected(null); } else { if (item instanceof QueueItem.UserInput) { await this.replyInterp(this._("Sorry, I had an error processing your command: ${error}."), {//" @@ -423,7 +605,7 @@ export default class DialogueLoop { this._mgrResolve!(); const queueItem = await this._notifyQueue.pop(); if (queueItem instanceof QueueItem.UserInput) - this.platformData = queueItem.intent.platformData; + this.platformData = queueItem.command.platformData; else this.platformData = {}; return queueItem; @@ -436,7 +618,7 @@ export default class DialogueLoop { await this.reply(this._("Sorry, I need you to confirm the last question first.")); } else if (this.expecting === ValueCategory.MultipleChoice) { await this.reply(this._("Could you choose one of the following?")); - this.conversation.resendChoices(); + await this._resendChoices(); } else if (this.expecting === ValueCategory.Measure) { await this.reply(this._("Could you give me a measurement?")); } else if (this.expecting === ValueCategory.Number) { @@ -488,9 +670,9 @@ export default class DialogueLoop { if (expected === undefined) throw new TypeError(); this.expecting = expected; - const context = prepareContextForPrediction(this._dialogueState, 'user'); - this.conversation.setContext(context); - this.conversation.setExpected(expected, raw); + this._raw = raw; + const [contextCode, contextEntities] = this._prepareContextForPrediction(this._dialogueState, 'user'); + this.conversation.setExpected(expected, { code: contextCode, entities: contextEntities }); } /** @@ -506,36 +688,61 @@ export default class DialogueLoop { // because otherwise we send it to the parser and the parser will // likely misbehave as it's a state that we've never seen in training await this.setExpected(expected, expected === ValueCategory.Location); - let intent = await this.nextIntent(); - while (!(intent instanceof UserInput.Answer) || intent.category !== expected) { - if (intent instanceof UserInput.UICommand) - await this._handleUICommand(intent); - else + + let analyzed = await this._analyzeCommand(await this.nextCommand()); + while (analyzed.answer === null || typeof analyzed.answer === 'number' || + ValueCategory.fromType(analyzed.answer.getType()) !== expected) { + switch (analyzed.type) { + case CommandAnalysisType.STOP: + case CommandAnalysisType.NEVERMIND: + case CommandAnalysisType.DEBUG: + case CommandAnalysisType.WAKEUP: + await this._handleUICommand(analyzed.type); + break; + + default: await this.fail(); - intent = await this.nextIntent(); + await this.lookingFor(); + } + + analyzed = await this._analyzeCommand(await this.nextCommand()); } - return intent.value; + return analyzed.answer; } + async askChoices(question : string, choices : string[]) : Promise { await this.reply(question); this.setExpected(ValueCategory.MultipleChoice); + this._choices = choices; for (let i = 0; i < choices.length; i++) - await this.replyChoice(i, choices[i]); - let intent = await this.nextIntent(); - while (!(intent instanceof UserInput.MultipleChoiceAnswer)) { - if (intent instanceof UserInput.UICommand) - await this._handleUICommand(intent); - else + await this.conversation.sendChoice(i, choices[i]); + + let analyzed = await this._analyzeCommand(await this.nextCommand()); + while (analyzed.answer === null || typeof analyzed.answer !== 'number' + || analyzed.answer < 0 || analyzed.answer >= choices.length) { + switch (analyzed.type) { + case CommandAnalysisType.STOP: + case CommandAnalysisType.NEVERMIND: + case CommandAnalysisType.DEBUG: + case CommandAnalysisType.WAKEUP: + await this._handleUICommand(analyzed.type); + break; + + default: await this.fail(); - intent = await this.nextIntent(); + await this.lookingFor(); + } + + analyzed = await this._analyzeCommand(await this.nextCommand()); } - return intent.value; + return analyzed.answer; } + private async _resendChoices() { + if (this.expecting !== ValueCategory.MultipleChoice) + console.log('UNEXPECTED: sendChoice while not expecting a MultipleChoice'); - async reset() { - this.icon = null; - this._dialogueState = null; - await this.setExpected(null); + for (let idx = 0; idx < this._choices.length; idx++) + await this.conversation.sendChoice(idx, this._choices[idx]); } async replyInterp(msg : string, args ?: Record, icon : string|null = null) { @@ -566,10 +773,6 @@ export default class DialogueLoop { } } - async replyChoice(idx : number, title : string) { - await this.conversation.sendChoice(idx, title); - } - async replyButton(text : string, json : string) { await this.conversation.sendButton(text, json); } @@ -591,8 +794,11 @@ export default class DialogueLoop { this._pushQueueItem(item); } - start(showWelcome : boolean) { - const promise = this._waitNextIntent(); + async start(showWelcome : boolean) { + await this._nlu.start(); + await this._nlg.start(); + + const promise = this._waitNextCommand(); this._loop(showWelcome).then(() => { throw new Error('Unexpected end of dialog loop'); }, (err) => { @@ -602,16 +808,48 @@ export default class DialogueLoop { return promise; } + async stop() { + this._stopped = true; + + // wait until the dialog is ready to accept commands, then inject + // a cancellation error + await this._mgrPromise; + assert(this._mgrPromise === null); + + if (this._isInDefaultState()) + this._notifyQueue.cancelWait(new CancellationError()); + else + this._userInputQueue.cancelWait(new CancellationError()); + + await this._nlu.stop(); + await this._nlg.stop(); + } + + async reset() { + // wait until the dialog is ready to accept commands + await this._mgrPromise; + assert(this._mgrPromise === null); + + if (this._isInDefaultState()) + this._notifyQueue.cancelWait(new CancellationError()); + else + this._userInputQueue.cancelWait(new CancellationError()); + } + private _pushQueueItem(item : QueueItem) { // ensure that we have something to wait on before the next // command is handled if (!this._mgrPromise) - this._waitNextIntent(); + this._waitNextCommand(); this._notifyQueue.push(item); } - private _waitNextIntent() : Promise { + /** + * Returns a promise that will resolve when the dialogue loop is + * ready to accept the next command from the user. + */ + private _waitNextCommand() : Promise { const promise = new Promise((callback, errback) => { this._mgrResolve = callback; }); @@ -619,21 +857,21 @@ export default class DialogueLoop { return promise; } - pushIntent(intent : UserInput, confident = false) { - this._pushQueueItem(new QueueItem.UserInput(intent, confident)); + pushCommand(command : UserInput) { + this._pushQueueItem(new QueueItem.UserInput(command)); } - async handle(intent : UserInput, confident = false) : Promise { + async handleCommand(command : UserInput) : Promise { // wait until the dialog is ready to accept commands await this._mgrPromise; assert(this._mgrPromise === null); - const promise = this._waitNextIntent(); + const promise = this._waitNextCommand(); if (this._isInDefaultState()) - this.pushIntent(intent, confident); + this.pushCommand(command); else - this._userInputQueue.push(intent); + this._userInputQueue.push(command); return promise; } diff --git a/lib/dialogue-agent/dialogue_queue.ts b/lib/dialogue-agent/dialogue_queue.ts index c8983f0c3..d1308a832 100644 --- a/lib/dialogue-agent/dialogue_queue.ts +++ b/lib/dialogue-agent/dialogue_queue.ts @@ -18,7 +18,7 @@ // // Author: Giovanni Campagna -import type UserInputIntent from './user-input'; +import type { UserInput as Command } from './user-input'; class QueueItem { } @@ -27,13 +27,12 @@ type JSError = Error; namespace QueueItem { export class UserInput extends QueueItem { - constructor(public intent : UserInputIntent, - public confident : boolean) { + constructor(public command : Command) { super(); } toString() { - return `UserInput(${this.intent})`; + return `UserInput(${this.command})`; } } diff --git a/lib/dialogue-agent/user-input.ts b/lib/dialogue-agent/user-input.ts index da1f89e35..4fcab903c 100644 --- a/lib/dialogue-agent/user-input.ts +++ b/lib/dialogue-agent/user-input.ts @@ -18,12 +18,9 @@ // // Author: Giovanni Campagna -import * as Tp from 'thingpedia'; -import { Ast, Type, SchemaRetriever } from 'thingtalk'; +import * as ThingTalk from 'thingtalk'; -import ValueCategory from './value-category'; import { EntityMap } from '../utils/entity-utils'; -import * as ThingTalkUtils from '../utils/thingtalk'; export interface PlatformData { contacts ?: Array<{ @@ -40,155 +37,14 @@ export interface PreparsedCommand { slotTypes ?: Record; } -function parseSpecial(intent : Ast.SpecialControlIntent, context : { platformData : PlatformData }) { - switch (intent.type) { - case 'yes': - return new UserInput.Answer(new Ast.Value.Boolean(true), context.platformData); - case 'no': - return new UserInput.Answer(new Ast.Value.Boolean(false), context.platformData); - default: - return new UserInput.UICommand(intent.type, context.platformData); - } -} - -/** - * Base class for the interpretation of the input from the user, which could be - * a UI action (a button) or a natural language command. - */ -class UserInput { +interface NaturalLanguageUserInput { + type : 'command'; + utterance : string; platformData : PlatformData; - - constructor(platformData : PlatformData) { - this.platformData = platformData; - } - - static fromThingTalk(thingtalk : Ast.Input, context : { platformData : PlatformData }) : UserInput { - if (thingtalk instanceof Ast.ControlCommand) { - if (thingtalk.intent instanceof Ast.SpecialControlIntent) - return parseSpecial(thingtalk.intent, context); - else if (thingtalk.intent instanceof Ast.AnswerControlIntent) - return new UserInput.Answer(thingtalk.intent.value, context.platformData); - else if (thingtalk.intent instanceof Ast.ChoiceControlIntent) - return new UserInput.MultipleChoiceAnswer(thingtalk.intent.value, context.platformData); - else - throw new TypeError(`Unrecognized bookkeeping intent`); - } else if (thingtalk instanceof Ast.Program) { - return new UserInput.Program(thingtalk, context.platformData); - } else if (thingtalk instanceof Ast.DialogueState) { - return new UserInput.DialogueState(thingtalk, context.platformData); - } else { - throw new TypeError(`Unrecognized ThingTalk command: ${thingtalk.prettyprint()}`); - } - } - - static async parse(json : { program : string }|PreparsedCommand, - thingpediaClient : Tp.BaseClient, - schemaRetriever : SchemaRetriever, - context : { platformData : PlatformData }) : Promise { - if ('program' in json) { - return UserInput.fromThingTalk(await ThingTalkUtils.parse(json.program, { - thingpediaClient, - schemaRetriever, - loadMetadata: true - }), context); - } - - const { code, entities } = json; - for (const name in entities) { - if (name.startsWith('SLOT_')) { - const slotname = json.slots![parseInt(name.substring('SLOT_'.length))]; - const slotType = Type.fromString(json.slotTypes![slotname]); - const value = Ast.Value.fromJSON(slotType, entities[name]); - entities[name] = value; - } - } - - const thingtalk = await ThingTalkUtils.parsePrediction(code, entities, { - thingpediaClient, - schemaRetriever, - loadMetadata: true - }, true); - return UserInput.fromThingTalk(thingtalk, context); - } } - -namespace UserInput { - /** - * A natural language command that was parsed correctly but is not supported in - * Thingpedia (it uses Thingpedia classes that are not available). - */ - export class Unsupported extends UserInput {} - - /** - * A natural language command that failed to parse entirely. - */ - export class Failed extends UserInput { - constructor(public utterance : string, - platformData : PlatformData) { - super(platformData); - } - } - - /** - * A special command that bypasses the neural network, or a button on the UI. - */ - export class UICommand extends UserInput { - constructor(public type : string, - platformData : PlatformData) { - super(platformData); - } - } - - /** - * A multiple choice answer. This can be generated by the UI button, - * or by the parser in multiple choice mode. It is only used to disambiguate - * entities and device names - */ - export class MultipleChoiceAnswer extends UserInput { - category : ValueCategory.MultipleChoice = ValueCategory.MultipleChoice; - - constructor(public value : number, - platformData : PlatformData) { - super(platformData); - } - } - - /** - * A single, naked ThingTalk value. This can be generated by the UI pickers - * (file pickers, location pickers, contact pickers, etc.), in certain uses - * of the exact matcher, and when the agent is in raw mode. - */ - export class Answer extends UserInput { - category : ValueCategory; - - constructor(public value : Ast.Value, - platformData : PlatformData) { - super(platformData); - this.category = ValueCategory.fromValue(value); - } - } - - /** - * A single ThingTalk program. This can come from a single-command neural network, - * or from the user typing "\t". - */ - export class Program extends UserInput { - constructor(public program : Ast.Program, - platformData : PlatformData) { - super(platformData); - } - } - - /** - * A prediction ThingTalk dialogue state (policy, dialogue act, statements), which - * is generated by the neural network after parsing the user's input. - */ - export class DialogueState extends UserInput { - constructor(public prediction : Ast.DialogueState, - platformData : PlatformData) { - super(platformData); - } - } +interface ThingTalkUserInput { + type : 'thingtalk'; + parsed : ThingTalk.Ast.Input; + platformData : PlatformData; } - -export default UserInput; +export type UserInput = NaturalLanguageUserInput|ThingTalkUserInput; diff --git a/lib/utils/thingtalk/index.ts b/lib/utils/thingtalk/index.ts index ef7398d2a..de170cbbb 100644 --- a/lib/utils/thingtalk/index.ts +++ b/lib/utils/thingtalk/index.ts @@ -31,6 +31,7 @@ import { extractConstants, createConstants } from './constants'; export * from './describe'; export * from './syntax'; export * from './dialogue_state_utils'; +import { computePrediction } from './dialogue_state_utils'; export * from './example-utils'; // reexport clean from misc-utils import { clean } from '../misc-utils'; @@ -130,3 +131,54 @@ class StateValidator { export function createStateValidator(policyManifest ?: string) : StateValidator { return new StateValidator(policyManifest); } + +interface DialoguePolicy { + handleAnswer(state : Ast.DialogueState, value : Ast.Value) : Promise; +} + +export async function inputToDialogueState(policy : DialoguePolicy, + context : Ast.DialogueState|null, + input : Ast.Input) : Promise { + if (input instanceof Ast.ControlCommand) { + if (context === null) + return null; + + if (input.intent instanceof Ast.SpecialControlIntent) { + switch (input.intent.type) { + case 'yes': + case 'no': { + const value = new Ast.BooleanValue(input.intent.type === 'yes'); + const handled = await policy.handleAnswer(context, value); + if (!handled) + return null; + return computePrediction(context, handled, 'user'); + } + default: + return null; + } + } + if (input.intent instanceof Ast.ChoiceControlIntent) + return null; + + if (input.intent instanceof Ast.AnswerControlIntent) { + const handled = await policy.handleAnswer(context, input.intent.value); + if (!handled) + return null; + return computePrediction(context, handled, 'user'); + } + + throw new TypeError(`Unrecognized bookkeeping intent`); + } else if (input instanceof Ast.Program) { + // convert thingtalk programs to dialogue states so we can use "\t" without too much typing + const prediction = new Ast.DialogueState(null, 'org.thingpedia.dialogue.transaction', 'execute', null, []); + for (const stmt of input.statements) { + if (stmt instanceof Ast.Assignment) + throw new Error(`Unsupported: assignment statement`); + prediction.history.push(new Ast.DialogueHistoryItem(null, stmt, null, 'accepted')); + } + return prediction; + } + + assert(input instanceof Ast.DialogueState); + return input; +} diff --git a/test/agent/index.js b/test/agent/index.js index de9597453..825909443 100644 --- a/test/agent/index.js +++ b/test/agent/index.js @@ -159,7 +159,7 @@ class MockUser { async function mockNLU(conversation) { // inject some mocking in the parser: - conversation._nlu.onlineLearn = function(utterance, targetCode) { + conversation._loop._nlu.onlineLearn = function(utterance, targetCode) { if (utterance === 'get an xkcd comic') assert.strictEqual(targetCode.join(' '), 'now => @com.xkcd.get_comic => notify'); else if (utterance === '!! test command multiple results !!') @@ -171,8 +171,8 @@ async function mockNLU(conversation) { const commands = yaml.safeLoad(await util.promisify(fs.readFile)( path.resolve(path.dirname(module.filename), './mock-nlu.yaml'))); - const realSendUtterance = conversation._nlu.sendUtterance; - conversation._nlu.sendUtterance = async function(utterance) { + const realSendUtterance = conversation._loop._nlu.sendUtterance; + conversation._loop._nlu.sendUtterance = async function(utterance) { if (utterance === '!! test command host unreach !!') { const e = new Error('Host is unreachable'); e.code = 'EHOSTUNREACH'; diff --git a/tool/interactive-annotate.ts b/tool/interactive-annotate.ts index bbafd6544..3175323c3 100644 --- a/tool/interactive-annotate.ts +++ b/tool/interactive-annotate.ts @@ -18,7 +18,6 @@ // // Author: Giovanni Campagna -import assert from 'assert'; import * as argparse from 'argparse'; import * as fs from 'fs'; import * as readline from 'readline'; @@ -450,45 +449,7 @@ class Annotator extends events.EventEmitter { } private async _inputToDialogueState(input : ThingTalk.Ast.Input) : Promise { - if (input instanceof ThingTalk.Ast.ControlCommand) { - if (input.intent instanceof ThingTalk.Ast.SpecialControlIntent) { - switch (input.intent.type) { - case 'yes': - case 'no': { - const value = new ThingTalk.Ast.BooleanValue(input.intent.type === 'yes'); - const handled = await this._dialoguePolicy.handleAnswer(this._context, value); - if (!handled) - return null; - return ThingTalkUtils.computePrediction(this._context, handled, 'user'); - } - default: - return null; - } - } - if (input.intent instanceof ThingTalk.Ast.ChoiceControlIntent) - return null; - - if (input.intent instanceof ThingTalk.Ast.AnswerControlIntent) { - const handled = await this._dialoguePolicy.handleAnswer(this._context, input.intent.value); - if (!handled) - return null; - return ThingTalkUtils.computePrediction(this._context, handled, 'user'); - } - - throw new TypeError(`Unrecognized bookkeeping intent`); - } else if (input instanceof ThingTalk.Ast.Program) { - // convert thingtalk programs to dialogue states so we can use "\t" without too much typing - const prediction = new ThingTalk.Ast.DialogueState(null, 'org.thingpedia.dialogue.transaction', 'execute', null, []); - for (const stmt of input.statements) { - if (stmt instanceof ThingTalk.Ast.Assignment) - throw new Error(`Unsupported: assignment statement`); - prediction.history.push(new ThingTalk.Ast.DialogueHistoryItem(null, stmt, null, 'accepted')); - } - return prediction; - } - - assert(input instanceof ThingTalk.Ast.DialogueState); - return input; + return ThingTalkUtils.inputToDialogueState(this._dialoguePolicy, this._context, input); } private async _handleUtterance(utterance : string) { From e1348d6d1ed968bf22f922381f424e3cdea2fd20 Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Fri, 5 Feb 2021 15:17:54 -0800 Subject: [PATCH 4/6] dialogue-loop: implement confidence-based dispatch With hardcoded thresholds of 0.5 confidence, we either handle the command, ask the user for confirmation, or report some kind of error. --- lib/dialogue-agent/dialogue-loop.ts | 144 +++++++++++++++++++++------- test/agent/index.js | 2 +- test/agent/tests.txt | 2 +- 3 files changed, 110 insertions(+), 38 deletions(-) diff --git a/lib/dialogue-agent/dialogue-loop.ts b/lib/dialogue-agent/dialogue-loop.ts index b4b0bf192..40a485460 100644 --- a/lib/dialogue-agent/dialogue-loop.ts +++ b/lib/dialogue-agent/dialogue-loop.ts @@ -49,23 +49,58 @@ import CardFormatter, { FormattedChunk } from './card-output/card-formatter'; import ExecutionDialogueAgent from './execution_dialogue_agent'; -const ENABLE_SUGGESTIONS = false; - // TODO: load the policy.yaml file instead const POLICY_NAME = 'org.thingpedia.dialogue.transaction'; const TERMINAL_STATES = [ 'sys_end', 'sys_action_success' ]; +// Confidence thresholds: +// +// The API returns two global scores associated with +// the utterance (the "command" score and the "ignore" score), +// and a confidence score on each candidate parse. +// (There is a third score, the "other" score, which is exactly +// 1-command-ignore) +// +// (See LocalParserClient for how these scores are computed +// from the raw confidence scores produced by genienlp) +// +// - If the "ignore" score is greater than IGNORE_THRESHOLD +// we ignore this command entirely, do nothing and say +// nothing. Typically, this means a spurious wakeword activation, +// or a command that was truncated midway by the microphone. +// In Alexa, the ring would light up, turn off, and nothing would happen. +// +// - If the "command" score is less than IN_DOMAIN_THRESHOLD, +// we ship this command out to other backends silently, or tell +// the user that the command is not supported. +// +// - If we have a best valid parse, and the confidence of that parse +// is greater than CONFIDENCE_CONFIRM_THRESHOLD, we run the command +// without further confirmation. +// +// - If we have a best valid parse, and the confidence of that parse +// is greater than CONFIDENCE_FAILURE_THRESHOLD, we ask the user +// for additional confirmation before executing. +// +// - In all other cases, we tell the user we did not understand. +const CONFIDENCE_CONFIRM_THRESHOLD = 0.5; +const CONFIDENCE_FAILURE_THRESHOLD = 0.25; +const IN_DOMAIN_THRESHOLD = 0.5; +const IGNORE_THRESHOLD = 0.5; + enum CommandAnalysisType { - // special commands + // special commands - these are generated by the exact matcher, or + // by UI buttons like the "X" button STOP, NEVERMIND, WAKEUP, DEBUG, // some sort of command - IN_DOMAIN_COMMAND, + CONFIDENT_IN_DOMAIN_COMMAND, + NONCONFIDENT_IN_DOMAIN_COMMAND, OUT_OF_DOMAIN_COMMAND, PARSE_FAILURE, @@ -215,7 +250,7 @@ export default class DialogueLoop { } // anything else is automatically in-domain - return CommandAnalysisType.IN_DOMAIN_COMMAND; + return CommandAnalysisType.CONFIDENT_IN_DOMAIN_COMMAND; } private _maybeGetThingTalkAnswer(input : Ast.Input) : Ast.Value|number|null { @@ -245,7 +280,7 @@ export default class DialogueLoop { return { type, answer: this._maybeGetThingTalkAnswer(command.parsed), - parsed: type === CommandAnalysisType.IN_DOMAIN_COMMAND ? command.parsed : null + parsed: command.parsed }; } @@ -259,7 +294,7 @@ export default class DialogueLoop { else value = new Ast.Value.String(command.utterance); return { - type: CommandAnalysisType.IN_DOMAIN_COMMAND, + type: CommandAnalysisType.CONFIDENT_IN_DOMAIN_COMMAND, answer: value, parsed: new Ast.ControlCommand(null, new Ast.AnswerControlIntent(null, value)) }; @@ -286,13 +321,26 @@ export default class DialogueLoop { throw e; } } + if (nluResult.intent.ignore >= IGNORE_THRESHOLD) { + this.debug('Ignored likely spurious command'); + return { + type: CommandAnalysisType.IGNORE, + answer: null, + parsed: null + }; + } + + if (nluResult.intent.command < IN_DOMAIN_THRESHOLD) { + this.debug('Analyzed as out-of-domain command'); + return { + type: CommandAnalysisType.OUT_OF_DOMAIN_COMMAND, + answer: null, + parsed: null + }; + } // parse all code sequences into an Intent // this will correctly filter out anything that does not parse - if (nluResult.candidates.length > 0) - this.debug('Analyzed message into ' + nluResult.candidates[0].code.join(' ')); - else - this.debug('Failed to analyze message'); const candidates = await Promise.all(nluResult.candidates.map(async (candidate, beamposition) => { let parsed; try { @@ -323,36 +371,26 @@ export default class DialogueLoop { type = this._getSpecialThingTalkType(choice.parsed); } - if (type === CommandAnalysisType.PARSE_FAILURE) + if (type === CommandAnalysisType.PARSE_FAILURE || + choice.score < CONFIDENCE_FAILURE_THRESHOLD) { + type = CommandAnalysisType.PARSE_FAILURE; + this.debug('Failed to analyze message'); this.conversation.stats.hit('sabrina-failure'); - else + } else if (choice.score < CONFIDENCE_CONFIRM_THRESHOLD) { + type = CommandAnalysisType.NONCONFIDENT_IN_DOMAIN_COMMAND; + this.debug('Dubiously analyzed message into ' + choice.parsed.prettyprint()); + this.conversation.stats.hit('sabrina-command-maybe'); + } else { + this.debug('Confidently analyzed message into ' + choice.parsed.prettyprint()); this.conversation.stats.hit('sabrina-command-good'); - + } return { type, answer: this._maybeGetThingTalkAnswer(choice.parsed), - parsed: type === CommandAnalysisType.IN_DOMAIN_COMMAND ? choice.parsed : null + parsed: choice.parsed }; } - private async _getFallbackExamples(command : string) { - const dataset = await this.conversation.thingpedia.getExamplesByKey(command); - const examples = ENABLE_SUGGESTIONS ? await Helpers.loadExamples(dataset, this.conversation.schemas, 5) : []; - - if (examples.length === 0) { - await this.reply(this._("Sorry, I did not understand that.")); - return; - } - - this.conversation.stats.hit('sabrina-fallback-buttons'); - - // don't sort the examples, they come already sorted from Thingpedia - - await this.reply(this._("Sorry, I did not understand that. Try the following instead:")); - for (const ex of examples) - this.replyButton(Helpers.presentExample(this, ex.utterance), JSON.stringify(ex.target)); - } - private _useNeuralNLG() : boolean { return this._prefs.get('experimental-use-neural-nlg') as boolean; } @@ -391,7 +429,7 @@ export default class DialogueLoop { return [expect, numResults]; } - private async _handleUICommand(type : CommandAnalysisType.STOP|CommandAnalysisType.NEVERMIND|CommandAnalysisType.DEBUG|CommandAnalysisType.WAKEUP) { + private async _handleUICommand(type : CommandAnalysisType) { switch (type) { case CommandAnalysisType.STOP: // stop means cancel, but without a failure message @@ -416,9 +454,24 @@ export default class DialogueLoop { // keep the microphone open for a while await this.setExpected(ValueCategory.Command); } + + case CommandAnalysisType.IGNORE: + // do exactly nothing + break; } } + private async _describeProgram(program : Ast.Input) { + const describer = new ThingTalkUtils.Describer(this.conversation.locale, this.conversation.timezone); + // retrieve the relevant primitive templates + const kinds = new Set(); + for (const [, prim] of program.iteratePrimitives(false)) + kinds.add(prim.selector.kind); + for (const kind of kinds) + describer.setDataset(kind, await this.engine.schemas.getExamplesByKind(kind)); + return describer.describe(program); + } + private async _handleUserInput(command : UserInput) { for (;;) { const analyzed = await this._analyzeCommand(command); @@ -428,11 +481,12 @@ export default class DialogueLoop { case CommandAnalysisType.NEVERMIND: case CommandAnalysisType.DEBUG: case CommandAnalysisType.WAKEUP: + case CommandAnalysisType.IGNORE: await this._handleUICommand(analyzed.type); break; case CommandAnalysisType.PARSE_FAILURE: - await this._getFallbackExamples(command.type === 'command' ? command.utterance : ''); + await this.fail(); break; case CommandAnalysisType.OUT_OF_DOMAIN_COMMAND: @@ -440,6 +494,22 @@ export default class DialogueLoop { await this.reply(this._("Sorry, I don't know how to do that yet.")); throw new CancellationError(); + case CommandAnalysisType.NONCONFIDENT_IN_DOMAIN_COMMAND: { + // TODO move this to the state machine, not here + const confirmation = await this._describeProgram(analyzed.parsed!); + const question = this.interpolate(this._("Did you mean ${command}?"), { command: confirmation }); + const yesNo = await this.ask(ValueCategory.YesNo, question); + assert(yesNo instanceof Ast.BooleanValue); + if (!yesNo.value) { + // preserve the dialogue state and the expecting state here + await this.reply(this._("Sorry I couldn't help on that. Would you like to try again?")); + continue; + } + + // fallthrough to the confident case + } + + case CommandAnalysisType.CONFIDENT_IN_DOMAIN_COMMAND: default: { // everything else is an in-domain command const prediction = await ThingTalkUtils.inputToDialogueState(this._policy, this._dialogueState, analyzed.parsed!); @@ -680,7 +750,7 @@ export default class DialogueLoop { * * This is a legacy method used for certain scripted interactions. */ - async ask(expected : ValueCategory.PhoneNumber|ValueCategory.EmailAddress|ValueCategory.Location|ValueCategory.Time, + async ask(expected : ValueCategory.YesNo|ValueCategory.PhoneNumber|ValueCategory.EmailAddress|ValueCategory.Location|ValueCategory.Time, question : string, args ?: Record) : Promise { await this.replyInterp(question, args); @@ -697,6 +767,7 @@ export default class DialogueLoop { case CommandAnalysisType.NEVERMIND: case CommandAnalysisType.DEBUG: case CommandAnalysisType.WAKEUP: + case CommandAnalysisType.IGNORE: await this._handleUICommand(analyzed.type); break; @@ -725,6 +796,7 @@ export default class DialogueLoop { case CommandAnalysisType.NEVERMIND: case CommandAnalysisType.DEBUG: case CommandAnalysisType.WAKEUP: + case CommandAnalysisType.IGNORE: await this._handleUICommand(analyzed.type); break; diff --git a/test/agent/index.js b/test/agent/index.js index 825909443..d4c96a64f 100644 --- a/test/agent/index.js +++ b/test/agent/index.js @@ -188,7 +188,7 @@ async function mockNLU(conversation) { err.code = command.error.code; throw err; } - return { tokens, entities, candidates: command.candidates }; + return { tokens, entities, candidates: command.candidates, intent: { ignore: 0, command: 1, other: 0 } }; } } diff --git a/test/agent/tests.txt b/test/agent/tests.txt index fd62c5be0..13cf16ddb 100644 --- a/test/agent/tests.txt +++ b/test/agent/tests.txt @@ -265,7 +265,7 @@ A: >> expecting = null U: !! test command always failed !! -A: Sorry, I did not understand that. +A: Sorry, I did not understand that. Can you rephrase it? A: >> context = null // {} A: >> expecting = null From d76980f0ff6f548ec205166706a78eb078876a13 Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Thu, 11 Feb 2021 12:05:17 -0800 Subject: [PATCH 5/6] Update confidence computation to the final naming --- lib/prediction/localparserclient.ts | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/lib/prediction/localparserclient.ts b/lib/prediction/localparserclient.ts index d25c8d583..84568ce1f 100644 --- a/lib/prediction/localparserclient.ts +++ b/lib/prediction/localparserclient.ts @@ -223,15 +223,20 @@ export default class LocalParserClient { assert(candidates.length > 0); result = candidates.map((c) => { + // convert is_correct and is_probably_correct scores into + // a single scale such that >0.5 is correct and >0.25 is + // probably correct + const score = (c.score.is_correct ?? 1) > 0.5 ? (c.score.is_correct ?? 1) : + ((c.score.is_probably_correct ?? 1) * 0.5); return { code: c.answer.split(' '), - score: c.score.confidence ?? 1 + score: score }; }); - intent.ignore = candidates[0].score.ignore ?? 0; - intent.command = (candidates[0].score.in_domain ?? 1) * (1 - intent.ignore); - intent.other = 1 - intent.ignore - intent.command; + intent.ignore = candidates[0].score.is_junk ?? 0; + intent.other = (candidates[0].score.is_ood ?? 0) * (1 - intent.ignore); + intent.command = 1 - intent.ignore - intent.other; } let result2 = result!; // guaranteed not null From c12bc3369b58ddc7312b547e10245af3da3087a0 Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Thu, 11 Feb 2021 17:16:08 -0800 Subject: [PATCH 6/6] prediction: use discrete scores for junk/OOD These scores cannot be treated as probabilities, but the API treats them like such. --- lib/prediction/localparserclient.ts | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/lib/prediction/localparserclient.ts b/lib/prediction/localparserclient.ts index 84568ce1f..c12b4e9e2 100644 --- a/lib/prediction/localparserclient.ts +++ b/lib/prediction/localparserclient.ts @@ -226,16 +226,22 @@ export default class LocalParserClient { // convert is_correct and is_probably_correct scores into // a single scale such that >0.5 is correct and >0.25 is // probably correct - const score = (c.score.is_correct ?? 1) > 0.5 ? (c.score.is_correct ?? 1) : - ((c.score.is_probably_correct ?? 1) * 0.5); + const score = (c.score.is_correct ?? 1) >= 0.5 ? 1 : + (c.score.is_probably_correct ?? 1) >= 0.5 ? 0.35 : 0.15; return { code: c.answer.split(' '), score: score }; }); - intent.ignore = candidates[0].score.is_junk ?? 0; - intent.other = (candidates[0].score.is_ood ?? 0) * (1 - intent.ignore); + if (candidates[0].score.is_junk >= 0.5) + intent.ignore = 1; + else + intent.ignore = 0; + if (intent.ignore < 0.5 && candidates[0].score.is_ood >= 0.5) + intent.other = 1; + else + intent.other = 0; intent.command = 1 - intent.ignore - intent.other; }