diff --git a/test/data/en-US/mock-thingpedia/constants.tsv b/test/data/en-US/mock-thingpedia/constants.tsv new file mode 100644 index 000000000..d19465beb --- /dev/null +++ b/test/data/en-US/mock-thingpedia/constants.tsv @@ -0,0 +1,2 @@ +param:@mock.device.packages:items:Entity(mock.device:items) apple Apple +param:@mock.device.packages:items:Entity(mock.device:items) banana Banana diff --git a/test/data/en-US/mock-thingpedia/mock.device/manifest.tt b/test/data/en-US/mock-thingpedia/mock.device/manifest.tt new file mode 100644 index 000000000..60cdd3443 --- /dev/null +++ b/test/data/en-US/mock-thingpedia/mock.device/manifest.tt @@ -0,0 +1,53 @@ +class @mock.device +#_[thingpedia_name="mock"] +#_[thingpedia_description="mock"] +#_[canonical="mock"] +#[license="BSD-3-Clause"] +#[license_gplcompatible=true] +#[subcategory="service"] +{ + import loader from @org.thingpedia.v2(); + import config from @org.thingpedia.config.none(); + + entity items #_[description="Packaged items"]; + + query setting(out state : Enum(off, on) + #_[canonical=["status"]], + out value : Number + #_[canonical=["value"]]) + #_[canonical="setting"] + #[minimal_projection=["state", "value"]]; + + query person(out name : String + #_[canonical=["name"]] + #[string_values="tt:person_first_name"] + #[filterable=false]) + #_[canonical=["person"]] + #[minimal_projection=[]]; + + query machine(out speed : Measure(mph) + #_[canonical=["speed"]]) + #_[canonical=["machine"]] + #[minimal_projection=[]]; + + query website(out url : Entity(tt:url) + #[filterable=false] + #_[canonical="link"]) + #_[canonical=["website"]] + #[minimal_projection=[]]; + + query packages(out fruits : Array(Entity(mock.device:fruits)) + #_[canonical=["items"]]) + #_[canonical=["packages"]] + #[minimal_projection=[]]; + + query base_station(out geo : Location + #_[canonical=["location"]]) + #_[canonical=["base station"]] + #[minimal_projection=[]]; + + query contact(out phone : Entity(tt:phone_number) + #_[canonical=["phone number"]]) + #_[canonical=["customer support"]] + #[minimal_projection=[]]; +} \ No newline at end of file diff --git a/test/data/en-US/mock-thingpedia/test/test.tsv b/test/data/en-US/mock-thingpedia/test/test.tsv new file mode 100644 index 000000000..e69de29bb diff --git a/test/unit/test_sample_synthetic_data.js b/test/unit/test_sample_synthetic_data.js new file mode 100644 index 000000000..34cddbe11 --- /dev/null +++ b/test/unit/test_sample_synthetic_data.js @@ -0,0 +1,126 @@ +// -*- mode: js; indent-tabs-mode: nil; js-basic-offset: 4 -*- +// +// This file is part of ThingTalk +// +// Copyright 2017-2020 The Board of Trustees of the Leland Stanford Junior University +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: Jake Wu + +import assert from 'assert'; +import * as ThingTalk from 'thingtalk'; +import * as Tp from 'thingpedia'; +import * as I18n from '../../lib/i18n'; +import * as Path from 'path'; +import sampler from '../../tool/sample-synthetic-data'; +import { ArgumentParser } from 'argparse'; +// const Type = ThingTalk.Type; + +const TEST_CASES = [ + + // query, utterance, thingtalk + ['setting', 'what is the status of the setting ?', '[ state ] of @mock.device . setting ( ) ;'], + ['setting', `which setting has status {0} ?`, '@mock.device . setting ( ) filter state == enum {0} ;'], + ['person', 'what is the name of the person ?', '[ name ] of @mock.device . person ( ) ;'], + ['machine', 'what is the speed of the machine ?', '[ speed ] of @mock.device . machine ( ) ;'], + ['machine', 'which machine has speed {0} metre per second ?', '@mock.device . machine ( ) filter speed == {0} mps ;'], + ['website', "what is the website 's link ?", '[ url ] of @mock.device . website ( ) ;'], + ['packages', 'what items does the packages have ?', '[ fruits ] of @mock.device . packages ( ) ;'], + ['base_station', 'what is the location of the base station ?', '[ geo ] of @mock.device . base_station ( ) ;'], + ['base_station', 'show me a base station with location {0} .', '@mock.device . base_station ( ) filter geo == new Location ( " {0} " ) ;'], + ['contact', "what phone number does the customer support have ?", '[ phone ] of @mock.device . contact ( ) ;'] +]; + +String.prototype.format = function() { + const args = arguments; + return this.replace(/{([0-9]+)}/g, (match, index) => { + return typeof args[index] === 'undefined' ? match : args[index]; + }); +}; + +function initArgparse() { + const parser = new ArgumentParser({ + description: 'Unit test synthetic data sampler argparser', + add_help: true + }); + parser.add_argument('-l', '--locale', { + default: 'en-US', + help: `BGP 47 locale tag of the natural language being processed (defaults to en-US).` + }); + parser.add_argument('-c', '--constants', { + required: false, + default: Path.resolve(Path.dirname(module.filename), '../data/en-US/mock-thingpedia/constants.tsv'), + help: 'TSV file containing sampled constant values to be used.' + }); + parser.add_argument('-d', '--device', { + required: false, + default: 'mock.device', + help: `The name of the device to be synthesized.` + }); + parser.add_argument('-s', '--sampleSize', { + required: false, + default: 1, + help: `The number of samples to be synthesized per annotation.` + }); + parser.add_argument('-f', '--function', { + required: false, + help: `A specific function to be sampled.` + }); + return parser; +} + +export default async function main() { + let anyFailed = false; + const parser = initArgparse(); + const args = parser.parse_args(); + const tpClient = new Tp.FileClient({ + locale: 'en', + thingpedia: Path.resolve(Path.dirname(module.filename), '../data/en-US/mock-thingpedia/mock.device/manifest.tt') + }); + const schemaRetriever = new ThingTalk.SchemaRetriever(tpClient, null, true); + const deviceClass = await schemaRetriever.getFullSchema(args.device); + const baseTokenizer = I18n.get(args.locale).getTokenizer(); + for (let [query, utterance, thingtalk] of TEST_CASES) { + args.function = query; + const ret = await sampler(deviceClass, baseTokenizer, args); + const item = ret.filter((x) => { + if (typeof x.value !== undefined) + utterance = utterance.format(x.value); + return x.utterance.toLowerCase() === utterance.toLowerCase(); + }); + try { + assert(item.length === 1); + if (typeof item[0].value !== undefined) { + utterance = utterance.format(item[0].value); + thingtalk = thingtalk.format(item[0].value); + } + assert.deepStrictEqual(item[0].query, query); + assert.deepStrictEqual(item[0].utterance.toLowerCase(), utterance.toLowerCase()); + assert.deepStrictEqual(item[0].thingtalk.toLowerCase(), thingtalk.toLowerCase()); + } catch(e) { + console.error(`Test case "${query}" failed`); + console.error(`${item[0].utterance} :: ${utterance}`); + console.error(`${item[0].thingtalk} :: ${thingtalk}`); + console.error(e); + anyFailed = true; + } + } + if (anyFailed) + throw new Error('Some test failed'); + else + process.stdout.write('{0}/{1} Passed!\n'.format(TEST_CASES.length, TEST_CASES.length)); +} + +if (!module.parent) + main(); diff --git a/tool/autoqa/wikidata/utils.ts b/tool/autoqa/wikidata/utils.ts index a63f5213c..f10fcb8ae 100644 --- a/tool/autoqa/wikidata/utils.ts +++ b/tool/autoqa/wikidata/utils.ts @@ -33,7 +33,7 @@ const Type = ThingTalk.Type; const _cache = new Map(); -const WikidataUnitToTTUnit : Record = { +export const WikidataUnitToTTUnit : Record = { // time 'millisecond': 'ms', 'second': 's', diff --git a/tool/sample-synthetic-data.ts b/tool/sample-synthetic-data.ts index 9f6af2408..a0f0678c6 100644 --- a/tool/sample-synthetic-data.ts +++ b/tool/sample-synthetic-data.ts @@ -23,10 +23,12 @@ import * as fs from 'fs'; import * as Tp from 'thingpedia'; import * as I18n from '../lib/i18n'; import * as utils from '../lib/utils/misc-utils'; -import { Ast, Type, SchemaRetriever } from 'thingtalk'; -import { ParaphraseExample, generateExamples as generateQueryExamples } from './autoqa/lib/canonical-example-constructor'; +import { Ast, Type, SchemaRetriever, Builtin } from 'thingtalk'; +import { ParaphraseExample } from './autoqa/lib/canonical-example-constructor'; +// import { generateExamples as generateQueryExamples } from './autoqa/lib/canonical-example-constructor'; import { parseConstantFile } from './lib/constant-file'; import { getElementType } from './autoqa/wikidata/utils'; +import { WikidataUnitToTTUnit } from './autoqa/wikidata/utils'; import { makeLookupKeys } from '../lib/dataset-tools/mturk/sample-utils'; import { PARTS_OF_SPEECH, @@ -38,6 +40,8 @@ import genBaseCanonical from './autoqa/lib/base-canonical-generator'; import { serializePrediction } from '../lib/utils/thingtalk'; import { EntityUtils } from '../lib'; import Path = require('path'); +import { Temporal } from '@js-temporal/polyfill'; +// import { ResultGenerator } from '../lib/dialogue-agent/simulator/simulation_exec_environment'; interface NewParaphraseExample extends ParaphraseExample { thingtalk : string @@ -50,6 +54,209 @@ interface Constant { unit ?: string; } +const CITIES = [ + 'New York', + 'Los Angeles', + 'Chicago', + 'Houston', + 'Phoenix', + 'San Antonio', + 'Philadelphia', + 'San Diego', + 'Dallas', + 'Austin', + 'San Jose', + 'Fort Worth', + 'Jacksonville', + 'Charlotte', + 'Columbus', + 'Indianapolis', + 'San Francisco', + 'Seattle', + 'Denver', + 'Washington', + 'Boston', + 'El Paso', + 'Nashville', + 'Oklahoma City', + 'Las Vegas', + 'Portland', + 'Detroit', + 'Memphis', + 'Louisville', + 'Milwaukee', + 'Baltimore', + 'Albuquerque', + 'Tucson', + 'Mesa', + 'Fresno', + 'Atlanta', + 'Sacramento', + 'Kansas City', + 'Colorado Springs', + 'Raleigh', + 'Miami', + 'Omaha', + 'Long Beach', + 'Virginia Beach', + 'Oakland', + 'Minneapolis', + 'Tampa', + 'Tulsa', + 'Arlington', + 'Aurora', + 'Wichita', + 'Bakersfield', + 'New Orleans', + 'Cleveland', + 'Henderson', + 'Anaheim', + 'Honolulu', + 'Riverside', + 'Santa Ana', + 'Corpus Christi', + 'Lexington', + 'San Juan', + 'Stockton', + 'St. Paul', + 'Cincinnati', + 'Irvine', + 'Greensboro', + 'Pittsburgh', + 'Lincoln', + 'Durham', + 'Orlando', + 'St. Louis', + 'Chula Vista', + 'Plano', + 'Newark', + 'Anchorage', + 'Fort Wayne', + 'Chandler', + 'Reno', + 'North Las Vegas', + 'Scottsdale', + 'St. Petersburg', + 'Laredo', + 'Gilbert', + 'Toledo', + 'Lubbock', + 'Madison', + 'Glendale', + 'Jersey City', + 'Buffalo', + 'Chesapeake', + 'Winston-Salem', + 'Fremont', + 'Norfolk', + 'Frisco', + 'Paradise', + 'Irving', + 'Garland', + 'Richmond', + 'Arlington', + 'Boise', + 'Spokane', + 'Hialeah', + 'Moreno Valley', + 'Tacoma', + 'Port St. Lucie', + 'McKinney', + 'Fontana', + 'Modesto', + 'Fayetteville', + 'Baton Rouge', + 'San Bernardino', + 'Santa Clarita', + 'Cape Coral', + 'Des Moines', + 'Tempe', + 'Huntsville', + 'Oxnard', + 'Spring Valley', + 'Birmingham', + 'Rochester', + 'Overland Park', + 'Grand Rapids', + 'Yonkers', + 'Salt Lake City', + 'Columbus', + 'Augusta', + 'Amarillo', + 'Tallahassee', + 'Ontario', + 'Montgomery', + 'Little Rock', + 'Akron', + 'Huntington Beach', + 'Grand Prairie', + 'Glendale', + 'Sioux Falls', + 'Sunrise Manor', + 'Aurora', + 'Vancouver', + 'Knoxville', + 'Peoria', + 'Mobile', + 'Chattanooga', + 'Worcester', + 'Brownsville', + 'Fort Lauderdale', + 'Newport News', + 'Elk Grove', + 'Providence', + 'Shreveport', + 'Salem', + 'Pembroke Pines', + 'Eugene', + 'Rancho Cucamonga', + 'Cary', + 'Santa Rosa', + 'Fort Collins', + 'Oceanside', + 'Corona', + 'Enterprise', + 'Garden Grove', + 'Springfield', + 'Clarksville', + 'Murfreesboro', + 'Lakewood', + 'Bayamon', + 'Killeen', + 'Alexandria', + 'Midland', + 'Hayward', + 'Hollywood', + 'Salinas', + 'Lancaster', + 'Macon', + 'Surprise', + 'Kansas City', + 'Sunnyvale', + 'Palmdale', + 'Bellevue', + 'Springfield', + 'Denton', + 'Jackson', + 'Escondido', + 'Pomona', + 'Naperville', + 'Roseville', + 'Thornton', + 'Round Rock', + 'Pasadena', + 'Joliet', + 'Carrollton', + 'McAllen', + 'Paterson', + 'Rockford', + 'Waco', + 'Bridgeport', + 'Miramar', + 'Olathe', + 'Metairie' +]; + function typeToString(type : Type) : string { const elemType = getElementType(type); if (elemType instanceof Type.Entity) @@ -93,30 +300,91 @@ function parseConstantKeys(classDef : Ast.ClassDef, return sampleConstants; } -function retrieveSampleValues(classDef : Ast.ClassDef, - sampleMeta : Record, - fname : string, - arg : Ast.ArgumentDef) : string[] { - if (arg.type instanceof Type.Enum) - return arg.type.entries!.slice(0, 10).map(utils.clean); - const sampleConstants = parseConstantKeys(classDef, sampleMeta, fname, arg); - return sampleConstants.map((v) => { - if ((arg.type === Type.String) || - (arg.type instanceof Type.Array && ((arg.type.elem as Type) === Type.String))) +function randomInt(low : number, high : number, rng : () => number) : number { + return Math.round(low + (high - low) * rng()); +} + +function makeJSDate(year : number, month : number, day : number) : Date { + const timezone = Temporal.Now.timeZone().id; + const datetz = Temporal.ZonedDateTime.from({ + timeZone: timezone, + year, month, day + }); + return new Date(datetz.epochMilliseconds); +} + +function generateRandomIntArray(max : number, sampleSize : number) { + return Array.from({ length : sampleSize }, () => randomInt(0, max, Math.random)); +} + +function generateDateArray(timezone : string, sampleSize : number) { + const _getDates = function(startDate : Date, period : number) { + const dates = []; + let i = 0; + const d = new Date(startDate); + while (i++ < period) { + dates.push(new Date(d)); + d.setDate(d.getDate() + 1); + } + return dates; + }; + const today = Temporal.Now.zonedDateTime('iso8601', timezone).withPlainTime({ hour: 0, minute: 0, second: 0 }); + const startDate = new Date(today.epochMilliseconds); + return _getDates(startDate, sampleSize); +} + +function generateTimeArray(sampleSize : number) { + const times = []; + for (let i=0; i 0.5 - Math.random()).slice(0, sampleSize); + else + return CITIES; +} + +async function retrieveSampleValues(classDef : Ast.ClassDef, + sampleMeta : Record, + fname : string, + argDef : Ast.ArgumentDef, + sampleSize : number) : Promise { + if (argDef.type instanceof Type.Enum) + return argDef.type.entries!.slice(0, sampleSize); + if ((argDef.type instanceof Type.Measure) || (argDef.type === Type.Currency)) + return generateRandomIntArray(100, sampleSize).map(String); + if (argDef.type === Type.Date) { + const timezone = Temporal.Now.timeZone().id; + return generateDateArray(timezone, sampleSize).map((x) => x.toISOString().substring(0,10)); + } + if (argDef.type === Type.Time) + return generateTimeArray(sampleSize).map((x) => x.toString()); + if (argDef.type === Type.Location) + return generateLocationArray(sampleSize); + const sampleConstants = parseConstantKeys(classDef, sampleMeta, fname, argDef); + const ret = sampleConstants.map((v) => { + if ((argDef.type === Type.String) || + (argDef.type instanceof Type.Array && ((argDef.type.elem as Type) === Type.String))) return v.value; return v.display; }); + return ret.length > sampleSize ? ret.slice(0, sampleSize) : ret; } function toThingtalkValue(classDef : Ast.ClassDef, sampleMeta : Record, fname : string, - arg : Ast.ArgumentDef, - value : string) : { value : Ast.Value, op : string } { + argDef : Ast.ArgumentDef, + value : string) : { value : Ast.Value; op : string; } { value = value.toLowerCase(); - let type = arg.type; + let type = argDef.type; if (type instanceof Type.Entity) { - const sampleConstants = parseConstantKeys(classDef, sampleMeta, fname, arg); + const sampleConstants = parseConstantKeys(classDef, sampleMeta, fname, argDef); const kv = sampleConstants.find((item) => item.value.display?.toLowerCase() === value); const ttValue = kv?.value.value.toLowerCase(); const ttDisplay = kv?.value.display.toLowerCase(); @@ -128,10 +396,34 @@ function toThingtalkValue(classDef : Ast.ClassDef, return { value: new Ast.Value.String(value), op: "=~" }; if (type === Type.Number) return { value: new Ast.Value.Number(parseFloat(value)), op: "==" }; + if (type instanceof Type.Measure) + return { value: new Ast.Value.Measure(parseFloat(value), type.unit), op: "==" }; + if (type === Type.Currency) + // TODO: check code? + return { value: new Ast.Value.Currency(parseFloat(value), 'usd'), op: "==" }; + if (type === Type.Date) { + const [y, m, d] = value.split('-').map(Number); + return { value: new Ast.Value.Date(makeJSDate(y, m, d)), op: "==" }; + } + if (type === Type.Time) { + const [h, m] = value.split(':').map(Number); + return { value: new Ast.Value.Time(new Ast.Time.Absolute(h, m, 0)), op: "==" }; + } + if (type === Type.Location) { + // const location = await tpClient.lookupLocation(value).then( + // (loc) => { + // console.log(loc); + // return loc.filter((val: { address: { country_code: string; }; }) => val.address.country_code.toLowerCase() === 'us' + // )[0]} + // ); + // const newLocation = new Ast.Location.Absolute(location.latitude, location.longitude, location.display); + const newLocation = new Ast.Location.Unresolved(value); + return { value: new Ast.Value.Location(newLocation), op: "==" }; + } if (type instanceof Type.Array) { type = type.elem as Type; if (type instanceof Type.Entity) { - const sampleConstants = parseConstantKeys(classDef, sampleMeta, fname, arg); + const sampleConstants = parseConstantKeys(classDef, sampleMeta, fname, argDef); const kv = sampleConstants.find((item) => item.value.display?.toLowerCase() === value); const ttValue = kv?.value.value.toLowerCase(); const ttDisplay = kv?.value.display.toLowerCase(); @@ -161,6 +453,114 @@ function toTSV(device : string, data : NewParaphraseExample[], useHeading : bool return `${headings}${rows}`; } +function generateQueryExamples(query : Ast.FunctionDef, + arg : Ast.ArgumentDef, + baseCanonicalAnnotation : CanonicalAnnotation, + sampleValues : string[]) : ParaphraseExample[] { + const examples : ParaphraseExample[] = []; + const queryCanonical = Array.isArray(query.nl_annotations.canonical) ? query.nl_annotations.canonical[0] : query.nl_annotations.canonical; + for (const [pos, canonicals] of Object.entries(baseCanonicalAnnotation)) { + if (!PARTS_OF_SPEECH.includes(pos)) + continue; + for (let canonical of canonicals) { + if (PROJECTION_PARTS_OF_SPEECH.includes(pos)) { + examples.push(...generateExamplesByPOS(query, queryCanonical, arg, canonical, pos)); + } else { + for (const value of sampleValues) { + canonical = canonical.replace(/\$\{value.*/i, '#'); + examples.push(...generateExamplesByPOS(query, queryCanonical, arg, canonical, pos, value)); + } + } + } + } + return examples; +} + +function isHumanType(type : Type) { + if (type instanceof Type.Entity) { + if (type.type === 'human') + return true; + } + return false; +} + +function generateExamplesByPOS(query : Ast.FunctionDef, + queryCanonical : string, + argument : Ast.ArgumentDef, + argumentCanonical : string, + pos : string, + value ?: string|boolean) : ParaphraseExample[] { + function example(utterance : string) : ParaphraseExample { + return { query: query.name, queryCanonical, argument: argument.name, utterance, value, paraphrases : [] }; + } + const interrogativePronoun = isHumanType(argument.type) ? 'who' : `which ${queryCanonical}`; + if (!PROJECTION_PARTS_OF_SPEECH.includes(pos)) { + if (!argumentCanonical.includes('#')) { + if (argument.type instanceof Type.Measure) { + const argType = argument.type; + const unitName = Object.keys(WikidataUnitToTTUnit).find( + (key) => WikidataUnitToTTUnit[key].toLowerCase() === argType.unit.toString().toLowerCase() + ); + argumentCanonical = argumentCanonical + ` # ${unitName}`; + } else { + argumentCanonical = argumentCanonical + ' #'; + } + } + } + const predicate = typeof value === 'string' ? argumentCanonical.replace('#', value) : argumentCanonical; + switch (pos) { + case 'base': + return [ + example(`What is the ${argumentCanonical} of the ${queryCanonical}?`), + example(`What is the ${queryCanonical} 's ${argumentCanonical}?`), + example(`What ${argumentCanonical} does the ${queryCanonical} have?`) + ]; + case 'property': + case 'property_true': + case 'property_false': + return [ + example(`Show me a ${queryCanonical} with ${predicate}.`), + example(`${interrogativePronoun} has ${predicate}?`) + ]; + case 'verb': + case 'verb_true': + case 'verb_false': + return [ + example(`Show me a ${queryCanonical} that ${predicate}.`), + example(`${interrogativePronoun} ${predicate}?`) + ]; + case 'passive_verb': + case 'passive_verb_true': + case 'passive_verb_false': + case 'preposition': + case 'preposition_true': + case 'preposition_false': + return [ + example(`Show me a ${queryCanonical} ${predicate}.`), + example(`${interrogativePronoun} is ${predicate}?`) + ]; + case 'reverse_property': + case 'reverse_property_true': + case 'reverse_property_false': + return [ + example(`${interrogativePronoun} is a ${predicate}?`) + ]; + case 'adjective': + case 'adjective_true': + case 'adjective_false': + return [ + example(`Show me a ${predicate} ${queryCanonical}.`), + example(`${interrogativePronoun} is ${predicate}?`) + ]; + case 'reverse_verb': + return [ + example(`${interrogativePronoun} ${predicate} the ${queryCanonical}?`) + ]; + default: + return []; + } +} + function generateBaseCanonicalAnnotation(func : Ast.FunctionDef, arg : Ast.ArgumentDef, typeCounts : Record, @@ -177,7 +577,7 @@ function generateBaseCanonicalAnnotation(func : Ast.FunctionDef, canonicalAnnotation.base = existingCanonical; else if (typeof existingCanonical === 'object') Object.assign(canonicalAnnotation, existingCanonical); - } + } // remove function name in arg name, normally it's repetitive for (const [key, value] of Object.entries(canonicalAnnotation)) { @@ -198,9 +598,19 @@ function generateBaseCanonicalAnnotation(func : Ast.FunctionDef, if (typestr && typeCounts[typestr] === 1) { // if an entity is unique, allow dropping the property name entirely // FIXME: consider type hierarchy, or probably drop it entirely - if (canonicalAnnotation.property && !queries.includes(typestr.substring(typestr.indexOf(':') + 1))) { - if (!canonicalAnnotation.property.includes('#')) - canonicalAnnotation.property.push('#'); + // if (canonicalAnnotation.property && !queries.includes(typestr.substring(typestr.indexOf(':') + 1))) { + // if (!canonicalAnnotation.property.includes('#') && + // !((arg.type instanceof Type.Measure) || (arg.type === Type.Location))) + // canonicalAnnotation.property.push('#'); + // } + + // if property is missing, use the type information + if (!('base' in canonicalAnnotation)) { + if (typestr.startsWith('Measure')) { + const base = func.name.toLowerCase(); + canonicalAnnotation['base'] = [base]; + canonicalAnnotation['property'] = [base]; + } } // if property is missing, use the type information @@ -308,7 +718,6 @@ function generateActionExamplesByPOS(action : Ast.FunctionDef, default: return []; } - } function generateFilterAst(device : string, @@ -420,19 +829,20 @@ export function initArgparse(subparsers : argparse.SubParser) { required: true, help: `The name of the device to be synthesized.` }); + parser.add_argument('-s', '--sampleSize', { + required: false, + default: 1, + help: `The number of samples to be synthesized per annotation.` + }); parser.add_argument('-f', '--function', { required: false, - help: `A specific function to be sampled` + help: `A specific function to be sampled.` }); } -export async function execute(args : any) { - process.stdout.write("Generating samples... "); - checkOutputPath(args); - const tpClient = new Tp.FileClient(args); - const schemaRetriever = new SchemaRetriever(tpClient, null, false); - const deviceClass = await schemaRetriever.getFullSchema(args.device); - const baseTokenizer : I18n.BaseTokenizer = I18n.get(args.locale).getTokenizer(); +export default async function sampler(deviceClass : Ast.ClassDef, + baseTokenizer : I18n.BaseTokenizer, + args : any) { const functionNames = Object.keys(deviceClass.queries).concat(Object.keys(deviceClass.actions)); const sampleMeta = await parseConstantFile(args.locale, args.constants); const utteranceThingtalkPairs : NewParaphraseExample[] = []; @@ -442,28 +852,32 @@ export async function execute(args : any) { continue; const func = deviceClass.queries[fname] || deviceClass.actions[fname]; const typeCounts = countArgTypes(func); - for (const arg of func.iterateArguments()) { - const sampleValues = retrieveSampleValues(deviceClass, sampleMeta, fname, arg); - const canonicalAnnotation = generateBaseCanonicalAnnotation(func, arg, typeCounts, functionNames, false); + for (const argDef of func.iterateArguments()) { + // if (argDef.direction !== Ast.ArgDirection.OUT) + // continue; + if (argDef.name.indexOf('.') >= 0) + continue; + const sampleValues = await retrieveSampleValues(deviceClass, sampleMeta, fname, argDef, args.sampleSize); + const canonicalAnnotation = generateBaseCanonicalAnnotation(func, argDef, typeCounts, functionNames, false); let utteranceExamples : ParaphraseExample[]; if (deviceClass.actions[fname]) - utteranceExamples = generateActionExamples(func, arg, canonicalAnnotation, sampleValues); + utteranceExamples = generateActionExamples(func, argDef, canonicalAnnotation, sampleValues); else - utteranceExamples = generateQueryExamples(func, arg, canonicalAnnotation, sampleValues); + utteranceExamples = generateQueryExamples(func, argDef, canonicalAnnotation, sampleValues); for (const ex of utteranceExamples) { const example = ex as NewParaphraseExample; const prepUtterance = baseTokenizer.tokenize(example.utterance).tokens.join(' '); let program : Ast.Program; if (deviceClass.actions[fname]) { if (example.value) { - const { value, } = toThingtalkValue(deviceClass, sampleMeta, fname, arg, `${example.value}`); + const { value, } = toThingtalkValue(deviceClass, sampleMeta, fname, argDef, `${example.value}`); program = generateActionAst(fname, example.argument, value); } else { continue; } } else { if (example.value) { - const { value, op } = toThingtalkValue(deviceClass, sampleMeta, fname, arg, `${example.value}`); + const { value, op } = toThingtalkValue(deviceClass, sampleMeta, fname, argDef, `${example.value}`); program = generateFilterAst(args.device, fname, example.argument, op, value); } else { program = generateProjectionAst(args.device, fname, example.argument); @@ -475,13 +889,25 @@ export async function execute(args : any) { } catch(err) { console.log(prepUtterance); console.log(program.prettyprint()); + console.log(example); throw err; } - example.utterance = prepUtterance; + example.utterance = prepUtterance.replace(/_/g, ' '); utteranceThingtalkPairs.push(example); } } } + return utteranceThingtalkPairs; +} + +export async function execute(args : any) { + process.stdout.write("Generating samples... "); + checkOutputPath(args); + const tpClient = new Tp.FileClient(args); + const schemaRetriever = new SchemaRetriever(tpClient, null, false); + const deviceClass = await schemaRetriever.getFullSchema(args.device); + const baseTokenizer : I18n.BaseTokenizer = I18n.get(args.locale).getTokenizer(); + const utteranceThingtalkPairs = await sampler(deviceClass, baseTokenizer, args); const output = toTSV(args.device, utteranceThingtalkPairs, false); // console.log(output); args.output.write(output);