Skip to content

Commit

Permalink
Merge pull request #283 from stanford-oval/wip/entity-linking
Browse files Browse the repository at this point in the history
Pass thingtalk code to nlp server
  • Loading branch information
Mehrad0711 committed Aug 30, 2020
2 parents a182bbf + 7261303 commit fa8baa1
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 14 deletions.
8 changes: 7 additions & 1 deletion lib/dataset-tools/evaluation/sentence_evaluator.js
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SentenceEvaluator {
this._locale = options.locale;
this._tokenized = options.tokenized;
this._debug = options.debug;
this._oracle = options.oracle;
this._tokenizer = tokenizer;
this._target = target;

Expand All @@ -91,6 +92,7 @@ class SentenceEvaluator {
this._preprocessed = ex.preprocessed;
this._targetPrograms = ex.target_code;
this._predictions = ex.predictions;

}

_hasNumeric(code) {
Expand Down Expand Up @@ -194,7 +196,11 @@ class SentenceEvaluator {
predictions = this._predictions;
} else {
try {
let answer = undefined;
if (this._oracle)
answer = firstTargetCode;
const parsed = await this._parser.sendUtterance(this._preprocessed, contextCode, contextEntities, {
answer: answer,
tokenized: this._tokenized,
skip_typechecking: true
});
Expand Down Expand Up @@ -298,6 +304,7 @@ class SentenceEvaluator {

if (first && this._debug && result_string !== 'ok')
console.log(`${this._id}\t${result_string}\t${this._preprocessed}\t${target}\t${normalizedCode}`);

first = false;
ok = ok || beam_ok;
ok_without_param = ok_without_param || beam_ok_without_param;
Expand Down Expand Up @@ -329,7 +336,6 @@ class SentenceEvaluatorStream extends Stream.Transform {

_transform(ex, encoding, callback) {
const evaluator = new SentenceEvaluator(this._parser, this._options, this._tokenizer, this._target, ex);

evaluator.evaluate().then((result) => callback(null, result), (err) => callback(err));
}

Expand Down
10 changes: 7 additions & 3 deletions lib/prediction/localparserclient.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ module.exports = class LocalParserClient {
entities = tokenized.entities;
}

let answer = undefined;
if (options.answer)
answer = options.answer;

let result = null;
let exact = null;

Expand Down Expand Up @@ -107,9 +111,9 @@ module.exports = class LocalParserClient {
} else {
let candidates;
if (contextCode)
candidates = await this._predictor.predict(contextCode.join(' '), tokens.join(' '), NLU_TASK);
candidates = await this._predictor.predict(contextCode.join(' '), tokens.join(' '), answer, NLU_TASK);
else
candidates = await this._predictor.predict(tokens.join(' '), undefined, SEMANTIC_PARSING_TASK);
candidates = await this._predictor.predict(tokens.join(' '), undefined, answer, SEMANTIC_PARSING_TASK);
result = candidates.map((c) => {
return {
code: c.answer.split(' '),
Expand Down Expand Up @@ -152,7 +156,7 @@ module.exports = class LocalParserClient {
}

async generateUtterance(contextCode, contextEntities, targetAct) {
let candidates = this._predictor.predict(contextCode.join(' ') + ' ' + targetAct.join(' '), NLG_QUESTION, NLG_TASK);
let candidates = this._predictor.predict(contextCode.join(' ') + ' ' + targetAct.join(' '), NLG_QUESTION, undefined, NLG_TASK);
candidates = candidates.map((cand) => {
return {
answer: this._langPack.postprocessNLG(cand.answer, contextEntities),
Expand Down
14 changes: 8 additions & 6 deletions lib/prediction/predictor.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class Worker extends events.EventEmitter {
];
if (process.env.GENIENLP_EMBEDDINGS)
args.push('--embeddings', process.env.GENIENLP_EMBEDDINGS);
if (process.env.GENIENLP_DATABASE)
args.push('--database', process.env.GENIENLP_DATABASE);

this._child = child_process.spawn('genienlp', args, {
stdio: ['pipe', 'pipe', 'inherit']
Expand Down Expand Up @@ -107,7 +109,7 @@ class Worker extends events.EventEmitter {
this._requests.clear();
}

request(task, context, question) {
request(task, context, question, answer) {
const id = this._nextId ++;

let resolve, reject;
Expand All @@ -119,7 +121,7 @@ class Worker extends events.EventEmitter {

assert(typeof context === 'string');
assert(typeof question === 'string');
this._stream.write({ id, context, question, task });
this._stream.write({ id, context, question, answer, task });
return promise;
}
}
Expand Down Expand Up @@ -189,20 +191,20 @@ module.exports = class Predictor {
return worker;
}

predict(context, question = DEFAULT_QUESTION, task = 'almond') {
predict(context, question = DEFAULT_QUESTION, answer, task = 'almond') {
// first pick a worker that is free
for (let worker of this._workers) {
if (worker.ok && !worker.busy)
return worker.request(task, context, question);
return worker.request(task, context, question, answer);
}

// failing that, pick any worker that is alive
for (let worker of this._workers) {
if (worker.ok)
return worker.request(task, context, question);
return worker.request(task, context, question, answer);
}

// failing that, spawn a new worker
return this._startWorker().request(task, context, question);
return this._startWorker().request(task, context, question, answer);
}
};
2 changes: 2 additions & 0 deletions lib/prediction/remoteparserclient.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ module.exports = class RemoteParserClient {
data.expect = String(options.expect);
if (options.choices)
data.choices = options.choices;
if (options.answer)
data.answer = options.answer;

const response = await Tp.Helpers.Http.post(`${this._baseUrl}/query`, JSON.stringify(data), {
dataContentType: 'application/json' //'
Expand Down
2 changes: 1 addition & 1 deletion tool/evaluate-dialog.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ module.exports = {
});
parser.add_argument('--url', {
required: false,
help: "URL of the server to evaluate. Use a file:// URL pointing to a model directory to evaluate using a local instance of decanlp",
help: "URL of the server to evaluate. Use a file:// URL pointing to a model directory to evaluate using a local instance of genienlp",
default: 'http://127.0.0.1:8400',
});
parser.add_argument('--tokenized', {
Expand Down
10 changes: 8 additions & 2 deletions tool/evaluate-server.js
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ module.exports = {
required: false,
type: Number,
default: 0,
help: 'Collapse all examples of complexity smaller or equal to this'
help: 'Collapse all examples of complexity smaller or equal to this',
});
parser.add_argument('--oracle', {
action: 'store_true',
help: 'Indicates evaluation of an oracle model where ThingTalk code should be passed to the genienlp server',
default: false
});
},

Expand All @@ -181,7 +186,8 @@ module.exports = {
thingpediaClient: tpClient,
tokenized: args.tokenized,
debug: args.debug,
complexityMetric: args.complexity_metric
complexityMetric: args.complexity_metric,
oracle: args.oracle
}))
.pipe(new CollectSentenceStatistics({
minComplexity: args.min_complexity,
Expand Down
2 changes: 1 addition & 1 deletion tool/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ module.exports = {
});
parser.add_argument('--nlg-model', {
required: false,
help: "Path to the NLU model, pointing to a model directory.",
help: "Path to the NLG model, pointing to a model directory.",
});
parser.add_argument('--thingpedia', {
required: true,
Expand Down

0 comments on commit fa8baa1

Please sign in to comment.