-
Notifications
You must be signed in to change notification settings - Fork 45
/
evaluate.js
92 lines (78 loc) · 2.98 KB
/
evaluate.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// -*- mode: js; indent-tabs-mode: nil; js-basic-offset: 4 -*-
//
// This file is part of ThingEngine
//
// Copyright 2018 The Board of Trustees of the Leland Stanford Junior University
//
// Author: Giovanni Campagna <gcampagn@cs.stanford.edu>
//
// See COPYING for details
"use strict";
const byline = require('byline');
const path = require('path');
const fs = require('fs');
const Genie = require('genie-toolkit');
const ThingTalk = require('thingtalk');
const AdminThingpediaClient = require('../../util/admin-thingpedia-client');
const AbstractFS = require('../../util/abstract_fs');
const TokenizerService = require('../../util/tokenizer_service');
class LocalParserClient {
constructor(modeldir, locale) {
this._locale = locale;
this._tokenizer = TokenizerService.getLocal();
this._predictor = new Genie.Predictor('local', modeldir, 1);
}
async start() {
await this._predictor.start();
}
async stop() {
await this._predictor.stop();
}
async tokenize(utterance, contextEntities) {
const tokenized = await this._tokenizer.tokenize(this._locale, utterance);
Genie.Utils.renumberEntities(tokenized, contextEntities);
return tokenized;
}
async sendUtterance(utterance, tokenized, contextCode, contextEntities) {
let tokens, entities;
if (tokenized) {
tokens = utterance.split(' ');
entities = {};
Object.assign(entities, contextEntities);
} else {
const tokenized = await this._tokenizer.tokenize(this._locale, utterance);
Genie.Utils.renumberEntities(tokenized, contextEntities);
tokens = tokenized.tokens;
entities = tokenized.entities;
}
const candidates = await this._predictor.predict(tokens, contextCode);
return { tokens, candidates, entities };
}
}
module.exports = async function main(task, argv) {
task.handleKill();
const jobdir = await AbstractFS.download(task.jobDir + '/');
const datadir = path.resolve(jobdir, 'dataset');
const outputdir = path.resolve(jobdir, 'output');
const tpClient = new AdminThingpediaClient(task.language);
const schemas = new ThingTalk.SchemaRetriever(tpClient, null, true);
const parser = new LocalParserClient(outputdir, task.language);
await parser.start();
const output = fs.createReadStream(path.resolve(datadir, 'eval.tsv'))
.setEncoding('utf8')
.pipe(byline())
.pipe(new Genie.DatasetParser({
contextual: task.modelInfo.contextual,
preserveId: true,
parseMultiplePrograms: true
}))
.pipe(new Genie.SentenceEvaluatorStream(parser, schemas, true /* tokenized */, argv.debug))
.pipe(new Genie.CollectSentenceStatistics());
const result = await output.read();
await task.setMetrics(result);
await Promise.all([
parser.stop(),
TokenizerService.tearDown(),
AbstractFS.removeTemporary(jobdir)
]);
};