Skip to content

Commit

Permalink
Merge 021f2b5 into 5783646
Browse files Browse the repository at this point in the history
  • Loading branch information
Silei Xu committed Oct 10, 2021
2 parents 5783646 + 021f2b5 commit 87f0fa6
Show file tree
Hide file tree
Showing 58 changed files with 4,187 additions and 9,174 deletions.
2 changes: 2 additions & 0 deletions lib/dataset-tools/augmentation/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ interface DatasetAugmenterOptions {
includeQuotedExample : boolean;
cleanParameters : boolean;
requotable : boolean;
includeEntityValue : boolean;

samplingType : 'random' | 'uniform' | 'default';
subsetParamSet : [number, number];
Expand Down Expand Up @@ -100,6 +101,7 @@ export default class DatasetAugmenter extends Stream.Transform {
paraphrasingExpandFactor: this._options.paraphrasingExpandFactor,
cleanParameters: this._options.cleanParameters,
requotable: this._options.requotable,
includeEntityValue: this._options.includeEntityValue,
samplingType: this._options.samplingType,
subsetParamSet: this._options.subsetParamSet,
numAttempts: this._options.numAttempts,
Expand Down
137 changes: 82 additions & 55 deletions lib/dataset-tools/augmentation/replace_parameters.ts

Large diffs are not rendered by default.

23 changes: 20 additions & 3 deletions lib/dataset-tools/evaluation/sentence_evaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ type SentenceEvaluatorOptions = {
tokenized ?: boolean;
oracle ?: boolean;
complexityMetric ?: keyof typeof COMPLEXITY_METRICS;
includeEntityValue ?: boolean
ignoreEntityType ?: boolean
} & ThingTalkUtils.ParseOptions;

export interface ExampleEvaluationResult {
Expand Down Expand Up @@ -135,6 +137,8 @@ class SentenceEvaluator {
private _tokenized : boolean;
private _debug : boolean;
private _oracle : boolean;
private _includeEntityValue : boolean;
private _ignoreEntityType : boolean;
private _tokenizer : I18n.BaseTokenizer;
private _computeComplexity : ((id : string, code : string) => number)|undefined;

Expand All @@ -155,6 +159,8 @@ class SentenceEvaluator {
this._tokenized = !!options.tokenized;
this._debug = options.debug;
this._oracle = !!options.oracle;
this._includeEntityValue = !!options.includeEntityValue;
this._ignoreEntityType = !!options.ignoreEntityType;
this._tokenizer = tokenizer;

if (options.complexityMetric)
Expand Down Expand Up @@ -183,6 +189,14 @@ class SentenceEvaluator {
return false;
}

private _equals(thingtalk1 : string, thingtalk2 : string) : boolean {
if (this._ignoreEntityType) {
thingtalk1 = thingtalk1.replace(/\^\^\S+/g, '^^entity');
thingtalk2 = thingtalk2.replace(/\^\^\S+/g, '^^entity');
}
return thingtalk1 === thingtalk2;
}

async evaluate() : Promise<ExampleEvaluationResult|undefined> {
const result : ExampleEvaluationResult = {
id: this._id,
Expand Down Expand Up @@ -229,6 +243,7 @@ class SentenceEvaluator {
normalizedTargetCode.push(ThingTalkUtils.serializePrediction(parsed!, tokens, entities, {
locale: this._locale,
timezone: this._options.timezone,
includeEntityValue: this._includeEntityValue
}).join(' '));
} catch(e) {
// if the target_code did not parse due to missing functions in thingpedia, ignore it
Expand All @@ -251,6 +266,7 @@ class SentenceEvaluator {
normalizedTargetCode.push(ThingTalkUtils.serializePrediction(parsed!, tokens, entities, {
locale: this._locale,
timezone: this._options.timezone,
includeEntityValue: this._includeEntityValue
}).join(' '));
} catch(e) {
console.error(this._id, this._preprocessed, this._targetPrograms);
Expand Down Expand Up @@ -327,13 +343,14 @@ class SentenceEvaluator {
const normalized = ThingTalkUtils.serializePrediction(parsed, tokens, entities, {
locale: this._locale,
timezone: this._options.timezone,
ignoreSentence: true
ignoreSentence: true,
includeEntityValue: this._includeEntityValue
});
const normalizedCode = normalized.join(' ');

// check that by normalizing we did not accidentally mark wrong a program that
// was correct before
if (beamString === normalizedTargetCode[0] && normalizedCode !== normalizedTargetCode[0]) {
if (this._equals(beamString, normalizedTargetCode[0]) && !this._equals(normalizedCode, normalizedTargetCode[0])) {
console.error();
console.error('NORMALIZATION ERROR');
console.error(normalizedTargetCode[0]);
Expand All @@ -347,7 +364,7 @@ class SentenceEvaluator {
let result_string = 'ok_syntax';

for (let referenceId = 0; referenceId < this._targetPrograms.length; referenceId++) {
if (normalizedCode === normalizedTargetCode[referenceId]) {
if (this._equals(normalizedCode, normalizedTargetCode[referenceId])) {
// we have a match!

beam_ok = true;
Expand Down
8 changes: 8 additions & 0 deletions lib/i18n/english.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import { Inflectors } from 'en-inflectors';
import { Tag } from 'en-pos';
import * as lexicon from 'en-lexicon';

import { coin } from '../utils/random';
import { Phrase } from '../utils/template-string';
Expand Down Expand Up @@ -135,6 +136,13 @@ function indefiniteArticle(word : string) {
export default class EnglishLanguagePack extends DefaultLanguagePack {
protected _tokenizer : EnglishTokenizer|undefined;

constructor(locale : string) {
super(locale);

// the pos tagger will crash without this lexicon extension
lexicon.extend({ constructor: 'NN' });
}

getTokenizer() : EnglishTokenizer {
if (this._tokenizer)
return this._tokenizer;
Expand Down
10 changes: 10 additions & 0 deletions lib/pos-parser/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ export default class PosParser {
for (const template of this.queryTemplates[pos]) {
const match = template.match(utterance, domainCanonicals, value);
if (match && !match.includes('$domain') && match.split(' ').length - 1 < MAX_LENGTH) {
// FIXME: capture these in templates
// skip matches containing punctuations that always introduce a break in the utterance
if (/[,.!?:]/.test(match))
continue;
// skip reverse property that contains a pronoun
if (pos === 'reverse_property') {
const tokens = match.split(' ');
if (tokens.includes('it') || tokens.includes('that') || tokens.includes('this'))
continue;
}
if (pos === 'verb' && match.startsWith('$value ')) {
return [
{ pos, canonical: match },
Expand Down
2 changes: 1 addition & 1 deletion lib/pos-parser/nfa.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class State {
constructor(isEnd = false) {
this.id = stateCounter++;
this.isEnd = isEnd;
this.transitions = {};
this.transitions = Object.create(null);
}

addTransition(token : string, to : State, capturing = false) {
Expand Down
40 changes: 24 additions & 16 deletions lib/templates/ast_manip.ts
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ function makeEdgeFilterStream(loader : ThingpediaLoader,
ptype: proj.schema!.getArgType(args[0])!,
ast: new Ast.BooleanExpression.Atom(null, args[0], op, value)
};
if (!checkFilter(proj.expression, f))
if (!checkFilter(loader, proj.expression, f))
return null;
if (!proj.schema!.is_monitorable || proj.schema!.is_list)
return null;
Expand Down Expand Up @@ -710,7 +710,7 @@ export function toChainExpression(expr : Ast.Expression) {

function makeProgram(loader : ThingpediaLoader,
rule : Ast.Expression) : Ast.Program|null {
if (!checkValidQuery(rule))
if (!loader.flags.no_soft_match_id && !checkValidQuery(rule))
return null;
const chain = toChainExpression(rule);
if (chain.first.schema!.functionType === 'stream' && loader.flags.nostream)
Expand Down Expand Up @@ -792,7 +792,7 @@ function checkComputeFilter(table : Ast.Expression, filter : Ast.ComputeBooleanE
return filter.rhs.getType().equals(vtype);
}

function checkAtomFilter(table : Ast.Expression, filter : Ast.AtomBooleanExpression) : boolean {
function checkAtomFilter(loader : ThingpediaLoader, table : Ast.Expression, filter : Ast.AtomBooleanExpression) : boolean {
const arg = table.schema!.getArgument(filter.name);
if (!arg || arg.is_input)
return false;
Expand Down Expand Up @@ -829,14 +829,20 @@ function checkAtomFilter(table : Ast.Expression, filter : Ast.AtomBooleanExpress
}

let typeMatch = false;
const valueType = filter.value.getType();
const parentTypes = valueType instanceof Type.Entity ? loader.entitySubTypeMap[valueType.type] || [] : [];
for (const type of vtypes) {
if (filter.value.getType().equals(type))
if (valueType.equals(type)) {
typeMatch = true;
break;
} else if (type instanceof Type.Entity && parentTypes.includes(type.type)) {
typeMatch = true;
break;
}
}
if (!typeMatch)
return false;


if (vtype.isNumber || vtype.isMeasure) {
let min = -Infinity;
const minArg = arg.getImplementationAnnotation<number>('min_number');
Expand All @@ -858,7 +864,7 @@ function checkAtomFilter(table : Ast.Expression, filter : Ast.AtomBooleanExpress
return true;
}

function internalCheckFilter(table : Ast.Expression, filter : Ast.BooleanExpression) : boolean {
function internalCheckFilter(loader : ThingpediaLoader, table : Ast.Expression, filter : Ast.BooleanExpression) : boolean {
while (table instanceof Ast.ProjectionExpression)
table = table.expression;

Expand All @@ -869,7 +875,7 @@ function internalCheckFilter(table : Ast.Expression, filter : Ast.BooleanExpress
if (filter instanceof Ast.AndBooleanExpression ||
filter instanceof Ast.OrBooleanExpression) {
for (const operands of filter.operands) {
if (!internalCheckFilter(table, operands))
if (!internalCheckFilter(loader, table, operands))
return false;
}
return true;
Expand All @@ -879,7 +885,7 @@ function internalCheckFilter(table : Ast.Expression, filter : Ast.BooleanExpress
return checkComputeFilter(table, filter);

if (filter instanceof Ast.AtomBooleanExpression)
return checkAtomFilter(table, filter);
return checkAtomFilter(loader, table, filter);

if (filter instanceof Ast.DontCareBooleanExpression) {
const arg = table.schema!.getArgument(filter.name);
Expand All @@ -893,10 +899,10 @@ function internalCheckFilter(table : Ast.Expression, filter : Ast.BooleanExpress
throw new Error(`Unexpected filter type ${filter}`);
}

function checkFilter(table : Ast.Expression, filter : FilterSlot|DomainIndependentFilterSlot) : boolean {
function checkFilter(loader : ThingpediaLoader, table : Ast.Expression, filter : FilterSlot|DomainIndependentFilterSlot) : boolean {
if (filter.schema !== null && !isSameFunction(table.schema!, filter.schema))
return false;
return internalCheckFilter(table, filter.ast);
return internalCheckFilter(loader, table, filter.ast);
}

function* iterateFilters(table : Ast.Expression) : Generator<[Ast.FunctionDef, Ast.BooleanExpression], void> {
Expand Down Expand Up @@ -1063,10 +1069,11 @@ function addFilterInternal(table : Ast.Expression,
return new Ast.FilterExpression(null, table, filter, schema);
}

function addFilter(table : Ast.Expression,
function addFilter(loader : ThingpediaLoader,
table : Ast.Expression,
filter : FilterSlot|DomainIndependentFilterSlot,
options : AddFilterOptions = {}) : Ast.Expression|null {
if (!checkFilter(table, filter))
if (!checkFilter(loader, table, filter))
return null;

return addFilterInternal(table, filter.ast, options);
Expand Down Expand Up @@ -1531,7 +1538,8 @@ function makeComputeExpression(table : Ast.Expression,
return new Ast.ProjectionExpression(null, table, [], [expression], [null], resolveProjection(table.schema!, [], [expression]));
}

function makeComputeFilterExpression(table : Ast.Expression,
function makeComputeFilterExpression(loader : ThingpediaLoader,
table : Ast.Expression,
operation : 'distance',
operands : Ast.Value[],
resultType : Type,
Expand All @@ -1551,10 +1559,10 @@ function makeComputeFilterExpression(table : Ast.Expression,
ptype: expression.type,
ast: new Ast.BooleanExpression.Compute(null, expression, filterOp, filterValue)
};
return addFilter(table, filter);
return addFilter(loader, table, filter);
}

function makeWithinGeoDistanceExpression(table : Ast.Expression, location : Ast.Value, filterValue : Ast.Value) : Ast.Expression|null {
function makeWithinGeoDistanceExpression(loader : ThingpediaLoader, table : Ast.Expression, location : Ast.Value, filterValue : Ast.Value) : Ast.Expression|null {
const arg = table.schema!.getArgument('geo');
if (!arg || !arg.type.isLocation)
return null;
Expand All @@ -1570,7 +1578,7 @@ function makeWithinGeoDistanceExpression(table : Ast.Expression, location : Ast.
// the distance should be at least 100 meters (if the value is small number)
if (filterValue instanceof Ast.MeasureValue && Units.transformToBaseUnit(filterValue.value, unit) < 100)
return null;
return makeComputeFilterExpression(table, 'distance', [new Ast.Value.VarRef('geo'), location], new Type.Measure('m'), '<=', filterValue);
return makeComputeFilterExpression(loader, table, 'distance', [new Ast.Value.VarRef('geo'), location], new Type.Measure('m'), '<=', filterValue);
}

function makeComputeArgMinMaxExpression(table : Ast.Expression,
Expand Down
8 changes: 4 additions & 4 deletions lib/templates/commands.genie
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ forward_when_do_rule : Ast.ChainExpression = {
// no pp
stream:stream action:complete_action => C.makeChainExpression(stream, action);
stream:stream action:complete_action 'if' filter:if_filter with { functionName = stream.functionName } => {
const newStream = C.addFilter(stream, filter);
const newStream = C.addFilter($loader, stream, filter);
if (!newStream)
return null;
return C.makeChainExpression(newStream, action);
Expand Down Expand Up @@ -384,7 +384,7 @@ explicit_when_condition : Ast.Expression = {
return null;
if ($loader.flags.turking && table.schema!.is_list)
return null;
const withFilter = C.addFilter(table, filter, { ifFilter: true });
const withFilter = C.addFilter($loader, table, filter, { ifFilter: true });
if (!withFilter)
return null;
return C.tableToStream(withFilter);
Expand All @@ -395,7 +395,7 @@ explicit_when_condition : Ast.Expression = {
return null;
if (!table.schema!.is_monitorable || table.schema!.is_list)
return null;
const withFilter = C.addFilter(table, filter, { ifFilter: true });
const withFilter = C.addFilter($loader, table, filter, { ifFilter: true });
if (!withFilter)
return null;
return C.tableToStream(withFilter);
Expand Down Expand Up @@ -450,7 +450,7 @@ monitor_command : Ast.Expression = {
return null;
if (table.schema!.is_list || !table.schema!.is_monitorable)
return null;
const withFilter = C.addFilter(table, filter);
const withFilter = C.addFilter($loader, table, filter);
if (!withFilter)
return null;
return C.tableToStream(withFilter);
Expand Down
2 changes: 1 addition & 1 deletion lib/templates/dialogue_acts/empty-search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ export function impreciseEmptySearchChangeRequest(ctx : ContextInfo,
return null;
if (answerFilter.name !== param.name)
return null;
if (!C.checkFilter(base, answerFilter))
if (!C.checkFilter(ctx.loader, base, answerFilter))
return null;

return emptySearchChangePhraseCommon(ctx, answerFilter, refineFilterForEmptySearch);
Expand Down
2 changes: 1 addition & 1 deletion lib/templates/dialogue_acts/search-questions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ function impreciseSearchQuestionAnswer(ctx : ContextInfo, answer : C.FilterSlot|

const currentStmt = ctx.current!.stmt;
const currentTable = currentStmt.expression;
if (!C.checkFilter(currentTable, answerFilter))
if (!C.checkFilter(ctx.loader, currentTable, answerFilter))
return null;

const newTable = queryRefinement(currentTable, answerFilter.ast, refineFilterToAnswerQuestion, null);
Expand Down
Loading

0 comments on commit 87f0fa6

Please sign in to comment.