Skip to content

Commit

Permalink
Add a low-term entities memory to the Bot Engine.
Browse files Browse the repository at this point in the history
  • Loading branch information
loristns committed Jun 3, 2018
1 parent 413ea9f commit f6fe0db
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions kadot/bot_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kadot.utils import SavedObject
from kadot.vectorizers import VectorDict
import logging
from typing import Optional, Sequence
from typing import Any, Optional, Sequence

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +19,7 @@ def __init__(self, word_vectors: Optional[VectorDict] = None):
self.word_vectors = word_vectors
self.classifier = None
self.entities = {}
self.conversation_contexts = {}

self.intent_functions = {}
self.intent_samples = {}
Expand Down Expand Up @@ -57,7 +58,7 @@ def train(self):
word_vectors=self.word_vectors
)

def predict(self, text: str):
def predict(self, text: str, conversation: Optional[Any] = None):
best_intent, best_proba = '', 0

for intent, proba in self.classifier.predict(text.lower()).items():
Expand All @@ -67,10 +68,28 @@ def predict(self, text: str):
best_intent, best_proba = intent, proba

# Retrieve entities
entities = {}
extracted_entities = {}
if best_intent in self.intent_entities.keys():
for entity_name in self.intent_entities[best_intent]:
entities[entity_name] = \
extracted_entities[entity_name] = \
self.entities[entity_name].predict(text)

context_entities = {}
if conversation in self.conversation_contexts.keys():
# Retrieve other entities from the context
# (even if they are not required by the entities)
context_entities = self.conversation_contexts[conversation]

self.conversation_contexts[conversation] = extracted_entities

# Extracted entities have priority over context entities
entities = extracted_entities.copy()

for entity_name, entity_value in context_entities.items():
if entity_name in extracted_entities.keys():
if not extracted_entities[entity_name][0] and entity_value[0]:
entities[entity_name] = entity_value
else:
entities[entity_name] = entity_value

return self.intent_functions[best_intent](text, entities)

0 comments on commit f6fe0db

Please sign in to comment.