Skip to content

Commit

Permalink
Refactor how entities are generated in agent sentences (#453)
Browse files Browse the repository at this point in the history
The goal of this refactor is two fold:
- We need to generate delexicalized agent utterances at inference
  time to pass to the NLG
- We want the actual text of the entity to be generated by the I18n
  code during postprocessing, so consistent postprocessing is
  applied to both synthetic and NLG sentences, and we can use
  locale-specific code.

To apply this refactoring, the system of inference time constants
is modified to make use of a SequentialEntityAllocator, which allocates
entity tokens (NUMBER_0, DATE_1, etc.) for each value that appears
in the context. This entity allocator is also passed to the Describe
module when that module extracts a value from the program.

At training time, we have only one SequentialEntityAllocator,
which is primed with the fake entities used during generation,
plus any entity that is generated by the simulator and observed
by Describe. At inference time, we reset the entity allocator
immediately prior to processing a new context.

The entities thus allocated are later passed to the I18n module for
postprocessing. This module gained code to automatically select
the most appropriate unit for a given size.

Fixes #447
  • Loading branch information
gcampax committed Feb 10, 2021
1 parent 21e0788 commit 316836d
Show file tree
Hide file tree
Showing 24 changed files with 600 additions and 431 deletions.
3 changes: 2 additions & 1 deletion languages/thingtalk/load-thingpedia.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ export default class ThingpediaLoader {
this._ttUtils = ttUtils;
this._grammar = grammar;
this._langPack = langPack;
this._describer = new ttUtils.Describer(langPack.locale, options.timezone, options.forSide);
this._describer = new ttUtils.Describer(langPack.locale,
options.timezone, options.entityAllocator, options.forSide);

this._tpClient = options.thingpediaClient;
if (!options.schemaRetriever) {
Expand Down
20 changes: 17 additions & 3 deletions lib/dialogue-agent/abstract_dialogue_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import assert from 'assert';
import { Ast, SchemaRetriever } from 'thingtalk';

import * as I18n from '../i18n';
import { cleanKind } from '../utils/misc-utils';
import { shouldAutoConfirmStatement } from '../utils/thingtalk';
import { contactSearch, Contact } from './entity-linking/contact_search';
Expand Down Expand Up @@ -75,11 +76,13 @@ interface ExecutionResult<PrivateStateType> {
export default abstract class AbstractDialogueAgent<PrivateStateType> {
protected _schemas : SchemaRetriever;
protected _debug : boolean;
private _langPack : I18n.LanguagePack;
locale : string;
timezone : string;

constructor(schemas : SchemaRetriever, options : AbstractDialogueAgentOptions) {
this._schemas = schemas;
this._langPack = I18n.get(options.locale);

this._debug = options.debug;
this.locale = options.locale;
Expand Down Expand Up @@ -314,7 +317,17 @@ export default abstract class AbstractDialogueAgent<PrivateStateType> {
// since dlg.locale is overwritten to be en-US, we infer the locale
// via other environment variables like LANG (language) or TZ (timezone)
if (value instanceof Ast.MeasureValue && value.unit.startsWith('default')) {
value.unit = this.getPreferredUnit(value.unit.substring('default'.length).toLowerCase());
const key = value.unit.substring('default'.length).toLowerCase();
const preference = this.getPreferredUnit(key);
if (preference)
value.unit = preference;

switch (key) {
case 'defaultTemperature':
value.unit = this._langPack.getDefaultTemperatureUnit();
default:
throw new TypeError('Unexpected default unit ' + value.unit);
}
} else if (value instanceof Ast.LocationValue && value.value instanceof Ast.UnresolvedLocation) {
slot.set(await this.lookupLocation(value.value.name, hints.previousLocations || []));
} else if (value instanceof Ast.LocationValue && value.value instanceof Ast.RelativeLocation) {
Expand Down Expand Up @@ -474,12 +487,13 @@ export default abstract class AbstractDialogueAgent<PrivateStateType> {
* Compute the user's preferred unit to use when the program specifies an ambiguous unit
* such as "degrees".
*
* @param {string} type - the type of unit to retrieve (e.g. "temperature")
* @param {string} type - the type of unit to retrieve (e.g. "temperature"), or undefined
* if the user has no preference
* @returns {string} - the preferred unit
* @abstract
* @protected
*/
protected getPreferredUnit(type : string) : string {
getPreferredUnit(type : string) : string|undefined {
throw new TypeError('Abstract method');
}
}
11 changes: 8 additions & 3 deletions lib/dialogue-agent/dialogue-loop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import AsyncQueue from 'consumer-queue';
import { getProgramIcon } from '../utils/icons';
import { computePrediction, computeNewState, prepareContextForPrediction } from '../utils/thingtalk';
import type Engine from '../engine';
import * as I18n from '../i18n';

import ValueCategory from './value-category';
import QueueItem from './dialogue_queue';
Expand Down Expand Up @@ -58,6 +59,7 @@ export default class DialogueLoop {
engine : Engine;
private _textFormatter : TextFormatter;
private _cardFormatter : CardFormatter;
private _langPack : I18n.LanguagePack;

private _userInputQueue : AsyncQueue<UserInput>;
private _notifyQueue : AsyncQueue<QueueItem>;
Expand Down Expand Up @@ -87,6 +89,7 @@ export default class DialogueLoop {
this.conversation = conversation;
this.engine = engine;
this._prefs = engine.platform.getSharedPreferences();
this._langPack = I18n.get(engine.platform.locale);
this._textFormatter = new TextFormatter(engine.platform.locale, engine.platform.timezone, engine.schemas);
this._cardFormatter = new CardFormatter(engine.platform.locale, engine.platform.timezone, engine.schemas);
this.icon = null;
Expand Down Expand Up @@ -261,9 +264,9 @@ export default class DialogueLoop {
throw new CancellationError();
}

let expect, utterance, numResults;
let expect, utterance, entities, numResults;
if (this._useNeuralNLG()) {
[this._dialogueState, expect, , numResults] = policyResult;
[this._dialogueState, expect, , entities, numResults] = policyResult;

const policyPrediction = computeNewState(oldState, this._dialogueState, 'agent');
this.debug(`Agent act:`);
Expand All @@ -274,9 +277,11 @@ export default class DialogueLoop {

utterance = await this.conversation.generateAnswer(policyPrediction);
} else {
[this._dialogueState, expect, utterance, numResults] = policyResult;
[this._dialogueState, expect, utterance, entities, numResults] = policyResult;
}

utterance = this._langPack.postprocessNLG(utterance, entities, this._agent);

this.icon = getProgramIcon(this._dialogueState!);
await this.reply(utterance);

Expand Down
17 changes: 11 additions & 6 deletions lib/dialogue-agent/dialogue_policy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@

import assert from 'assert';
import * as Tp from 'thingpedia';
import { Ast, SchemaRetriever } from 'thingtalk';
import { Ast, SchemaRetriever, Syntax } from 'thingtalk';

import ValueCategory from './value-category';
import * as I18n from '../i18n';
import SentenceGenerator, { SentenceGeneratorOptions } from '../sentence-generator/generator';
import { AgentReplyRecord } from '../sentence-generator/types';
import * as ThingTalkUtils from '../utils/thingtalk';
import { EntityMap } from '../utils/entity-utils';

const MAX_DEPTH = 7;
const TARGET_PRUNING_SIZES = [50, 100];
Expand Down Expand Up @@ -104,13 +105,15 @@ export default class DialoguePolicy {
private _sentenceGenerator : SentenceGenerator<Ast.DialogueState|null, Ast.DialogueState, AgentReplyRecord<Ast.DialogueState>>|null;
private _generatorDevices : string[]|null;
private _generatorOptions : SentenceGeneratorOptions<Ast.DialogueState|null, Ast.DialogueState>|undefined;
private _entityAllocator : Syntax.SequentialEntityAllocator;

constructor(options : DialoguePolicyOptions) {
this._thingpedia = options.thingpedia;
this._schemas = options.schemas;
this._locale = options.locale;
this._timezone = options.timezone;
this._langPack = I18n.get(options.locale);
this._entityAllocator = new Syntax.SequentialEntityAllocator({});

this._rng = options.rng;
assert(this._rng);
Expand Down Expand Up @@ -139,16 +142,18 @@ export default class DialoguePolicy {
templateFiles: [TEMPLATE_FILE_PATH],
thingpediaClient: this._thingpedia,
schemaRetriever: this._schemas,
entityAllocator: this._entityAllocator,
onlyDevices: forDevices,
maxDepth: MAX_DEPTH,
maxConstants: 5,
targetPruningSize: TARGET_PRUNING_SIZES[0],
debug: this._debug ? 2 : 1,

contextInitializer(state, functionTable, contextTable) {
contextInitializer: (state, functionTable, contextTable) => {
// ask the target language to extract the constants from the context
this._entityAllocator.reset();
if (state !== null) {
const constants = ThingTalkUtils.extractConstants(state);
const constants = ThingTalkUtils.extractConstants(state, this._entityAllocator);
sentenceGenerator.addConstantsFromContext(constants);
}
return functionTable.context!(state, contextTable);
Expand Down Expand Up @@ -203,15 +208,15 @@ export default class DialoguePolicy {
return derivation;
}

async chooseAction(state : Ast.DialogueState|null) : Promise<[Ast.DialogueState, ValueCategory|null, string, number]|undefined> {
async chooseAction(state : Ast.DialogueState|null) : Promise<[Ast.DialogueState, ValueCategory|null, string, EntityMap, number]|undefined> {
await this._ensureGeneratorForState(state);

const derivation = this._generateDerivation(state);
if (derivation === undefined)
return derivation;

let sentence = derivation.toString();
sentence = this._langPack.postprocessSynthetic(sentence, derivation.value.state, this._rng, 'agent');
sentence = this._langPack.postprocessNLG(sentence, {});

let expect : ValueCategory|null;
if (derivation.value.end)
Expand All @@ -223,6 +228,6 @@ export default class DialoguePolicy {
if (expect === ValueCategory.RawString && !derivation.value.raw)
expect = ValueCategory.Command;

return [derivation.value.state, expect, sentence, derivation.value.numResults];
return [derivation.value.state, expect, sentence, this._entityAllocator.entities, derivation.value.numResults];
}
}
68 changes: 2 additions & 66 deletions lib/dialogue-agent/execution_dialogue_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -388,72 +388,8 @@ export default class ExecutionDialogueAgent extends AbstractDialogueAgent<undefi
return answer;
}

protected getPreferredUnit(type : string) : string {
// const locale = dlg.locale; // this is not useful
getPreferredUnit(type : string) : string|undefined {
const pref = this._platform.getSharedPreferences();
let preferredUnit = pref.get('preferred-' + type) as string|undefined;
// e.g. defaultTemperature will get from preferred-temperature
if (preferredUnit === undefined) {
switch (type) {
case 'temperature':
preferredUnit = this._getDefaultTemperatureUnit();
break;
default:
throw new Error('Invalid default unit');
}
}
return preferredUnit;
}

private _getDefaultTemperatureUnit() : string {
// this method is quite hacky because it accounts for the fact that the locale
// is always en-US, but we don't want

let preferredUnit = 'C'; // Below code checks if we are in US
if (this._platform.type !== 'cloud' && this._platform.type !== 'android') {
const realLocale = process.env.LC_ALL || process.env.LC_MEASUREMENT || process.env.LANG || 'C';
if (realLocale.indexOf('en_US') !== -1)
preferredUnit = 'F';
} else if (this._platform.type === 'cloud') {
const realLocale = process.env.TZ || 'UTC';
// timezones obtained from http://efele.net/maps/tz/us/
const usTimeZones = [
'America/New_York',
'America/Chicago',
'America/Denver',
'America/Los_Angeles',
'America/Adak',
'America/Yakutat',
'America/Juneau',
'America/Sitka',
'America/Metlakatla',
'America/Anchrorage',
'America/Nome',
'America/Phoenix',
'America/Honolulu',
'America/Boise',
'America/Indiana/Marengo',
'America/Indiana/Vincennes',
'America/Indiana/Tell_City',
'America/Indiana/Petersburg',
'America/Indiana/Knox',
'America/Indiana/Winamac',
'America/Indiana/Vevay',
'America/Kentucky/Louisville',
'America/Indiana/Indianapolis',
'America/Kentucky/Monticello',
'America/Menominee',
'America/North_Dakota/Center',
'America/North_Dakota/New_Salem',
'America/North_Dakota/Beulah',
'America/Boise',
'America/Puerto_Rico',
'America/St_Thomas',
'America/Shiprock',
];
if (usTimeZones.indexOf(realLocale) !== -1)
preferredUnit = 'F';
}
return preferredUnit;
return pref.get('preferred-' + type) as string|undefined;
}
}
2 changes: 1 addition & 1 deletion lib/dialogue-agent/simulator/simulation_dialogue_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ export default class SimulationDialogueAgent extends AbstractDialogueAgent<Thing
}
}

protected getPreferredUnit(type : string) : string {
getPreferredUnit(type : string) : string {
switch (type) {
case 'temperature':
if (this._interactive)
Expand Down
11 changes: 9 additions & 2 deletions lib/engine/apps/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ export default class AppDatabase extends events.EventEmitter {
if (!description) {
// if we don't have a description already, compute one using
// the Describer
const describer = new Describer(this._platform.locale, this._platform.timezone);
const allocator = new ThingTalk.Syntax.SequentialEntityAllocator({});
const describer = new Describer(this._platform.locale, this._platform.timezone, allocator);

// retrieve the relevant primitive templates
const kinds = new Set<string>();
Expand All @@ -137,7 +138,13 @@ export default class AppDatabase extends events.EventEmitter {
// treat it as an agent sentence for purposes of postprocessing
// (which disables randomization)
// even though it is a user-side sentence (ie, it says "my")
description = langPack.postprocessNLG(langPack.postprocessSynthetic(description, program, null, 'agent'), {});
description = langPack.postprocessNLG(langPack.postprocessSynthetic(description, program, null, 'agent'), allocator.entities, {
timezone: this._platform.timezone,
getPreferredUnit: (type) => {
const pref = this._platform.getSharedPreferences();
return pref.get('preferred-' + type) as string|undefined;
}
});
}

delete options.description;
Expand Down
Loading

0 comments on commit 316836d

Please sign in to comment.