From 7442c722a7202cb8567f12c3e2f9f4e8ea3dc076 Mon Sep 17 00:00:00 2001 From: Daniel Calvo Date: Mon, 11 Dec 2017 11:26:59 -0600 Subject: [PATCH] Train endpoint for domain, and added retraining of domain after updates in agent, domain and entitites --- .../controllers/import.agent.controller.js | 5 +- .../updateById.agent.controller.js | 44 ++++++++- api/modules/domain/config/domain.route.js | 10 ++ api/modules/domain/config/domain.validator.js | 11 ++- api/modules/domain/controllers/index.js | 5 +- .../controllers/train.domain.controller.js | 94 +++++++++++++++++++ .../updateById.domain.controller.js | 13 ++- ...ainRecognitionTrainingData.domain.tool.js} | 4 +- .../tools/buildTrainingData.domain.tool.js} | 4 +- .../tools/cartesianProduct.domain.tool.js} | 0 .../tools/getAgentData.domain.tool.js} | 0 .../tools/getDomainData.domain.tool.js} | 0 .../getEntitiesCombinations.domain.tool.js} | 2 +- api/modules/domain/tools/index.js | 12 +++ .../retrainDomainRecognizer.domain.tool.js} | 2 +- .../tools/retrainModel.domain.tool.js} | 2 +- .../updateById.entity.controller.js | 14 ++- .../controllers/add.intent.controller.js | 5 +- .../deleteById.intent.controller.js | 5 +- .../updateById.intent.controller.js | 5 +- api/modules/intent/tools/index.js | 8 +- api/test-data/samson.import.json | 2 +- 22 files changed, 216 insertions(+), 31 deletions(-) create mode 100644 api/modules/domain/controllers/train.domain.controller.js rename api/modules/{intent/tools/buildDomainRecognitionTrainingData.intent.tool.js => domain/tools/buildDomainRecognitionTrainingData.domain.tool.js} (98%) rename api/modules/{intent/tools/buildTrainingData.intent.tool.js => domain/tools/buildTrainingData.domain.tool.js} (98%) rename api/modules/{intent/tools/cartesianProduct.intent.tool.js => domain/tools/cartesianProduct.domain.tool.js} (100%) rename api/modules/{intent/tools/getAgentData.intent.tool.js => domain/tools/getAgentData.domain.tool.js} (100%) rename api/modules/{intent/tools/getDomainData.intent.tool.js => domain/tools/getDomainData.domain.tool.js} (100%) rename api/modules/{intent/tools/getEntitiesCombinations.intent.tool.js => domain/tools/getEntitiesCombinations.domain.tool.js} (96%) create mode 100644 api/modules/domain/tools/index.js rename api/modules/{intent/tools/retrainDomainRecognizer.intent.tool.js => domain/tools/retrainDomainRecognizer.domain.tool.js} (96%) rename api/modules/{intent/tools/retrainModel.intent.tool.js => domain/tools/retrainModel.domain.tool.js} (96%) diff --git a/api/modules/agent/controllers/import.agent.controller.js b/api/modules/agent/controllers/import.agent.controller.js index 480c13474..c92c38998 100644 --- a/api/modules/agent/controllers/import.agent.controller.js +++ b/api/modules/agent/controllers/import.agent.controller.js @@ -4,6 +4,7 @@ const Boom = require('boom'); const Flat = require('flat'); const _ = require('lodash'); const IntentTools = require('../../intent/tools'); +const DomainTools = require('../../domain/tools'); module.exports = (request, reply) => { @@ -305,8 +306,8 @@ module.exports = (request, reply) => { domainResult.intents = resultIntents; Async.waterfall([ - Async.apply(IntentTools.retrainModelTool, server, rasa, agentResult.agentName, domainResult.domainName, domainResult.id), - Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, agentResult.agentName, agentResult.id) + Async.apply(DomainTools.retrainModelTool, server, rasa, agentResult.agentName, domainResult.domainName, domainResult.id), + Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, agentResult.agentName, agentResult.id) ], (errTraining) => { if (errTraining){ diff --git a/api/modules/agent/controllers/updateById.agent.controller.js b/api/modules/agent/controllers/updateById.agent.controller.js index 970077dfa..16bd21285 100644 --- a/api/modules/agent/controllers/updateById.agent.controller.js +++ b/api/modules/agent/controllers/updateById.agent.controller.js @@ -295,8 +295,48 @@ module.exports = (request, reply) => { return reply(err, null); } if (requiresRetrain){ - //call retrain here + Async.waterfall([ + (callbackGetDomains) => { + + server.inject(`/agent/${agentId}/domain`, (res) => { + + if (res.statusCode !== 200){ + const error = Boom.create(res.statusCode, 'An error ocurred getting the domains of the agent to train them'); + return callbackGetDomains(error, null); + } + return callbackGetDomains(null, res.result); + }); + }, + (domains, callbackTrainEachDomain) => { + + Async.eachLimit(domains, 1, (domain, callbackMapOfDomain) => { + + server.inject(`/domain/${domain.id}/train`, (res) => { + + if (res.statusCode !== 200){ + const error = Boom.create(res.statusCode, `An error ocurred training the domain ${domain.domain}`); + return callbackMapOfDomain(error); + } + return callbackMapOfDomain(null); + }); + }, (err) => { + + if (err){ + return callbackTrainEachDomain(err); + } + return callbackTrainEachDomain(null); + }); + } + ], (errTraining) => { + + if (errTraining){ + return reply(errTraining); + } + return reply(result); + }); + } + else { + return reply(result); } - return reply(result); }); }; diff --git a/api/modules/domain/config/domain.route.js b/api/modules/domain/config/domain.route.js index f5f0b22d0..62fb0013b 100644 --- a/api/modules/domain/config/domain.route.js +++ b/api/modules/domain/config/domain.route.js @@ -62,6 +62,16 @@ const DomainRoutes = [ validate: DomainValidator.findIntentsByDomainId, handler: DomainController.findIntentsByDomainId } + }, + { + method: 'GET', + path: '/domain/{id}/train', + config: { + description: 'Train the specified domain', + tags: ['api'], + validate: DomainValidator.train, + handler: DomainController.train + } } ]; diff --git a/api/modules/domain/config/domain.validator.js b/api/modules/domain/config/domain.validator.js index 1b38af3e0..8e1e86d48 100644 --- a/api/modules/domain/config/domain.validator.js +++ b/api/modules/domain/config/domain.validator.js @@ -69,7 +69,7 @@ class DomainValidate { params: (() => { return { - id: DomainSchema.id.required().description('Id of the agent') + id: DomainSchema.id.required().description('Id of the domain') }; })(), query: (() => { @@ -81,6 +81,15 @@ class DomainValidate { })() }; + this.train = { + params: (() => { + + return { + id: DomainSchema.id.required().description('Id of the domain') + }; + })() + }; + } } diff --git a/api/modules/domain/controllers/index.js b/api/modules/domain/controllers/index.js index b41fb2f96..59bb96d5a 100644 --- a/api/modules/domain/controllers/index.js +++ b/api/modules/domain/controllers/index.js @@ -5,6 +5,7 @@ const UpdateByIdController = require('./updateById.domain.controller'); const DeleteByIdController = require('./deleteById.domain.controller'); const FindEntitiesByDomainIdController = require('./findEntitiesByDomainId.domain.controller'); const FindIntentsByDomainIdController = require('./findIntentsByDomainId.domain.controller'); +const TrainController = require('./train.domain.controller'); const DomainController = { @@ -18,7 +19,9 @@ const DomainController = { findEntitiesByDomainId: FindEntitiesByDomainIdController, - findIntentsByDomainId: FindIntentsByDomainIdController + findIntentsByDomainId: FindIntentsByDomainIdController, + + train: TrainController }; module.exports = DomainController; diff --git a/api/modules/domain/controllers/train.domain.controller.js b/api/modules/domain/controllers/train.domain.controller.js new file mode 100644 index 000000000..786036d4f --- /dev/null +++ b/api/modules/domain/controllers/train.domain.controller.js @@ -0,0 +1,94 @@ +'use strict'; +const Async = require('async'); +const Boom = require('boom'); +const DomainTools = require('../tools'); + +module.exports = (request, reply) => { + + let agentId = null; + let domain = null; + const domainId = request.params.id; + const server = request.server; + const redis = server.app.redis; + const rasa = server.app.rasa; + + Async.waterfall([ + (callback) => { + + server.inject(`/domain/${domainId}`, (res) => { + + if (res.statusCode !== 200){ + const error = Boom.create(res.statusCode, 'An error ocurred getting the domain'); + return callback(error, null); + } + domain = res.result; + return callback(null); + }); + }, + (callback) => { + + redis.zscore('agents', domain.agent, (err, id) => { + + if (err){ + const error = Boom.badImplementation('An error ocurred checking if the agent exists.'); + return callback(error); + } + if (id){ + agentId = id; + return callback(null); + } + const error = Boom.badRequest(`The agent ${domain.agent} of the specified domain doesn't exist`); + return callback(error); + }); + }, + (callback) => { + + redis.zscore('agents', domain.agent, (err, id) => { + + if (err){ + const error = Boom.badImplementation('An error ocurred checking if the agent exists.'); + return callback(error); + } + if (id){ + agentId = id; + return callback(null); + } + const error = Boom.badRequest(`The agent ${domain.agent} of the specified domain doesn't exist`); + return callback(error); + }); + }, + (callback) => { + + DomainTools.retrainModelTool(server, rasa, domain.agent, domain.domainName, domainId, (err) => { + + if (err){ + return callback(err); + } + return callback(null); + }); + }, + (callback) => { + + DomainTools.retrainDomainRecognizerTool(server, redis, rasa, domain.agent, agentId, (err) => { + + if (err){ + return callback(err); + } + return callback(null); + }); + } + ], (err) => { + + if (err){ + return reply(err); + } + server.inject(`/domain/${domainId}`, (res) => { + + if (res.statusCode !== 200){ + const error = Boom.create(res.statusCode, 'An error ocurred getting the domain after training'); + return callback(error, null); + } + return reply(res.result); + }); + }); +}; diff --git a/api/modules/domain/controllers/updateById.domain.controller.js b/api/modules/domain/controllers/updateById.domain.controller.js index 9ff89fef3..388e1aca9 100644 --- a/api/modules/domain/controllers/updateById.domain.controller.js +++ b/api/modules/domain/controllers/updateById.domain.controller.js @@ -200,8 +200,17 @@ module.exports = (request, reply) => { return reply(err, null); } if (requiresRetrain){ - //call retrain here + server.inject(`/domain/${domainId}/train`, (res) => { + + if (res.statusCode !== 200){ + const error = Boom.create(res.statusCode, `An error ocurred retraining the domain ${result.domain} after the update`); + reply(error); + } + reply(res.result); + }); + } + else { + return reply(result); } - return reply(result); }); }; diff --git a/api/modules/intent/tools/buildDomainRecognitionTrainingData.intent.tool.js b/api/modules/domain/tools/buildDomainRecognitionTrainingData.domain.tool.js similarity index 98% rename from api/modules/intent/tools/buildDomainRecognitionTrainingData.intent.tool.js rename to api/modules/domain/tools/buildDomainRecognitionTrainingData.domain.tool.js index 45d6a8c5f..e80434ade 100644 --- a/api/modules/intent/tools/buildDomainRecognitionTrainingData.intent.tool.js +++ b/api/modules/domain/tools/buildDomainRecognitionTrainingData.domain.tool.js @@ -1,8 +1,8 @@ 'use strict'; const _ = require('lodash'); -const GetAgentData = require('./getAgentData.intent.tool'); -const GetEntitiesCombinations = require('./getEntitiesCombinations.intent.tool'); +const GetAgentData = require('./getAgentData.domain.tool'); +const GetEntitiesCombinations = require('./getEntitiesCombinations.domain.tool'); const buildDomainRecognitionTrainingData = (server, agentId, cb) => { diff --git a/api/modules/intent/tools/buildTrainingData.intent.tool.js b/api/modules/domain/tools/buildTrainingData.domain.tool.js similarity index 98% rename from api/modules/intent/tools/buildTrainingData.intent.tool.js rename to api/modules/domain/tools/buildTrainingData.domain.tool.js index 6aac34ec6..10db5a3ad 100644 --- a/api/modules/intent/tools/buildTrainingData.intent.tool.js +++ b/api/modules/domain/tools/buildTrainingData.domain.tool.js @@ -1,8 +1,8 @@ 'use strict'; const _ = require('lodash'); -const GetDomainData = require('./getDomainData.intent.tool'); -const GetEntitiesCombinations = require('./getEntitiesCombinations.intent.tool'); +const GetDomainData = require('./getDomainData.domain.tool'); +const GetEntitiesCombinations = require('./getEntitiesCombinations.domain.tool'); const buildTrainingData = (server, domainId, callback) => { diff --git a/api/modules/intent/tools/cartesianProduct.intent.tool.js b/api/modules/domain/tools/cartesianProduct.domain.tool.js similarity index 100% rename from api/modules/intent/tools/cartesianProduct.intent.tool.js rename to api/modules/domain/tools/cartesianProduct.domain.tool.js diff --git a/api/modules/intent/tools/getAgentData.intent.tool.js b/api/modules/domain/tools/getAgentData.domain.tool.js similarity index 100% rename from api/modules/intent/tools/getAgentData.intent.tool.js rename to api/modules/domain/tools/getAgentData.domain.tool.js diff --git a/api/modules/intent/tools/getDomainData.intent.tool.js b/api/modules/domain/tools/getDomainData.domain.tool.js similarity index 100% rename from api/modules/intent/tools/getDomainData.intent.tool.js rename to api/modules/domain/tools/getDomainData.domain.tool.js diff --git a/api/modules/intent/tools/getEntitiesCombinations.intent.tool.js b/api/modules/domain/tools/getEntitiesCombinations.domain.tool.js similarity index 96% rename from api/modules/intent/tools/getEntitiesCombinations.intent.tool.js rename to api/modules/domain/tools/getEntitiesCombinations.domain.tool.js index 56f024dca..133c2c5bf 100644 --- a/api/modules/intent/tools/getEntitiesCombinations.intent.tool.js +++ b/api/modules/domain/tools/getEntitiesCombinations.domain.tool.js @@ -2,7 +2,7 @@ const _ = require('lodash'); -const CartesianProduct = require('./cartesianProduct.intent.tool'); +const CartesianProduct = require('./cartesianProduct.domain.tool'); const getCombinationOfEntities = (entities, intents) => { diff --git a/api/modules/domain/tools/index.js b/api/modules/domain/tools/index.js new file mode 100644 index 000000000..cd0e5d234 --- /dev/null +++ b/api/modules/domain/tools/index.js @@ -0,0 +1,12 @@ +'use strict'; +const RetrainModelTool = require('./retrainModel.domain.tool'); +const RetrainDomainRecognizerTool = require('./retrainDomainRecognizer.domain.tool'); + +const DomainTools = { + + retrainModelTool: RetrainModelTool, + + retrainDomainRecognizerTool: RetrainDomainRecognizerTool +}; + +module.exports = DomainTools; diff --git a/api/modules/intent/tools/retrainDomainRecognizer.intent.tool.js b/api/modules/domain/tools/retrainDomainRecognizer.domain.tool.js similarity index 96% rename from api/modules/intent/tools/retrainDomainRecognizer.intent.tool.js rename to api/modules/domain/tools/retrainDomainRecognizer.domain.tool.js index da1e701d6..f76ab2bfa 100644 --- a/api/modules/intent/tools/retrainDomainRecognizer.intent.tool.js +++ b/api/modules/domain/tools/retrainDomainRecognizer.domain.tool.js @@ -3,7 +3,7 @@ const Wreck = require('wreck'); const Boom = require('boom'); -const BuildDomainRecognitionTrainingData = require('./buildDomainRecognitionTrainingData.intent.tool'); +const BuildDomainRecognitionTrainingData = require('./buildDomainRecognitionTrainingData.domain.tool'); const retrainDomainRecognizer = (server, redis, rasa, agentName, agentId, cb) => { diff --git a/api/modules/intent/tools/retrainModel.intent.tool.js b/api/modules/domain/tools/retrainModel.domain.tool.js similarity index 96% rename from api/modules/intent/tools/retrainModel.intent.tool.js rename to api/modules/domain/tools/retrainModel.domain.tool.js index 6767eca2a..70e4cfc69 100644 --- a/api/modules/intent/tools/retrainModel.intent.tool.js +++ b/api/modules/domain/tools/retrainModel.domain.tool.js @@ -3,7 +3,7 @@ const Wreck = require('wreck'); const Boom = require('boom'); const Guid = require('guid'); -const BuildTrainingData = require('./buildTrainingData.intent.tool'); +const BuildTrainingData = require('./buildTrainingData.domain.tool'); const retrainModel = (server, rasa, agentName, domainName, domainId, callback) => { diff --git a/api/modules/entity/controllers/updateById.entity.controller.js b/api/modules/entity/controllers/updateById.entity.controller.js index 32b7c42f3..950ed6a08 100644 --- a/api/modules/entity/controllers/updateById.entity.controller.js +++ b/api/modules/entity/controllers/updateById.entity.controller.js @@ -285,8 +285,19 @@ module.exports = (request, reply) => { } return callbackUpdateIntentsAndScenarios(null); }); + }, + (callbackRetrainDomains) => { + + server.inject(`/domain/${domain}/train`, (res) => { + + if (res.statusCode !== 200){ + const error = Boom.create(res.statusCode, `An error ocurred training the domain ${domain}`); + return callbackMapOfDomains(error); + } + return callbackMapOfDomains(null); + }); } - ], (err, result) => { + ], (err) => { if (err){ return callbackMapOfDomains(err); @@ -306,7 +317,6 @@ module.exports = (request, reply) => { if (err){ return reply(err); } - //call retrain here return reply(updatedEntity); }); } diff --git a/api/modules/intent/controllers/add.intent.controller.js b/api/modules/intent/controllers/add.intent.controller.js index 69027f966..d5b75d474 100644 --- a/api/modules/intent/controllers/add.intent.controller.js +++ b/api/modules/intent/controllers/add.intent.controller.js @@ -3,6 +3,7 @@ const Async = require('async'); const Boom = require('boom'); const Flat = require('flat'); const IntentTools = require('../tools'); +const DomainTools = require('../../domain/tools'); module.exports = (request, reply) => { @@ -132,8 +133,8 @@ module.exports = (request, reply) => { (cb) => { Async.waterfall([ - Async.apply(IntentTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId), - Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId) + Async.apply(DomainTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId), + Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId) ], (err) => { if (err){ diff --git a/api/modules/intent/controllers/deleteById.intent.controller.js b/api/modules/intent/controllers/deleteById.intent.controller.js index bbb4a6e05..2bf0402f4 100644 --- a/api/modules/intent/controllers/deleteById.intent.controller.js +++ b/api/modules/intent/controllers/deleteById.intent.controller.js @@ -2,6 +2,7 @@ const Async = require('async'); const Boom = require('boom'); const IntentTools = require('../tools'); +const DomainTools = require('../../domain/tools'); module.exports = (request, reply) => { @@ -119,8 +120,8 @@ module.exports = (request, reply) => { (callback) => { Async.waterfall([ - Async.apply(IntentTools.retrainModelTool, server, rasa, intent.agent, intent.domain, domainId), - Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, intent.agent, agentId) + Async.apply(DomainTools.retrainModelTool, server, rasa, intent.agent, intent.domain, domainId), + Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, intent.agent, agentId) ], (err) => { if (err){ diff --git a/api/modules/intent/controllers/updateById.intent.controller.js b/api/modules/intent/controllers/updateById.intent.controller.js index de3c1f806..d22ca0cf8 100644 --- a/api/modules/intent/controllers/updateById.intent.controller.js +++ b/api/modules/intent/controllers/updateById.intent.controller.js @@ -3,6 +3,7 @@ const Async = require('async'); const Boom = require('boom'); const Flat = require('flat'); const IntentTools = require('../tools'); +const DomainTools = require('../../domain/tools'); const _ = require('lodash'); const updateDataFunction = (redis, server, rasa, intentId, currentIntent, updateData, agentId, domainId, cb) => { @@ -44,8 +45,8 @@ const updateDataFunction = (redis, server, rasa, intentId, currentIntent, update (callback) => { Async.waterfall([ - Async.apply(IntentTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId), - Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId) + Async.apply(DomainTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId), + Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId) ], (err) => { if (err){ diff --git a/api/modules/intent/tools/index.js b/api/modules/intent/tools/index.js index d24de01df..d8132a57d 100644 --- a/api/modules/intent/tools/index.js +++ b/api/modules/intent/tools/index.js @@ -2,8 +2,6 @@ const ValidateEntitiesTool = require('./validateEntities.intent.tool'); const ValidateEntitiesScenarioTool = require('./validateEntities.scenario.tool'); const UpdateEntitiesDomainTool = require('./updateEntitiesDomain.intent.tool'); -const RetrainModelTool = require('./retrainModel.intent.tool'); -const RetrainDomainRecognizerTool = require('./retrainDomainRecognizer.intent.tool'); const IntentTools = { @@ -11,11 +9,7 @@ const IntentTools = { validateEntitiesScenarioTool: ValidateEntitiesScenarioTool, - updateEntitiesDomainTool: UpdateEntitiesDomainTool, - - retrainModelTool: RetrainModelTool, - - retrainDomainRecognizerTool: RetrainDomainRecognizerTool + updateEntitiesDomainTool: UpdateEntitiesDomainTool }; module.exports = IntentTools; diff --git a/api/test-data/samson.import.json b/api/test-data/samson.import.json index c70a418b4..af7169eac 100644 --- a/api/test-data/samson.import.json +++ b/api/test-data/samson.import.json @@ -1,5 +1,5 @@ { - "agentName": "Imported samson", + "agentName": "Samson", "webhookUrl": "http://localhost:3000", "useWebhookFallback": "false", "fallbackResponses": [