diff --git a/bothub/api/v2/nlp/views.py b/bothub/api/v2/nlp/views.py index 193f4d49..17ecf682 100644 --- a/bothub/api/v2/nlp/views.py +++ b/bothub/api/v2/nlp/views.py @@ -53,14 +53,20 @@ class RepositoryAuthorizationTrainViewSet( def retrieve(self, request, *args, **kwargs): check_auth(request) repository_authorization = self.get_object() - current_version = repository_authorization.repository.current_version( - str(request.query_params.get("language")) - ) + repository_version = request.query_params.get("repository_version") + if repository_version: + current_version = repository_authorization.repository.get_specific_version_id( + repository_version, str(request.query_params.get("language")) + ) + else: + current_version = repository_authorization.repository.current_version( + str(request.query_params.get("language")) + ) return Response( { "ready_for_train": current_version.ready_for_train, - "current_update_id": current_version.id, + "current_version_id": current_version.id, "repository_authorization_user_id": repository_authorization.user.id, "language": current_version.language, } @@ -70,7 +76,7 @@ def retrieve(self, request, *args, **kwargs): def get_examples(self, request, **kwargs): check_auth(request) queryset = get_object_or_404( - RepositoryVersionLanguage, pk=request.query_params.get("update_id") + RepositoryVersionLanguage, pk=request.query_params.get("repository_version") ) page = self.paginate_queryset(queryset.examples) @@ -87,7 +93,7 @@ def get_examples(self, request, **kwargs): def get_examples_labels(self, request, **kwargs): check_auth(request) queryset = get_object_or_404( - RepositoryVersionLanguage, pk=request.query_params.get("update_id") + RepositoryVersionLanguage, pk=request.query_params.get("repository_version") ) page = self.paginate_queryset( @@ -107,7 +113,7 @@ def start_training(self, request, **kwargs): check_auth(request) repository = get_object_or_404( - RepositoryVersionLanguage, pk=request.data.get("update_id") + RepositoryVersionLanguage, pk=request.data.get("repository_version") ) repository.start_training( @@ -117,7 +123,7 @@ def start_training(self, request, **kwargs): return Response( { "language": repository.language, - "update_id": repository.id, + "repository_version": repository.id, "repository_uuid": str(repository.repository_version.repository.uuid), "intent": repository.intents, "algorithm": repository.algorithm, @@ -141,7 +147,7 @@ def get_entities_and_labels(self, request, **kwargs): try: examples = request.data.get("examples") label_examples_query = request.data.get("label_examples_query") - update_id = request.data.get("update_id") + update_id = request.data.get("repository_version") except ValueError: raise exceptions.NotFound() @@ -206,7 +212,7 @@ def get_entities_and_labels(self, request, **kwargs): def train_fail(self, request, **kwargs): check_auth(request) repository = get_object_or_404( - RepositoryVersionLanguage, pk=request.data.get("update_id") + RepositoryVersionLanguage, pk=request.data.get("repository_version") ) repository.train_fail() return Response({}) @@ -215,7 +221,7 @@ def train_fail(self, request, **kwargs): def training_log(self, request, **kwargs): check_auth(request) repository = get_object_or_404( - RepositoryVersionLanguage, pk=request.data.get("update_id") + RepositoryVersionLanguage, pk=request.data.get("repository_version") ) repository.training_log = request.data.get("training_log") repository.save(update_fields=["training_log"]) @@ -233,16 +239,21 @@ def retrieve(self, request, *args, **kwargs): repository = repository_authorization.repository language = request.query_params.get("language") + repository_version = request.query_params.get("repository_version") if language == "None" or language is None: language = str(repository.language) - update = repository.last_trained_update(language) + if repository_version: + update = repository.get_specific_version_id(repository_version, language) + else: + update = repository.last_trained_update(language) + try: return Response( { - "update": False if update is None else True, - "update_id": update.id, + "version": False if update is None else True, + "repository_version": update.id, "language": update.language, } ) @@ -253,7 +264,7 @@ def retrieve(self, request, *args, **kwargs): def repository_entity(self, request, **kwargs): check_auth(request) repository_update = get_object_or_404( - RepositoryVersionLanguage, pk=request.query_params.get("update_id") + RepositoryVersionLanguage, pk=request.query_params.get("repository_version") ) repository_entity = get_object_or_404( RepositoryEntity, @@ -309,7 +320,7 @@ def retrieve(self, request, *args, **kwargs): def evaluations(self, request, **kwargs): check_auth(request) repository_update = get_object_or_404( - RepositoryVersionLanguage, pk=request.query_params.get("update_id") + RepositoryVersionLanguage, pk=request.query_params.get("repository_version") ) evaluations = repository_update.repository_version.repository.evaluations( language=repository_update.language @@ -345,7 +356,7 @@ def evaluations(self, request, **kwargs): def evaluate_results(self, request, **kwargs): check_auth(request) repository_update = get_object_or_404( - RepositoryVersionLanguage, pk=request.data.get("update_id") + RepositoryVersionLanguage, pk=request.data.get("repository_version") ) intents_score = RepositoryEvaluateResultScore.objects.create( @@ -418,7 +429,7 @@ def evaluate_results_score(self, request, **kwargs): ) repository_update = get_object_or_404( - RepositoryVersionLanguage, pk=request.data.get("update_id") + RepositoryVersionLanguage, pk=request.data.get("repository_version") ) entity_score = RepositoryEvaluateResultScore.objects.create( @@ -484,8 +495,8 @@ def retrieve(self, request, *args, **kwargs): return Response( { - "update_id": update.id, - "repository_uuid": update.repository.uuid, + "version_id": update.id, + "repository_uuid": update.repository_version.repository.uuid, "bot_data": str(bot_data), "from_aws": aws, } diff --git a/bothub/api/v2/repository/serializers.py b/bothub/api/v2/repository/serializers.py index d0cb8f61..8a269687 100644 --- a/bothub/api/v2/repository/serializers.py +++ b/bothub/api/v2/repository/serializers.py @@ -579,6 +579,11 @@ def create(self, validated_data): class AnalyzeTextSerializer(serializers.Serializer): language = serializers.ChoiceField(LANGUAGE_CHOICES, required=True) text = serializers.CharField(allow_blank=False) + repository_version = serializers.IntegerField(required=False) + + +class TrainSerializer(serializers.Serializer): + repository_version = serializers.IntegerField(required=False) class EvaluateSerializer(serializers.Serializer): diff --git a/bothub/api/v2/repository/views.py b/bothub/api/v2/repository/views.py index 96612410..196a409c 100644 --- a/bothub/api/v2/repository/views.py +++ b/bothub/api/v2/repository/views.py @@ -36,7 +36,7 @@ from .permissions import RepositoryAdminManagerAuthorization from .permissions import RepositoryExamplePermission from .permissions import RepositoryPermission -from .serializers import AnalyzeTextSerializer +from .serializers import AnalyzeTextSerializer, TrainSerializer from .serializers import EvaluateSerializer from .serializers import RepositoryAuthorizationRoleSerializer from .serializers import RepositoryAuthorizationSerializer @@ -87,7 +87,7 @@ def languagesstatus(self, request, **kwargs): @action( detail=True, - methods=["GET"], + methods=["POST"], url_name="repository-train", lookup_fields=["uuid"], ) @@ -95,13 +95,15 @@ def train(self, request, **kwargs): """ Train current update using Bothub NLP service """ - if self.lookup_field not in kwargs: - return Response({}, status=403) repository = self.get_object() user_authorization = repository.get_user_authorization(request.user) + serializer = TrainSerializer(data=request.data) # pragma: no cover + serializer.is_valid(raise_exception=True) # pragma: no cover if not user_authorization.can_write: raise PermissionDenied() - request = repository.request_nlp_train(user_authorization) # pragma: no cover + request = repository.request_nlp_train( + user_authorization, serializer.data + ) # pragma: no cover if request.status_code != status.HTTP_200_OK: # pragma: no cover raise APIException( # pragma: no cover {"status_code": request.status_code}, code=request.status_code diff --git a/bothub/api/v2/tests/test_nlp.py b/bothub/api/v2/tests/test_nlp.py index 348454f5..eb185437 100644 --- a/bothub/api/v2/tests/test_nlp.py +++ b/bothub/api/v2/tests/test_nlp.py @@ -53,7 +53,7 @@ def request(self, token): "/v2/repository/nlp/authorization/train/start_training/", json.dumps( { - "update_id": self.repository_version_language.pk, + "repository_version": self.repository_version_language.pk, "by_user": self.user.pk, } ), @@ -108,7 +108,7 @@ def request(self, token): authorization_header = {"HTTP_AUTHORIZATION": "Bearer {}".format(token)} request = self.factory.post( "/v2/repository/nlp/authorization/train/train_fail/", - json.dumps({"update_id": self.repository_version_language.pk}), + json.dumps({"repository_version": self.repository_version_language.pk}), content_type="application/json", **authorization_header ) diff --git a/bothub/api/v2/tests/test_repository.py b/bothub/api/v2/tests/test_repository.py index b70d95da..43c515e8 100644 --- a/bothub/api/v2/tests/test_repository.py +++ b/bothub/api/v2/tests/test_repository.py @@ -1708,19 +1708,22 @@ def setUp(self): language=languages.LANGUAGE_EN, ) - def request(self, repository, token): + def request(self, repository, token, data): authorization_header = {"HTTP_AUTHORIZATION": "Token {}".format(token.key)} - request = self.factory.get( + request = self.factory.post( "/v2/repository/repository-info/{}/train/".format(str(repository.uuid)), + data, **authorization_header, ) - response = RepositoryViewSet.as_view({"get": "train"})(request) + response = RepositoryViewSet.as_view({"post": "train"})( + request, uuid=repository.uuid + ) response.render() content_data = json.loads(response.content) return (response, content_data) def test_permission_denied(self): - response, content_data = self.request(self.repository, self.user_token) + response, content_data = self.request(self.repository, self.user_token, {}) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/bothub/common/management/commands/fill_db_using_fake_data.py b/bothub/common/management/commands/fill_db_using_fake_data.py index 404f9672..9c664ec6 100644 --- a/bothub/common/management/commands/fill_db_using_fake_data.py +++ b/bothub/common/management/commands/fill_db_using_fake_data.py @@ -58,6 +58,8 @@ def handle(self, *args, **kwargs): repository_1.categories.add(categories[1]) repository_1.categories.add(categories[3]) + repository_1.current_version() + repository_2 = Repository.objects.create( owner=user, name="Repository 2", @@ -67,6 +69,8 @@ def handle(self, *args, **kwargs): repository_2.categories.add(categories[0]) repository_2.categories.add(categories[2]) + repository_2.current_version() + for x in range(3, 46): new_repository = Repository.objects.create( owner=user, @@ -75,6 +79,7 @@ def handle(self, *args, **kwargs): language=languages.LANGUAGE_EN, ) new_repository.categories.add(random.choice(categories)) + new_repository.current_version() # Examples diff --git a/bothub/common/models.py b/bothub/common/models.py index c18669d7..27afda6d 100644 --- a/bothub/common/models.py +++ b/bothub/common/models.py @@ -56,10 +56,8 @@ def publics(self): return self.filter(is_private=False) def order_by_relevance(self): - return self.annotate(examples_sum=models.Sum("versions")).order_by( - "-versions__repositoryversionlanguage__total_training_end", - "-examples_sum", - "-created_at", + return self.order_by( + "-versions__repositoryversionlanguage__total_training_end", "-created_at" ) def supported_language(self, language): @@ -218,15 +216,33 @@ def save( self.__use_name_entities = self.use_name_entities self.__use_analyze_char = self.use_analyze_char - def request_nlp_train(self, user_authorization): + def request_nlp_train(self, user_authorization, data): try: # pragma: no cover - r = requests.post( # pragma: no cover - "{}train/".format( - self.nlp_server if self.nlp_server else settings.BOTHUB_NLP_BASE_URL - ), - data={}, - headers={"Authorization": "Bearer {}".format(user_authorization.uuid)}, - ) + print(data.get("repository_version")) + if data.get("repository_version"): + r = requests.post( # pragma: no cover + "{}train/".format( + self.nlp_server + if self.nlp_server + else settings.BOTHUB_NLP_BASE_URL + ), + data={"repository_version": data.get("repository_version")}, + headers={ + "Authorization": "Bearer {}".format(user_authorization.uuid) + }, + ) + else: + r = requests.post( # pragma: no cover + "{}train/".format( + self.nlp_server + if self.nlp_server + else settings.BOTHUB_NLP_BASE_URL + ), + data={}, + headers={ + "Authorization": "Bearer {}".format(user_authorization.uuid) + }, + ) return r # pragma: no cover except requests.exceptions.ConnectionError: # pragma: no cover raise APIException( # pragma: no cover @@ -236,13 +252,34 @@ def request_nlp_train(self, user_authorization): def request_nlp_analyze(self, user_authorization, data): try: # pragma: no cover - r = requests.post( # pragma: no cover - "{}parse/".format( - self.nlp_server if self.nlp_server else settings.BOTHUB_NLP_BASE_URL - ), - data={"text": data.get("text"), "language": data.get("language")}, - headers={"Authorization": "Bearer {}".format(user_authorization.uuid)}, - ) + if data.get("repository_version"): + r = requests.post( # pragma: no cover + "{}parse/".format( + self.nlp_server + if self.nlp_server + else settings.BOTHUB_NLP_BASE_URL + ), + data={ + "text": data.get("text"), + "language": data.get("language"), + "repository_version": data.get("repository_version"), + }, + headers={ + "Authorization": "Bearer {}".format(user_authorization.uuid) + }, + ) + else: + r = requests.post( # pragma: no cover + "{}parse/".format( + self.nlp_server + if self.nlp_server + else settings.BOTHUB_NLP_BASE_URL + ), + data={"text": data.get("text"), "language": data.get("language")}, + headers={ + "Authorization": "Bearer {}".format(user_authorization.uuid) + }, + ) return r # pragma: no cover except requests.exceptions.ConnectionError: # pragma: no cover raise APIException( # pragma: no cover @@ -499,6 +536,15 @@ def get_specific_version_language(self, language=None): query = query.filter(language=language) return query.first() + def get_specific_version_id(self, repository_version, language=None): + query = RepositoryVersionLanguage.objects.filter( + repository_version__repository=self, + repository_version__pk=repository_version, + ) + if language: + query = query.filter(language=language) + return query.first() + def get_user_authorization(self, user): if user.is_anonymous: return RepositoryAuthorization(repository=self) @@ -632,50 +678,13 @@ def requirements_to_train(self): @property def ready_for_train(self): - previous_update = None if len(self.requirements_to_train) > 0: return False - if self.training_end_at: - previous = self.repository_version.repository.versions.filter( - repositoryversionlanguage__language=self.language, - last_update__gte=self.training_end_at, - created_by__isnull=False, - ).first() - else: - previous = self.repository_version.repository.versions.filter( - repositoryversionlanguage__language=self.language, - created_by__isnull=False, - ).first() - - if previous: - previous_update = previous.version_languages.filter( - language=self.language - ).first() - - if previous_update: - if ( - previous_update.algorithm - != self.repository_version.repository.algorithm - ): - return True - if ( - previous_update.use_competing_intents - is not self.repository_version.repository.use_competing_intents - ): - return True - if ( - previous_update.use_name_entities - is not self.repository_version.repository.use_name_entities - ): - return True - if ( - previous_update.use_analyze_char - is not self.repository_version.repository.use_analyze_char - ): - return True - if previous_update.failed_at: - return True + if ( + self.training_end_at is not None + ) and self.training_end_at > self.last_update: + return True if not self.added.exists() and not self.translated_added.exists(): return False