Skip to content

Commit

Permalink
fix(Monitoring): Serializable log middleware (#1908)
Browse files Browse the repository at this point in the history
* fix(Monitoring): Argilla log middleware can be serialized using pickle

Also, review and format the class name

* refactor: Align http middleware to the base monitor

* fix: Monitor initialization

* tests: Fix tests

* feat: #1908 revamped usage ASGI Middleware top allow for GET

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: #1908 re-added default mappers for test usability and backward compatibilty

* feat: #1908 updated asgi test

* fix-tests: converted to singular record mapper

* Update src/argilla/monitoring/asgi.py

Co-authored-by: Francisco Aranda <francis@argilla.io>

* feat: added additional tests for GET and PUT requests logging

* tests: resolved failing get request endpoint

* tests: request get endoint

* tests: added PUT endpoint

* tests: add monitor for GET prediction endpoint

* refactor: removed obsolete statement to simplify URL

Co-authored-by: Francisco Aranda <francisco@recogn.ai>
Co-authored-by: david <david.m.berenstein@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Nov 22, 2022
1 parent caeb7d4 commit 53a57f7
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 113 deletions.
134 changes: 83 additions & 51 deletions src/argilla/monitoring/asgi.py
Expand Up @@ -12,15 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import json
import logging
import re
import threading
from queue import Queue
from typing import Any, Callable, Dict, List, Optional, Tuple

from argilla.monitoring.base import BaseMonitor

try:
import starlette
except ModuleNotFoundError:
Expand All @@ -34,7 +33,6 @@
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.types import Message, Receive

import argilla
from argilla.client.models import (
Record,
TextClassificationRecord,
Expand Down Expand Up @@ -101,67 +99,92 @@ async def cached_receive() -> Message:
return self._receive


class argillaLogHTTPMiddleware(BaseHTTPMiddleware):
class ArgillaLogHTTPMiddleware(BaseHTTPMiddleware):
"""An standard starlette middleware that enables argilla logs for http prediction requests"""

def __init__(
self,
api_endpoint: str,
dataset: str,
records_mapper: Optional[Callable[[dict, dict], Record]] = None,
records_mapper: Optional[Callable[[dict, dict], Record]],
sample_rate: float = 1.0,
log_interval: float = 1.0,
agent: Optional[str] = None,
tags: Dict[str, str] = None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
BaseHTTPMiddleware.__init__(self, *args, **kwargs)

self._endpoint = api_endpoint
self._dataset = dataset
self._records_mapper = records_mapper or text_classification_mapper
self._queue = Queue()
self._worker_task = threading.Thread(
target=self.__worker__, name=argillaLogHTTPMiddleware.__name__, daemon=True
self._records_mapper = records_mapper
self._monitor_cfg = dict(
dataset=dataset,
sample_rate=sample_rate,
log_interval=log_interval,
agent=agent,
tags=tags,
)
self._monitor: Optional[BaseMonitor] = None

def init(self):
if self._monitor:
return
from argilla.client.api import active_api

self._monitor = BaseMonitor(
self,
api=active_api(),
**self._monitor_cfg,
)
self._worker_task.start()
self._monitor._prepare_log_data = self._prepare_argilla_data

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
if self._endpoint != request.url.path: # Filtering endpoint path
return await call_next(request)
self.init()

content_type = request.headers.get("Content-type", None)
if "application/json" not in content_type:
if self._endpoint != request.url.path: # Filtering endpoint path
return await call_next(request)

cached_request = CachedJsonRequest(
scope=request.scope, receive=request.receive, send=request._send
scope=request.scope,
receive=request.receive,
send=request._send,
)
# Must read body before call_next
inputs = await cached_request.json()
response: Response = await call_next(cached_request)

try:
# Must obtain input parameters from request
if cached_request.method in ["POST", "PUT"]:
content_type = request.headers.get("Content-type", None)
if content_type is None:
if "application/json" not in content_type:
return await call_next(request)
inputs = await cached_request.json()
elif cached_request.method == "GET":
inputs = cached_request.query_params._dict
else:
raise NotImplementedError(
"Only request methods POST, PUT and GET are implemented."
)

# Must obtain response from request
response: Response = await call_next(cached_request)
if (
not isinstance(response, (JSONResponse, StreamingResponse))
or response.status_code >= 400
):
return response

new_response, outputs = await self._extract_response_content(response)
self._queue.put_nowait((inputs, outputs, str(request.url)))
self._monitor.send_records(inputs=inputs, outputs=outputs)
return new_response
except Exception as ex:
_logger.error("Cannot log to argilla", exc_info=ex)
return response

def __worker__(self):
while True:
try:
inputs, outputs, url = self._queue.get()
self._log_to_argilla(inputs, outputs, url)
except Exception as ex:
# Run thread FOREVER!!!
_logger.error("Error sending records to argilla", exc_info=ex)
finally:
self._queue.task_done()
return await call_next(request)

async def _extract_response_content(
self, response: Response
Expand All @@ -183,21 +206,30 @@ async def _extract_response_content(
body = response.body
return new_response, json.loads(body)

def _log_to_argilla(
self,
inputs: List[Dict[str, Any]],
outputs: List[Dict[str, Any]],
url: str,
**tags
def _prepare_argilla_data(
self, inputs: List[Dict[str, Any]], outputs: List[Dict[str, Any]], **tags
):
records = [
record
for _inputs, _outputs in zip(inputs, outputs)
for record in [self._records_mapper(_inputs, _outputs)]
if record
]

if records:
for r in records:
r.prediction_agent = url
argilla.log(records=records, name=self._dataset, tags=tags)
# using the base monitor, we only need to provide the input data to the rg.log function
# and the monitor will handle the sample rate, queue and argilla interaction
try:
records = self._records_mapper(inputs, outputs)
assert records, ValueError(
"The records_mapper returns and empty record list."
)
if not isinstance(records, list):
records = [records]
except Exception as ex:
records = []
_logger.error(
"Cannot log to argilla. Error in records mapper.", exc_info=ex
)

for record in records:
if self._monitor.agent is not None and not record.prediction_agent:
record.prediction_agent = self._monitor.agent

return dict(
records=records or [],
name=self._dataset,
tags=tags,
)
2 changes: 1 addition & 1 deletion src/argilla/monitoring/base.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import atexit
import dataclasses
import logging
import random
import threading
Expand Down Expand Up @@ -221,7 +222,6 @@ def send_records(self, *args, **kwargs):

def _get_consumer_by_dataset(self, dataset: str):
if dataset not in self._consumers:
print(f"NOT FOUND {dataset}")
self._consumers[dataset] = self._create_consumer(dataset)
return self._consumers[dataset]

Expand Down

0 comments on commit 53a57f7

Please sign in to comment.