Skip to content

Commit

Permalink
Merge pull request #442 from stanford-oval/wip/simulate-with-parser
Browse files Browse the repository at this point in the history
`simulate-dialogs` with parser
  • Loading branch information
s-jse committed Feb 5, 2021
2 parents 63974d5 + b37b82b commit d418bae
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ lib/engine/devices/builtins/faq.json

workdir*
genienlp/
.embeddings/

env/

Expand Down
15 changes: 10 additions & 5 deletions languages/thingtalk/en/dlg/results.genie
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ result_name_pair : D.NameList = {
const r1 = results[0];
const r2 = results[1];
assert(r1.value.id.equals(n1));
assert(r2.value.id.equals(n2));
if (!r1.value.id.equals(n1)) // can happen if the types are different but the value is the same
return null;
if (!r2.value.id.equals(n2))
return null;
return { ctx, results: results.slice(0, 2) };
};
}
Expand All @@ -53,8 +55,10 @@ result_name_list : D.NameList = {

const r1 = results[0];
const r2 = results[1];
assert(r1.value.id.equals(n1));
assert(r2.value.id.equals(n2));
if (!r1.value.id.equals(n1))
return null;
if (!r2.value.id.equals(n2))
return null;
return { ctx, results: results.slice(0, 2) };
};

Expand All @@ -64,7 +68,8 @@ result_name_list : D.NameList = {
assert(results && results.length >= 3);
assert(ctx.resultInfo!.idType !== null);
const r3 = results[2];
assert(r3.value.id.equals(n3));
if (!r3.value.id.equals(n3))
return null;
return { ctx, results: results.slice(0, 3) };
};
}
Expand Down
2 changes: 1 addition & 1 deletion lib/utils/entity-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ function makeDummyEntity(token : string) : AnyEntity {
function makeDummyEntities(preprocessed : string) : EntityMap {
const entities : EntityMap = {};
for (const token of preprocessed.split(' ')) {
if (/^[A-Z]/.test(token))
if (ENTITY_MATCH_REGEX.test(token))
entities[token] = makeDummyEntity(token);
}
return entities;
Expand Down
178 changes: 161 additions & 17 deletions tool/simulate-dialogs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ import {
DialogueExample,
} from '../lib/dataset-tools/parsers';
import DialoguePolicy from '../lib/dialogue-agent/dialogue_policy';
import * as ParserClient from '../lib/prediction/parserclient';

import { readAllLines } from './lib/argutils';
import MultiJSONDatabase from './lib/multi_json_database';
import { PredictionResult } from '../lib/prediction/parserclient';

export function initArgparse(subparsers : argparse.SubParser) {
const parser = subparsers.add_parser('simulate-dialogs', {
Expand Down Expand Up @@ -80,36 +82,110 @@ export function initArgparse(subparsers : argparse.SubParser) {
type: fs.createReadStream,
help: 'Input dialog file'
});
parser.add_argument('--nlu-server', {
required: false,
help: `The URL of the natural language server to parse user utterances. Use a file:// URL pointing to a model directory to use a local instance of genienlp.
If provided, will be used to parse the last user utterance instead of reading the parse from input_file.`
});
parser.add_argument('--output-mistakes-only', {
action: 'store_true',
help: 'If set and --nlu-server is provided, will only output partial dialogues where a parsing mistake happens.',
default: false
});
parser.add_argument('--all-turns', {
action: 'store_true',
help: `If set, will run simulation on all dialogue turns as opposed to only the last turn (but still for one turn only).
The output will have as many partial dialogues as there are dialogue turns in the input.`,
default: false
});
}

class SimulatorStream extends Stream.Transform {
private _simulator : ThingTalkUtils.Simulator;
private _schemas : ThingTalk.SchemaRetriever;
private _dialoguePolicy : DialoguePolicy;
private _parser : ParserClient.ParserClient | null;
private _tpClient : Tp.BaseClient;
private _outputMistakesOnly : boolean;
private _locale : string;

constructor(policy : DialoguePolicy,
simulator : ThingTalkUtils.Simulator,
schemas : ThingTalk.SchemaRetriever) {
schemas : ThingTalk.SchemaRetriever,
parser : ParserClient.ParserClient | null,
tpClient : Tp.BaseClient,
outputMistakesOnly : boolean,
locale : string) {
super({ objectMode : true });

this._dialoguePolicy = policy;
this._simulator = simulator;
this._schemas = schemas;
this._parser = parser;
this._tpClient = tpClient;
this._outputMistakesOnly = outputMistakesOnly;
this._locale = locale;
}

async _run(dlg : ParsedDialogue) : Promise<DialogueExample> {
async _run(dlg : ParsedDialogue) : Promise<void> {
console.log('dialogue = ', dlg.id);
const lastTurn = dlg[dlg.length-1];

let state = null;
let contextCode, contextEntities;
if (lastTurn.context) {
const context = await ThingTalkUtils.parse(lastTurn.context, this._schemas);
assert(context instanceof ThingTalk.Ast.DialogueState);
const agentTarget = await ThingTalkUtils.parse(lastTurn.agent_target!, this._schemas);
assert(agentTarget instanceof ThingTalk.Ast.DialogueState);
state = ThingTalkUtils.computeNewState(context, agentTarget, 'agent');
[contextCode, contextEntities] = ThingTalkUtils.serializeNormalized(ThingTalkUtils.prepareContextForPrediction(state, 'user'));
} else {
contextCode = ['null'];
contextEntities = {};
}

const userTarget = await ThingTalkUtils.parse(lastTurn.user_target, this._schemas);
let userTarget : ThingTalk.Ast.Input;
const goldUserTarget = await ThingTalkUtils.parse(lastTurn.user_target, this._schemas);
if (this._parser !== null) {
const parsed : PredictionResult = await this._parser.sendUtterance(lastTurn.user, contextCode, contextEntities, {
tokenized: false,
skip_typechecking: true
});

const candidates = await ThingTalkUtils.parseAllPredictions(parsed.candidates, parsed.entities, {
thingpediaClient: this._tpClient,
schemaRetriever: this._schemas,
loadMetadata: true
}) as ThingTalk.Ast.DialogueState[];

if (candidates.length > 0) {
userTarget = candidates[0];
} else {
console.log(`No valid candidate parses for this command. Top candidate was ${parsed.candidates[0].code.join(' ')}. Using the gold UT`);
userTarget = goldUserTarget;
}
const normalizedUserTarget : string = ThingTalkUtils.serializePrediction(userTarget, parsed.tokens, parsed.entities, {
locale: this._locale,
ignoreSentence: true
}).join(' ');
const normalizedGoldUserTarget : string = ThingTalkUtils.serializePrediction(goldUserTarget, parsed.tokens, parsed.entities, {
locale: this._locale,
ignoreSentence: true
}).join(' ');

// console.log('normalizedUserTarget = ', normalizedUserTarget)
// console.log('normalizedGoldUserTarget = ', normalizedGoldUserTarget)

if (normalizedUserTarget === normalizedGoldUserTarget && this._outputMistakesOnly) {
// don't push anything
return;
}
dlg[dlg.length-1].user_target = normalizedUserTarget;

} else {
userTarget = goldUserTarget;
}
assert(userTarget instanceof ThingTalk.Ast.DialogueState);
state = ThingTalkUtils.computeNewState(state, userTarget, 'user');

Expand All @@ -125,23 +201,70 @@ class SimulatorStream extends Stream.Transform {
user_target: ''
};

const policyResult = await this._dialoguePolicy.chooseAction(state);
if (!policyResult)
throw new Error(`Dialogue policy error: no reply for dialogue ${dlg.id}`);
let policyResult;
try {
policyResult = await this._dialoguePolicy.chooseAction(state);
} catch(error) {
console.log(`Error while choosing action: ${error.message}. skipping.`);
return;
}
if (!policyResult) {
// throw new Error(`Dialogue policy error: no reply for dialogue ${dlg.id}`);
console.log(`Dialogue policy error: no reply for dialogue ${dlg.id}. skipping.`);
return;
}
const [dialogueStateAfterAgent, , utterance] = policyResult;

const prediction = ThingTalkUtils.computePrediction(state, dialogueStateAfterAgent, 'agent');
newTurn.agent = utterance;
newTurn.agent_target = prediction.prettyprint();

return {
this.push({
id: dlg.id,
turns: dlg.concat([newTurn])
};
});
}

_transform(dlg : ParsedDialogue, encoding : BufferEncoding, callback : (err : Error|null, dlg ?: DialogueExample) => void) {
this._run(dlg).then((dlg) => callback(null, dlg), callback);
this._run(dlg).then(() => callback(null), callback);
}

_flush(callback : () => void) {
callback();
}
}

class DialogueToPartialDialoguesStream extends Stream.Transform {

constructor() {
super({ objectMode : true });
}

private _copyDialogueTurns(turns : DialogueTurn[]) : DialogueTurn[] {
const copy : DialogueTurn[] = [];
for (let i = 0; i < turns.length; i++) {
copy.push({
context : turns[i].context,
agent : turns[i].agent,
agent_target : turns[i].agent_target,
intermediate_context : turns[i].intermediate_context,
user : turns[i].user,
user_target : turns[i].user_target
});
}
return copy;
}

async _run(dlg : ParsedDialogue) : Promise<void> {
for (let i = 1; i < dlg.length + 1; i++) {
// do a deep copy so that later streams can modify these dialogues
const output = this._copyDialogueTurns(dlg.slice(0, i));
(output as ParsedDialogue).id = dlg.id + '-turn_' + i;
this.push(output);
}
}

_transform(dlg : ParsedDialogue, encoding : BufferEncoding, callback : (err : Error|null, dlgs ?: ParsedDialogue) => void) {
this._run(dlg).then(() => callback(null), callback);
}

_flush(callback : () => void) {
Expand Down Expand Up @@ -176,11 +299,32 @@ export async function execute(args : any) {
debug: false
});

await StreamUtils.waitFinish(
readAllLines(args.input_file, '====')
.pipe(new DialogueParser())
.pipe(new SimulatorStream(policy, simulator, schemas))
.pipe(new DialogueSerializer())
.pipe(args.output)
);
let parser = null;
if (args.nlu_server){
parser = ParserClient.get(args.nlu_server, args.locale);
await parser.start();
}

if (args.all_turns) {
await StreamUtils.waitFinish(
readAllLines(args.input_file, '====')
.pipe(new DialogueParser())
.pipe(new DialogueToPartialDialoguesStream()) // convert each dialogues to many partial dialogues
.pipe(new SimulatorStream(policy, simulator, schemas, parser, tpClient, args.output_mistakes_only, args.locale))
.pipe(new DialogueSerializer())
.pipe(args.output)
);
} else {
await StreamUtils.waitFinish(
readAllLines(args.input_file, '====')
.pipe(new DialogueParser())
.pipe(new SimulatorStream(policy, simulator, schemas, parser, tpClient, args.output_mistakes_only, args.locale))
.pipe(new DialogueSerializer())
.pipe(args.output)
);
}


if (parser !== null)
await parser.stop();
}

0 comments on commit d418bae

Please sign in to comment.