Skip to content

Commit

Permalink
Merge pull request #421 from stanford-oval/wip/confidence
Browse files Browse the repository at this point in the history
Confidence-based dispatch
  • Loading branch information
gcampax committed Feb 12, 2021
2 parents 2d0f5ee + 2f70c68 commit 2111a8f
Show file tree
Hide file tree
Showing 13 changed files with 1,264 additions and 815 deletions.
282 changes: 60 additions & 222 deletions lib/dialogue-agent/conversation.ts

Large diffs are not rendered by default.

709 changes: 517 additions & 192 deletions lib/dialogue-agent/dialogue-loop.ts

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions lib/dialogue-agent/dialogue_queue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//
// Author: Giovanni Campagna <gcampagn@cs.stanford.edu>

import type UserInputIntent from './user-input';
import type { UserInput as Command } from './user-input';

class QueueItem {
}
Expand All @@ -27,13 +27,12 @@ type JSError = Error;

namespace QueueItem {
export class UserInput extends QueueItem {
constructor(public intent : UserInputIntent,
public confident : boolean) {
constructor(public command : Command) {
super();
}

toString() {
return `UserInput(${this.intent})`;
return `UserInput(${this.command})`;
}
}

Expand Down
165 changes: 9 additions & 156 deletions lib/dialogue-agent/user-input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
//
// Author: Giovanni Campagna <gcampagn@cs.stanford.edu>

import * as Tp from 'thingpedia';
import { Ast, Type, SchemaRetriever } from 'thingtalk';
import * as ThingTalk from 'thingtalk';

import ValueCategory from './value-category';
import { EntityMap } from '../utils/entity-utils';
import * as ThingTalkUtils from '../utils/thingtalk';

export interface PlatformData {
contacts ?: Array<{
Expand All @@ -40,158 +37,14 @@ export interface PreparsedCommand {
slotTypes ?: Record<string, string>;
}

function parseSpecial(intent : Ast.SpecialControlIntent,
context : { command : string|null, platformData : PlatformData }) {
switch (intent.type) {
case 'yes':
return new UserInput.Answer(new Ast.Value.Boolean(true), context.command, context.platformData);
case 'no':
return new UserInput.Answer(new Ast.Value.Boolean(false), context.command, context.platformData);
default:
return new UserInput.UICommand(intent.type, context.platformData);
}
}

/**
* Base class for the interpretation of the input from the user, which could be
* a UI action (a button) or a natural language command.
*/
class UserInput {
utterance : string|null;
interface NaturalLanguageUserInput {
type : 'command';
utterance : string;
platformData : PlatformData;

constructor(utterance : string|null, platformData : PlatformData) {
this.utterance = utterance;
this.platformData = platformData;
}

static fromThingTalk(thingtalk : Ast.Input,
context : { command : string|null, platformData : PlatformData }) : UserInput {
if (thingtalk instanceof Ast.ControlCommand) {
if (thingtalk.intent instanceof Ast.SpecialControlIntent)
return parseSpecial(thingtalk.intent, context);
else if (thingtalk.intent instanceof Ast.AnswerControlIntent)
return new UserInput.Answer(thingtalk.intent.value, context.command, context.platformData);
else if (thingtalk.intent instanceof Ast.ChoiceControlIntent)
return new UserInput.MultipleChoiceAnswer(thingtalk.intent.value, context.command, context.platformData);
else
throw new TypeError(`Unrecognized bookkeeping intent`);
} else if (thingtalk instanceof Ast.Program) {
return new UserInput.Program(thingtalk, context.command, context.platformData);
} else if (thingtalk instanceof Ast.DialogueState) {
return new UserInput.DialogueState(thingtalk, context.command, context.platformData);
} else {
throw new TypeError(`Unrecognized ThingTalk command: ${thingtalk.prettyprint()}`);
}
}

static async parse(json : { program : string }|PreparsedCommand,
thingpediaClient : Tp.BaseClient,
schemaRetriever : SchemaRetriever,
context : { command : string|null, platformData : PlatformData }) : Promise<UserInput> {
if ('program' in json) {
return UserInput.fromThingTalk(await ThingTalkUtils.parse(json.program, {
thingpediaClient,
schemaRetriever,
loadMetadata: true
}), context);
}

const { code, entities } = json;
for (const name in entities) {
if (name.startsWith('SLOT_')) {
const slotname = json.slots![parseInt(name.substring('SLOT_'.length))];
const slotType = Type.fromString(json.slotTypes![slotname]);
const value = Ast.Value.fromJSON(slotType, entities[name]);
entities[name] = value;
}
}

const thingtalk = await ThingTalkUtils.parsePrediction(code, entities, {
thingpediaClient,
schemaRetriever,
loadMetadata: true
}, true);
return UserInput.fromThingTalk(thingtalk, context);
}
}

namespace UserInput {
/**
* A natural language command that was parsed correctly but is not supported in
* Thingpedia (it uses Thingpedia classes that are not available).
*/
export class Unsupported extends UserInput {}

/**
* A natural language command that failed to parse entirely.
*/
export class Failed extends UserInput {}

/**
* A special command that bypasses the neural network, or a button on the UI.
*/
export class UICommand extends UserInput {
constructor(public type : string,
platformData : PlatformData) {
super(null, platformData);
}
}

/**
* A multiple choice answer. This can be generated by the UI button,
* or by the parser in multiple choice mode. It is only used to disambiguate
* entities and device names
*/
export class MultipleChoiceAnswer extends UserInput {
category : ValueCategory.MultipleChoice = ValueCategory.MultipleChoice;

constructor(public value : number,
utterance : string|null,
platformData : PlatformData) {
super(utterance, platformData);
}
}

/**
* A single, naked ThingTalk value. This can be generated by the UI pickers
* (file pickers, location pickers, contact pickers, etc.), in certain uses
* of the exact matcher, and when the agent is in raw mode.
*/
export class Answer extends UserInput {
category : ValueCategory;

constructor(public value : Ast.Value,
utterance : string|null,
platformData : PlatformData) {
super(utterance, platformData);
this.category = ValueCategory.fromValue(value);
}
}

/**
* A single ThingTalk program. This can come from a single-command neural network,
* or from the user typing "\t".
*/
export class Program extends UserInput {
constructor(public program : Ast.Program,
utterance : string|null,
platformData : PlatformData) {
super(utterance, platformData);
}
}

/**
* A prediction ThingTalk dialogue state (policy, dialogue act, statements), which
* is generated by the neural network after parsing the user's input.
*/
export class DialogueState extends UserInput {
constructor(public prediction : Ast.DialogueState,
utterance : string|null,
platformData : PlatformData) {
super(utterance, platformData);
}
}
interface ThingTalkUserInput {
type : 'thingtalk';
parsed : ThingTalk.Ast.Input;
platformData : PlatformData;
}

export default UserInput;
export type UserInput = NaturalLanguageUserInput|ThingTalkUserInput;
67 changes: 45 additions & 22 deletions lib/prediction/localparserclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//
// Author: Giovanni Campagna <gcampagn@cs.stanford.edu>

import assert from 'assert';
import * as Tp from 'thingpedia';
import * as ThingTalk from 'thingtalk';

Expand Down Expand Up @@ -179,6 +180,11 @@ export default class LocalParserClient {

let result : PredictionCandidate[]|null = null;
let exact : string[][]|null = null;
const intent = {
command: 1,
other: 0,
ignore: 0
};

if (tokens.length === 0) {
result = [{
Expand Down Expand Up @@ -206,27 +212,37 @@ export default class LocalParserClient {
}

if (result === null) {
if (options.expect === 'Location') {
result = [{
code: ['$answer', '(', 'new', 'Location', '(', '"', ...tokens, '"', ')', ')', ';'],
score: 1
}];
} else {
if (contextCode)
contextCode = this._applyPreHeuristics(contextCode);
if (contextCode)
contextCode = this._applyPreHeuristics(contextCode);

let candidates;
if (contextCode)
candidates = await this._predictor.predict(contextCode.join(' '), tokens.join(' '), answer, NLU_TASK, options.example_id);
else
candidates = await this._predictor.predict(tokens.join(' '), undefined, answer, SEMANTIC_PARSING_TASK, options.example_id);
result = candidates.map((c) => {
return {
code: c.answer.split(' '),
score: c.score
};
});
}
let candidates;
if (contextCode)
candidates = await this._predictor.predict(contextCode.join(' '), tokens.join(' '), answer, NLU_TASK, options.example_id);
else
candidates = await this._predictor.predict(tokens.join(' '), undefined, answer, SEMANTIC_PARSING_TASK, options.example_id);
assert(candidates.length > 0);

result = candidates.map((c) => {
// convert is_correct and is_probably_correct scores into
// a single scale such that >0.5 is correct and >0.25 is
// probably correct
const score = (c.score.is_correct ?? 1) >= 0.5 ? 1 :
(c.score.is_probably_correct ?? 1) >= 0.5 ? 0.35 : 0.15;
return {
code: c.answer.split(' '),
score: score
};
});

if (candidates[0].score.is_junk >= 0.5)
intent.ignore = 1;
else
intent.ignore = 0;
if (intent.ignore < 0.5 && candidates[0].score.is_ood >= 0.5)
intent.other = 1;
else
intent.other = 0;
intent.command = 1 - intent.ignore - intent.other;
}

let result2 = result!; // guaranteed not null
Expand Down Expand Up @@ -263,11 +279,18 @@ export default class LocalParserClient {
result: 'ok',
tokens: tokens,
candidates: result2,
entities: entities
entities: entities,
intent
};
}

async generateUtterance(contextCode : string[], contextEntities : EntityMap, targetAct : string[]) : Promise<GenerationResult[]> {
return this._predictor.predict(contextCode.join(' ') + ' ' + targetAct.join(' '), NLG_QUESTION, undefined, NLG_TASK);
const candidates = await this._predictor.predict(contextCode.join(' ') + ' ' + targetAct.join(' '), NLG_QUESTION, undefined, NLG_TASK);
return candidates.map((cand) => {
return {
answer: cand.answer,
score: cand.score.confidence ?? 1
};
});
}
}
Loading

0 comments on commit 2111a8f

Please sign in to comment.