-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Handle context - Handle intent flags - Able to retrieve user prompt
- Loading branch information
Showing
1 changed file
with
122 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,164 @@ | ||
from kadot.classifiers import NeuralClassifier | ||
from kadot.models import CRFExtractor | ||
from kadot.tokenizers import regex_tokenizer, Tokens | ||
from kadot.utils import SavedObject | ||
from kadot.vectorizers import VectorDict | ||
import logging | ||
from typing import Any, Optional, Sequence | ||
from typing import Any, Callable, Optional, Sequence | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ConversationNode(SavedObject): | ||
class Context(object): | ||
""" | ||
Keeps track of the data used in a conversation. | ||
""" | ||
def __init__(self, name: str): | ||
self.name = name | ||
self.age = 0 | ||
self.data = {} | ||
self.data_track = {} # Indicate if data are expired | ||
self.intent_flag = None | ||
|
||
def __init__(self, word_vectors: Optional[VectorDict] = None): | ||
def __setitem__(self, key, value): | ||
if value or key not in self.data.keys(): # No empty data in the context | ||
self.data[key] = value | ||
self.data_track[key] = self.age | ||
|
||
def __delitem__(self, key): | ||
del self.data_track[key], self.data[key] | ||
|
||
def __getitem__(self, item): | ||
if self.data_track[item] + 2 < self.age: # If data is too old | ||
self.data[item] = '' | ||
|
||
return self.data[item] | ||
|
||
def step(self): | ||
self.age += 1 | ||
|
||
|
||
class Intent(object): | ||
""" | ||
A container for bot's intents. | ||
""" | ||
|
||
def __init__(self, | ||
name: str, | ||
func: Callable[[str, Context], Any], | ||
entities: Sequence[str] = [], | ||
samples: Sequence[str] = [] | ||
): | ||
self.name = name | ||
self.run = func | ||
self.entities = entities | ||
self.samples = samples | ||
|
||
|
||
class Agent(SavedObject): | ||
|
||
def __init__(self, | ||
word_vectors: Optional[VectorDict] = None, | ||
tokenizer: Callable[..., Tokens] = regex_tokenizer | ||
): | ||
""" | ||
:param word_vectors: a VectorDict object containing the word vectors | ||
that will be used to train the classifier (optional). | ||
:param tokenizer: the word tokenizer to use. | ||
""" | ||
|
||
self.word_vectors = word_vectors | ||
self.classifier = None | ||
self.tokenizer = tokenizer | ||
self.word_vectors = word_vectors | ||
|
||
self.intents = {} | ||
self.entities = {} | ||
self.conversation_contexts = {} | ||
self.contexts = {} | ||
|
||
self.intent_functions = {} | ||
self.intent_samples = {} | ||
self.intent_entities = {} | ||
def add_entity(self, name: str, extractor: CRFExtractor): | ||
self.entities[name] = extractor | ||
|
||
def intent(self, samples: Sequence[str]): | ||
def intent(self, samples: Sequence[str], entities: Sequence[str] = []): | ||
|
||
def wrapper(intent_function): | ||
self.intent_functions[intent_function.__name__] = intent_function | ||
|
||
for sample in samples: | ||
self.intent_samples[sample.lower()] = intent_function.__name__ | ||
self.intents[intent_function.__name__] = Intent( | ||
name=intent_function.__name__, | ||
func=intent_function, | ||
entities=entities, | ||
samples=samples | ||
) | ||
|
||
return intent_function | ||
|
||
return wrapper | ||
|
||
def require_entity(self, name: str): | ||
def prompt(self, | ||
message: Any, | ||
key: str, | ||
callback: Callable[[str, Context], Any], | ||
context: Context): | ||
|
||
def wrapper(intent_function): | ||
if intent_function.__name__ in self.intent_entities.keys(): | ||
self.intent_entities[intent_function.__name__].append(name) | ||
else: | ||
self.intent_entities[intent_function.__name__] = [name] | ||
def _prompt(raw, ctx): | ||
""" | ||
An intent to retrieve the user's input and put it in the context. | ||
""" | ||
output = '' | ||
if key in self.entities.keys(): | ||
# Try to use the entity extractor | ||
output = ' '.join(self.entities[key].predict(raw)[0]) | ||
|
||
return intent_function | ||
if output: ctx[key] = output | ||
else: ctx[key] = raw | ||
|
||
return wrapper | ||
return callback(raw, ctx) | ||
|
||
self.intents['_prompt'] = Intent(name='_prompt', func=_prompt) | ||
context.intent_flag = '_prompt' | ||
|
||
def add_entity(self, name: str, recognizer: CRFExtractor): | ||
self.entities[name] = recognizer | ||
return message, context | ||
|
||
def _get_training_dataset(self): | ||
training_dataset = {} | ||
|
||
for intent in self.intents.values(): | ||
for sample in intent.samples: | ||
training_dataset[sample] = intent.name | ||
|
||
return training_dataset | ||
|
||
def train(self): | ||
self.classifier = NeuralClassifier( | ||
self.intent_samples, | ||
self._get_training_dataset(), | ||
word_vectors=self.word_vectors | ||
) | ||
|
||
def predict(self, text: str, conversation: Optional[Any] = None): | ||
if conversation in self.contexts.keys(): | ||
context = self.contexts[conversation] | ||
else: | ||
context = Context(conversation) | ||
|
||
# Retrieve the intent | ||
best_intent, best_proba = '', 0 | ||
|
||
for intent, proba in self.classifier.predict(text.lower()).items(): | ||
logger.info("{}: {}".format(intent, proba)) | ||
if context.intent_flag is None: | ||
for intent, proba in self.classifier.predict(text.lower()).items(): | ||
logger.info("{}: {}".format(intent, proba)) | ||
|
||
if proba >= best_proba: | ||
best_intent, best_proba = intent, proba | ||
if proba >= best_proba: | ||
best_intent, best_proba = intent, proba | ||
else: # Handle intent flag | ||
best_intent, best_proba = context.intent_flag, 1 | ||
logger.info("Intent flag for {}.".format(best_intent)) | ||
context.intent_flag = None | ||
|
||
# Retrieve entities | ||
extracted_entities = {} | ||
if best_intent in self.intent_entities.keys(): | ||
for entity_name in self.intent_entities[best_intent]: | ||
extracted_entities[entity_name] = \ | ||
self.entities[entity_name].predict(text) | ||
|
||
context_entities = {} | ||
if conversation in self.conversation_contexts.keys(): | ||
# Retrieve other entities from the context | ||
# (even if they are not required by the entities) | ||
context_entities = self.conversation_contexts[conversation] | ||
|
||
self.conversation_contexts[conversation] = extracted_entities | ||
|
||
# Extracted entities have priority over context entities | ||
entities = extracted_entities.copy() | ||
|
||
for entity_name, entity_value in context_entities.items(): | ||
if entity_name in extracted_entities.keys(): | ||
if not extracted_entities[entity_name][0] and entity_value[0]: | ||
entities[entity_name] = entity_value | ||
else: | ||
entities[entity_name] = entity_value | ||
|
||
return self.intent_functions[best_intent](text, entities) | ||
for entity_name in self.intents[best_intent].entities: | ||
context[entity_name] = ' '.join(self.entities[entity_name].predict(text)[0]) | ||
|
||
context.step() | ||
output, context = self.intents[best_intent].run(text, context) | ||
self.contexts[conversation] = context | ||
|
||
return output |