Skip to content

Commit

Permalink
add filterAst
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
jmhw0123 committed Feb 14, 2022
1 parent c27c0a3 commit ba2a5de
Showing 1 changed file with 137 additions and 98 deletions.
235 changes: 137 additions & 98 deletions tool/synthetic-data-sampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import * as utils from '../lib/utils/misc-utils';
import { serializePrediction } from '../lib/utils/thingtalk';
import { EntityUtils } from '../lib';

var path = require('path');

export interface Entity {
value : string;
display : string;
Expand Down Expand Up @@ -79,7 +81,7 @@ export function initArgparse(subparsers : argparse.SubParser) {
description: "Automatically generate samples from the canonicals"
});
parser.add_argument('-o', '--output', {
required: true,
required: false,
type: fs.createWriteStream
});
parser.add_argument('-l', '--locale', {
Expand All @@ -99,16 +101,16 @@ export function initArgparse(subparsers : argparse.SubParser) {
help: 'TSV file containing the paths to datasets for strings and entity types.'
});
parser.add_argument('--sample-size', {
default: 10,
default: 1,
help: 'Number of samples per entity or string value'
});
parser.add_argument('--devices', {
required: false,
help: `The list of devices to sample, separated by comma`
});
parser.add_argument('--paraphraser-model', {
parser.add_argument('--query', {
required: false,
help: `A path to the directory where the bart paraphraser model is saved`
help: `A specific query function to be sampled`
});
}

Expand Down Expand Up @@ -174,11 +176,11 @@ function countArgTypes(schema : Ast.FunctionDef) : Record<string, number> {
return count;
}

async function retrieveEntitySamples(constProvider : ParameterProvider, name : string) {
async function retrieveEntitySamples(constProvider : ParameterProvider, name : string, sampleSize : number) {
const data = await constProvider.getEntity(name);
if (data.length === 0)
return [];
const sampled = sampleEntities(1, data);
const sampled = sampleEntities(sampleSize, data);
return sampled;
}

Expand All @@ -193,20 +195,15 @@ async function retrieveStringSamples(constProvider : ParameterProvider, name : s
async function sampleConstants(functions : Record<string, Ast.FunctionDef>,
constProvider : Tp.FileParameterProvider,
device : any,
locale : any) {
locale : any,
sampleSize : number) {
const constants : Record<string, Constant[]> = {};
for (const f in functions) {
const functionDef = functions[f];
for (const argument of functionDef.iterateArguments()) {
const arg = argument.name;
const string_values = argument.getImplementationAnnotation<string>('string_values');
const entityType = getEntityType(argument.type);
// const queryCanonical = Array.isArray(argument.nl_annotations.canonical) ? argument.nl_annotations.canonical[0] : argument.nl_annotations.canonical;
// console.log(`arg: ${arg}`);
// console.log(`string_values: ${string_values}`);
// console.log(`entityType: ${entityType}`);
// console.log("canonical:");
// console.log(queryCanonical);
if (string_values) {
let samples : string[] = await retrieveStringSamples(constProvider, `org.schema:${f}_${arg}`, locale);
if (samples.length === 0)
Expand Down Expand Up @@ -248,7 +245,7 @@ async function sampleConstants(functions : Record<string, Ast.FunctionDef>,
});
}
} else if (entityType) {
const samples = await retrieveEntitySamples(constProvider, entityType);
const samples = await retrieveEntitySamples(constProvider, entityType, sampleSize);
samples.forEach((sample) => {
const key = `param:@${device}.${f}:${arg}:Entity(${entityType})`;
const obj = { key: sample.value, value: sample.value, display: sample.display };
Expand All @@ -266,8 +263,6 @@ async function sampleConstants(functions : Record<string, Ast.FunctionDef>,
});
}
}
// console.log("constants:");
// console.log(constants);
}
}
return constants;
Expand Down Expand Up @@ -357,6 +352,66 @@ function generateBaseCanonicalAnnotation(func : Ast.FunctionDef,
return canonicalAnnotation;
}

function generateFilterAst(device : string, func : string, property : string, operator : string, value : any) : Ast.Program {
const invocation = new Ast.InvocationExpression(
null,
new Ast.Invocation(null, new Ast.DeviceSelector(null, device, null, null), func, [], null),
null
);
const filter = new Ast.AtomBooleanExpression(
null,
property,
operator,
value, //
null
);
const filtered = new Ast.FilterExpression(
null,
invocation,
filter,
null
);
const statement = new Ast.ExpressionStatement(
null,
new Ast.ChainExpression(null, [filtered], null)
);
return new Ast.Program(
null,
[],
[],
[statement],
{}
);
}


function generateProjectionAst(device : string, func : string, property : string) : Ast.Program {
const invocation = new Ast.InvocationExpression(
null,
new Ast.Invocation(null, new Ast.DeviceSelector(null, device, null, null), func, [], null),
null
);
const projection = new Ast.ProjectionExpression(
null,
invocation,
[property],
[],
[],
null
);
const statement = new Ast.ExpressionStatement(
null,
new Ast.ChainExpression(null, [projection], null)
);
return new Ast.Program(
null,
[],
[],
[statement],
{}
);
}

function toTSV(device : string, data : NewParaphraseExample[], useHeading : boolean) {
let headings : string = '';
if (useHeading)
Expand All @@ -366,118 +421,102 @@ function toTSV(device : string, data : NewParaphraseExample[], useHeading : bool
const tmp = [
id,
colValue["utterance"],
colValue["thingtalk"],
colValue["query"],
colValue["queryCanonical"],
colValue["argument"],
colValue["value"]
colValue["thingtalk"]
]
return acc.concat([tmp.join('\t')]);
}, []).join('\n');
return `${headings}${rows}`;
}

function toThingtalkValue(value : string, type : Type) : { value : Ast.Value, op : string } {
if (type instanceof Type.Entity) {
value = `"${value}"`;
return {value: new Ast.Value.Entity(value, 'tt:device', value), op: "=~" };
}
if (type instanceof Type.Enum) {
return { value: new Ast.Value.Enum(value), op: "==" };
}
if (type === Type.String) {
value = `"${value}"`;
// return { value: new Ast.Value.String(value), op: "=~" };
return {value: new Ast.Value.Entity(value, 'tt:device', value), op: "=~" };
}
if (type === Type.Number) {
return { value: new Ast.Value.Number(parseFloat(value)), op: "==" };
}
if (type instanceof Type.Array) {
type = type.elem as Type;
if (type instanceof Type.Entity) {
value = `"${value}"^^${type.type}("${value}")`;
return { value: new Ast.Value.Entity(value, 'tt:device', value), op: "contains" };
} else {
throw new Error(`Unsupported value type: ${type}`);
}
}
throw new Error(`Unsupported value type: ${type}`);
}

function checkOrCreateOutputPath(args : any) {
if (!args.output) {
const outputDir = path.join(path.dirname(args.thingpedia), "test");
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir);
}
args.output = fs.createWriteStream(path=path.join(outputDir, "results.tsv"));
}
}

export async function execute(args : any) {
process.stdout.write("Generating samples... ");
checkOrCreateOutputPath(args);
const tpClient = new Tp.FileClient(args);
const schemaRetriever = new SchemaRetriever(tpClient, null, !args.debug);
const device = args.devices.split(',')[0];
const locale = args.locale;
const sampleSize = args.sample_size;
const deviceClass = await schemaRetriever.getFullSchema(device);
const functions = Object.assign({}, deviceClass.queries, deviceClass.actions);
// const functions = Object.assign({}, deviceClass.queries, deviceClass.actions);
const functions = Object.assign({}, deviceClass.queries);
let sampledConstants : Record<string, Constant[]> = {};
if (args.constants) {
sampledConstants = await parseConstantFile(args.locale, args.constants);
} else {
const constProvider = new Tp.FileParameterProvider(args.parameter_datasets, args.locale);
await constProvider.load();
sampledConstants = await sampleConstants(functions, constProvider, device, locale);
sampledConstants = await sampleConstants(functions, constProvider, device, locale, sampleSize);
}
// const examples : ParaphraseExample[] = [];
const ttExamples : NewParaphraseExample[] = [];
const queries = Object.keys(deviceClass.queries).concat(Object.keys(deviceClass.actions));
const options = { locale: 'en', timezone: undefined, includeEntityValue: true };
const examples : NewParaphraseExample[] = [];
// const queries = Object.keys(deviceClass.queries).concat(Object.keys(deviceClass.actions));
const queries = Object.keys(deviceClass.queries)
const options = { locale: locale, timezone: undefined, includeEntityValue: true };
for (const fname of queries) {
const func = deviceClass.queries[fname] || deviceClass.actions[fname];
if (args.query && fname !== args.query)
continue;
// const func = deviceClass.queries[fname] || deviceClass.actions[fname];
const func = functions[fname];
const typeCounts = countArgTypes(func);
for (const arg of func.iterateArguments()) {
const sampleValues = retrieveSamples(deviceClass, sampledConstants, fname, arg);
// console.log(sampleValues);
const canonicalAnnotation = generateBaseCanonicalAnnotation(func, arg, typeCounts, queries, false);
const thingtalkExamples = generateExamples(func, arg, canonicalAnnotation, sampleValues);
for (const ex of thingtalkExamples) {
const newEx = ex as NewParaphraseExample;
const preprocessed = newEx.utterance;
const property = newEx.argument;
const prog = generateProjectionAst(device, fname, property);
newEx.thingtalk = serializePrediction(prog, preprocessed, EntityUtils.makeDummyEntities(preprocessed), options).join(' ');
ttExamples.push(newEx);
let program : Ast.Program;
if (newEx.value) {
// console.log(typeToString(arg.type));
// console.log(newEx.utterance);
const { value, op } = toThingtalkValue(`${newEx.value}`, arg.type);
program = generateFilterAst(device, fname, newEx.argument, op, value);
} else {
program = generateProjectionAst(device, fname, newEx.argument);
}
newEx.thingtalk = serializePrediction(program, newEx.utterance, EntityUtils.makeDummyEntities(newEx.utterance), options).join(' ');
examples.push(newEx);
}
// examples.push(...ttExamples);
}
}
const output = toTSV(device.split('.').pop(), ttExamples, false);
console.log(output);
const output = toTSV(device.split('.').pop(), examples, false);
// console.log(output);
args.output.write(output);
process.stdout.write("Done!\n");
process.stdout.write(`Done!\nFile location: ${args.output.path}\n`);
}

// function generateFilterAst(device : string, func : string, property : string, operator : string, value : any) : Ast.Program {
// const invocation = new Ast.InvocationExpression(
// null,
// new Ast.Invocation(null, new Ast.DeviceSelector(null, device, null, null), func, [], null),
// null
// );
// const filter = new Ast.AtomBooleanExpression(
// null,
// property,
// operator,
// value, //
// null
// );
// const filtered = new Ast.FilterExpression(
// null,
// invocation,
// filter,
// null
// );
// const statement = new Ast.ExpressionStatement(
// null,
// new Ast.ChainExpression(null, [filtered], null)
// );
// return new Ast.Program(
// null,
// [],
// [],
// [statement],
// {}
// );
// }


function generateProjectionAst(device : string, func : string, property : string) : Ast.Program {
const invocation = new Ast.InvocationExpression(
null,
new Ast.Invocation(null, new Ast.DeviceSelector(null, device, null, null), func, [], null),
null
);
const projection = new Ast.ProjectionExpression(
null,
invocation,
[property],
[],
[],
null
);
const statement = new Ast.ExpressionStatement(
null,
new Ast.ChainExpression(null, [projection], null)
);
return new Ast.Program(
null,
[],
[],
[statement],
{}
);
}

0 comments on commit ba2a5de

Please sign in to comment.