Skip to content

Commit

Permalink
Train endpoint for domain, and added retraining of domain after updat…
Browse files Browse the repository at this point in the history
…es in agent, domain and entitites
  • Loading branch information
dcalvom committed Dec 11, 2017
1 parent b9b289d commit 7442c72
Show file tree
Hide file tree
Showing 22 changed files with 216 additions and 31 deletions.
5 changes: 3 additions & 2 deletions api/modules/agent/controllers/import.agent.controller.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const Boom = require('boom');
const Flat = require('flat');
const _ = require('lodash');
const IntentTools = require('../../intent/tools');
const DomainTools = require('../../domain/tools');

module.exports = (request, reply) => {

Expand Down Expand Up @@ -305,8 +306,8 @@ module.exports = (request, reply) => {
domainResult.intents = resultIntents;

Async.waterfall([
Async.apply(IntentTools.retrainModelTool, server, rasa, agentResult.agentName, domainResult.domainName, domainResult.id),
Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, agentResult.agentName, agentResult.id)
Async.apply(DomainTools.retrainModelTool, server, rasa, agentResult.agentName, domainResult.domainName, domainResult.id),
Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, agentResult.agentName, agentResult.id)
], (errTraining) => {

if (errTraining){
Expand Down
44 changes: 42 additions & 2 deletions api/modules/agent/controllers/updateById.agent.controller.js
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,48 @@ module.exports = (request, reply) => {
return reply(err, null);
}
if (requiresRetrain){
//call retrain here
Async.waterfall([
(callbackGetDomains) => {

server.inject(`/agent/${agentId}/domain`, (res) => {

if (res.statusCode !== 200){
const error = Boom.create(res.statusCode, 'An error ocurred getting the domains of the agent to train them');
return callbackGetDomains(error, null);
}
return callbackGetDomains(null, res.result);
});
},
(domains, callbackTrainEachDomain) => {

Async.eachLimit(domains, 1, (domain, callbackMapOfDomain) => {

server.inject(`/domain/${domain.id}/train`, (res) => {

if (res.statusCode !== 200){
const error = Boom.create(res.statusCode, `An error ocurred training the domain ${domain.domain}`);
return callbackMapOfDomain(error);
}
return callbackMapOfDomain(null);
});
}, (err) => {

if (err){
return callbackTrainEachDomain(err);
}
return callbackTrainEachDomain(null);
});
}
], (errTraining) => {

if (errTraining){
return reply(errTraining);
}
return reply(result);
});
}
else {
return reply(result);
}
return reply(result);
});
};
10 changes: 10 additions & 0 deletions api/modules/domain/config/domain.route.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ const DomainRoutes = [
validate: DomainValidator.findIntentsByDomainId,
handler: DomainController.findIntentsByDomainId
}
},
{
method: 'GET',
path: '/domain/{id}/train',
config: {
description: 'Train the specified domain',
tags: ['api'],
validate: DomainValidator.train,
handler: DomainController.train
}
}
];

Expand Down
11 changes: 10 additions & 1 deletion api/modules/domain/config/domain.validator.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DomainValidate {
params: (() => {

return {
id: DomainSchema.id.required().description('Id of the agent')
id: DomainSchema.id.required().description('Id of the domain')
};
})(),
query: (() => {
Expand All @@ -81,6 +81,15 @@ class DomainValidate {
})()
};

this.train = {
params: (() => {

return {
id: DomainSchema.id.required().description('Id of the domain')
};
})()
};

}
}

Expand Down
5 changes: 4 additions & 1 deletion api/modules/domain/controllers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const UpdateByIdController = require('./updateById.domain.controller');
const DeleteByIdController = require('./deleteById.domain.controller');
const FindEntitiesByDomainIdController = require('./findEntitiesByDomainId.domain.controller');
const FindIntentsByDomainIdController = require('./findIntentsByDomainId.domain.controller');
const TrainController = require('./train.domain.controller');

const DomainController = {

Expand All @@ -18,7 +19,9 @@ const DomainController = {

findEntitiesByDomainId: FindEntitiesByDomainIdController,

findIntentsByDomainId: FindIntentsByDomainIdController
findIntentsByDomainId: FindIntentsByDomainIdController,

train: TrainController
};

module.exports = DomainController;
94 changes: 94 additions & 0 deletions api/modules/domain/controllers/train.domain.controller.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
'use strict';
const Async = require('async');
const Boom = require('boom');
const DomainTools = require('../tools');

module.exports = (request, reply) => {

let agentId = null;
let domain = null;
const domainId = request.params.id;
const server = request.server;
const redis = server.app.redis;
const rasa = server.app.rasa;

Async.waterfall([
(callback) => {

server.inject(`/domain/${domainId}`, (res) => {

if (res.statusCode !== 200){
const error = Boom.create(res.statusCode, 'An error ocurred getting the domain');
return callback(error, null);
}
domain = res.result;
return callback(null);
});
},
(callback) => {

redis.zscore('agents', domain.agent, (err, id) => {

if (err){
const error = Boom.badImplementation('An error ocurred checking if the agent exists.');
return callback(error);
}
if (id){
agentId = id;
return callback(null);
}
const error = Boom.badRequest(`The agent ${domain.agent} of the specified domain doesn't exist`);
return callback(error);
});
},
(callback) => {

redis.zscore('agents', domain.agent, (err, id) => {

if (err){
const error = Boom.badImplementation('An error ocurred checking if the agent exists.');
return callback(error);
}
if (id){
agentId = id;
return callback(null);
}
const error = Boom.badRequest(`The agent ${domain.agent} of the specified domain doesn't exist`);
return callback(error);
});
},
(callback) => {

DomainTools.retrainModelTool(server, rasa, domain.agent, domain.domainName, domainId, (err) => {

if (err){
return callback(err);
}
return callback(null);
});
},
(callback) => {

DomainTools.retrainDomainRecognizerTool(server, redis, rasa, domain.agent, agentId, (err) => {

if (err){
return callback(err);
}
return callback(null);
});
}
], (err) => {

if (err){
return reply(err);
}
server.inject(`/domain/${domainId}`, (res) => {

if (res.statusCode !== 200){
const error = Boom.create(res.statusCode, 'An error ocurred getting the domain after training');
return callback(error, null);
}
return reply(res.result);
});
});
};
13 changes: 11 additions & 2 deletions api/modules/domain/controllers/updateById.domain.controller.js
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,17 @@ module.exports = (request, reply) => {
return reply(err, null);
}
if (requiresRetrain){
//call retrain here
server.inject(`/domain/${domainId}/train`, (res) => {

if (res.statusCode !== 200){
const error = Boom.create(res.statusCode, `An error ocurred retraining the domain ${result.domain} after the update`);
reply(error);
}
reply(res.result);
});
}
else {
return reply(result);
}
return reply(result);
});
};
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
'use strict';

const _ = require('lodash');
const GetAgentData = require('./getAgentData.intent.tool');
const GetEntitiesCombinations = require('./getEntitiesCombinations.intent.tool');
const GetAgentData = require('./getAgentData.domain.tool');
const GetEntitiesCombinations = require('./getEntitiesCombinations.domain.tool');

const buildDomainRecognitionTrainingData = (server, agentId, cb) => {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
'use strict';

const _ = require('lodash');
const GetDomainData = require('./getDomainData.intent.tool');
const GetEntitiesCombinations = require('./getEntitiesCombinations.intent.tool');
const GetDomainData = require('./getDomainData.domain.tool');
const GetEntitiesCombinations = require('./getEntitiesCombinations.domain.tool');

const buildTrainingData = (server, domainId, callback) => {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

const _ = require('lodash');

const CartesianProduct = require('./cartesianProduct.intent.tool');
const CartesianProduct = require('./cartesianProduct.domain.tool');

const getCombinationOfEntities = (entities, intents) => {

Expand Down
12 changes: 12 additions & 0 deletions api/modules/domain/tools/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
'use strict';
const RetrainModelTool = require('./retrainModel.domain.tool');
const RetrainDomainRecognizerTool = require('./retrainDomainRecognizer.domain.tool');

const DomainTools = {

retrainModelTool: RetrainModelTool,

retrainDomainRecognizerTool: RetrainDomainRecognizerTool
};

module.exports = DomainTools;
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
const Wreck = require('wreck');
const Boom = require('boom');

const BuildDomainRecognitionTrainingData = require('./buildDomainRecognitionTrainingData.intent.tool');
const BuildDomainRecognitionTrainingData = require('./buildDomainRecognitionTrainingData.domain.tool');

const retrainDomainRecognizer = (server, redis, rasa, agentName, agentId, cb) => {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
const Wreck = require('wreck');
const Boom = require('boom');
const Guid = require('guid');
const BuildTrainingData = require('./buildTrainingData.intent.tool');
const BuildTrainingData = require('./buildTrainingData.domain.tool');

const retrainModel = (server, rasa, agentName, domainName, domainId, callback) => {

Expand Down
14 changes: 12 additions & 2 deletions api/modules/entity/controllers/updateById.entity.controller.js
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,19 @@ module.exports = (request, reply) => {
}
return callbackUpdateIntentsAndScenarios(null);
});
},
(callbackRetrainDomains) => {

server.inject(`/domain/${domain}/train`, (res) => {

if (res.statusCode !== 200){
const error = Boom.create(res.statusCode, `An error ocurred training the domain ${domain}`);
return callbackMapOfDomains(error);
}
return callbackMapOfDomains(null);
});
}
], (err, result) => {
], (err) => {

if (err){
return callbackMapOfDomains(err);
Expand All @@ -306,7 +317,6 @@ module.exports = (request, reply) => {
if (err){
return reply(err);
}
//call retrain here
return reply(updatedEntity);
});
}
Expand Down
5 changes: 3 additions & 2 deletions api/modules/intent/controllers/add.intent.controller.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const Async = require('async');
const Boom = require('boom');
const Flat = require('flat');
const IntentTools = require('../tools');
const DomainTools = require('../../domain/tools');

module.exports = (request, reply) => {

Expand Down Expand Up @@ -132,8 +133,8 @@ module.exports = (request, reply) => {
(cb) => {

Async.waterfall([
Async.apply(IntentTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId),
Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId)
Async.apply(DomainTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId),
Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId)
], (err) => {

if (err){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
const Async = require('async');
const Boom = require('boom');
const IntentTools = require('../tools');
const DomainTools = require('../../domain/tools');

module.exports = (request, reply) => {

Expand Down Expand Up @@ -119,8 +120,8 @@ module.exports = (request, reply) => {
(callback) => {

Async.waterfall([
Async.apply(IntentTools.retrainModelTool, server, rasa, intent.agent, intent.domain, domainId),
Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, intent.agent, agentId)
Async.apply(DomainTools.retrainModelTool, server, rasa, intent.agent, intent.domain, domainId),
Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, intent.agent, agentId)
], (err) => {

if (err){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const Async = require('async');
const Boom = require('boom');
const Flat = require('flat');
const IntentTools = require('../tools');
const DomainTools = require('../../domain/tools');
const _ = require('lodash');

const updateDataFunction = (redis, server, rasa, intentId, currentIntent, updateData, agentId, domainId, cb) => {
Expand Down Expand Up @@ -44,8 +45,8 @@ const updateDataFunction = (redis, server, rasa, intentId, currentIntent, update
(callback) => {

Async.waterfall([
Async.apply(IntentTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId),
Async.apply(IntentTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId)
Async.apply(DomainTools.retrainModelTool, server, rasa, resultIntent.agent, resultIntent.domain, domainId),
Async.apply(DomainTools.retrainDomainRecognizerTool, server, redis, rasa, resultIntent.agent, agentId)
], (err) => {

if (err){
Expand Down

0 comments on commit 7442c72

Please sign in to comment.