Skip to content

Commit

Permalink
feat(data_connector): full text search _q to be a universal parameter
Browse files Browse the repository at this point in the history
closes #183

closes #183
  • Loading branch information
pallavibharadwaj authored and dovahcrow committed Sep 28, 2020
1 parent 7086c36 commit 947584a
Show file tree
Hide file tree
Showing 8 changed files with 756 additions and 60 deletions.
112 changes: 85 additions & 27 deletions dataprep/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import sys
from asyncio import as_completed
from pathlib import Path
from typing import Any, Awaitable, Dict, List, Optional, Union, Tuple
from aiohttp.client_reqrep import ClientResponse
from jsonpath_ng import parse as jparse
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from warnings import warn

import pandas as pd
from aiohttp import ClientSession
from aiohttp.client_reqrep import ClientResponse
from jinja2 import Environment, StrictUndefined, Template, UndefinedError
from jsonpath_ng import parse as jparse

from .config_manager import config_directory, ensure_config
from .errors import InvalidParameterError, RequestError, UniversalParameterOverridden
Expand All @@ -21,10 +23,10 @@
ConfigDef,
FieldDefUnion,
OffsetPaginationDef,
SeekPaginationDef,
PagePaginationDef,
TokenPaginationDef,
SeekPaginationDef,
TokenLocation,
TokenPaginationDef,
)
from .throttler import OrderedThrottler, ThrottleSession

Expand Down Expand Up @@ -108,6 +110,7 @@ async def query( # pylint: disable=too-many-locals
self,
table: str,
*,
_q: Optional[str] = None,
_auth: Optional[Dict[str, Any]] = None,
_count: Optional[int] = None,
**where: Any,
Expand All @@ -119,6 +122,8 @@ async def query( # pylint: disable=too-many-locals
----------
table
The table name.
_q: Optional[str] = None
Search string to be matched in the response.
_auth: Optional[Dict[str, Any]] = None
The parameters for authentication. Usually the authentication parameters
should be defined when instantiating the Connector. In case some tables have different
Expand All @@ -134,12 +139,13 @@ async def query( # pylint: disable=too-many-locals
if key not in allowed_params:
raise InvalidParameterError(key)

return await self._query_imp(table, where, _auth=_auth, _count=_count)
return await self._query_imp(table, where, _auth=_auth, _q=_q, _count=_count)

@property
def table_names(self) -> List[str]:
"""
Return all the names of the available tables in a list.
Note
----
We abstract each website as a database containing several tables.
Expand All @@ -148,9 +154,8 @@ def table_names(self) -> List[str]:
return list(self._impdb.tables.keys())

def info(self) -> None:
"""
Show the basic information and provide guidance for users to issue queries.
"""
"""Show the basic information and provide guidance for users
to issue queries."""

# get info
tbs: Dict[str, Any] = {}
Expand Down Expand Up @@ -216,6 +221,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
*,
_auth: Optional[Dict[str, Any]] = None,
_count: Optional[int] = None,
_q: Optional[str] = None,
) -> pd.DataFrame:
if table not in self._impdb.tables:
raise ValueError(f"No such table {table} in {self._impdb.name}")
Expand All @@ -238,7 +244,12 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m

if reqconf.pagination is None or _count is None:
df = await self._fetch(
itable, kwargs, _client=client, _throttler=throttler, _auth=_auth,
itable,
kwargs,
_client=client,
_throttler=throttler,
_auth=_auth,
_q=_q,
)
return df

Expand All @@ -263,6 +274,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
_throttler=throttler,
_page=i,
_auth=_auth,
_q=_q,
_limit=count,
_anchor=last_id - 1,
)
Expand All @@ -274,7 +286,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
# The API returns empty for this page, maybe we've reached the end
break

cid = df.columns.get_loc(pagdef.seek_id)
cid = df.columns.get_loc(pagdef.seek_id) # type: ignore
last_id = int(df.iloc[-1, cid]) - 1 # type: ignore

dfs.append(df)
Expand All @@ -291,6 +303,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
_throttler=throttler,
_page=i,
_auth=_auth,
_q=_q,
_limit=count,
_anchor=next_token,
_raw=True,
Expand Down Expand Up @@ -326,6 +339,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
_page=i,
_allowed_page=allowed_page,
_auth=_auth,
_q=_q,
_limit=count,
_anchor=anchor,
)
Expand Down Expand Up @@ -355,6 +369,7 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
_limit: Optional[int] = None,
_anchor: Optional[Any] = None,
_auth: Optional[Dict[str, Any]] = None,
_q: Optional[str] = None,
_raw: bool = False,
) -> Union[Optional[pd.DataFrame], Tuple[Optional[pd.DataFrame], ClientResponse]]:

Expand All @@ -371,12 +386,6 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
if reqdef.authorization is not None:
reqdef.authorization.build(req_data, _auth or self._auth, self._storage)

for key in ["headers", "params", "cookies"]:
field_def = getattr(reqdef, key, None)
if field_def is not None:
instantiated_fields = populate_field(field_def, self._jenv, merged_vars)
req_data[key].update(**instantiated_fields)

if reqdef.body is not None:
# TODO: do we support binary body?
instantiated_fields = populate_field(
Expand Down Expand Up @@ -414,6 +423,39 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
if _anchor is not None:
req_data["params"][anchor] = _anchor

if _q is not None:
if reqdef.search is None:
raise ValueError(
"_q specified but the API does not support custom search."
)

searchdef = reqdef.search
search_key = searchdef.key

if search_key in req_data["params"]:
raise UniversalParameterOverridden(search_key, "_q")
req_data["params"][search_key] = _q

for key in ["headers", "params", "cookies"]:
field_def = getattr(reqdef, key, None)
if field_def is not None:
instantiated_fields = populate_field(
field_def, self._jenv, merged_vars,
)
for ikey in instantiated_fields:
if ikey in req_data[key]:
warn(
f"Query parameter {ikey}={req_data[key][ikey]}"
" is overriden by {ikey}={instantiated_fields[ikey]}",
RuntimeWarning,
)
req_data[key].update(**instantiated_fields)

for key in ["headers", "params", "cookies"]:
field_def = getattr(reqdef, key, None)
if field_def is not None:
validate_fields(field_def, req_data[key])

await _throttler.acquire(_page)

if _allowed_page is not None and int(_allowed_page) <= _page:
Expand Down Expand Up @@ -445,21 +487,37 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
return df


def validate_fields(fields: Dict[str, FieldDefUnion], data: Dict[str, Any]) -> None:
"""Check required fields are provided."""

for key, def_ in fields.items():
from_key, to_key = key, key

if isinstance(def_, bool):
required = def_
if required and to_key not in data:
raise KeyError(f"'{from_key}' is required but not provided")
elif isinstance(def_, str):
pass
else:
to_key = def_.to_key or to_key
from_key = def_.from_key or from_key
required = def_.required
if required and to_key not in data:
raise KeyError(f"'{from_key}' is required but not provided")


def populate_field( # pylint: disable=too-many-branches
fields: Dict[str, FieldDefUnion], jenv: Environment, params: Dict[str, Any]
fields: Dict[str, FieldDefUnion], jenv: Environment, params: Dict[str, Any],
) -> Dict[str, str]:
"""Populate a dict based on the fields definition and provided vars."""

ret: Dict[str, str] = {}

for key, def_ in fields.items():
from_key, to_key = key, key

if isinstance(def_, bool):
required = def_
value = params.get(from_key)
if value is None and required:
raise KeyError(from_key)
remove_if_empty = False
elif isinstance(def_, str):
# is a template
Expand All @@ -473,10 +531,7 @@ def populate_field( # pylint: disable=too-many-branches
from_key = def_.from_key or from_key

if template is None:
required = def_.required
value = params.get(from_key)
if value is None and required:
raise KeyError(from_key)
else:
tmplt = jenv.from_string(template)
try:
Expand All @@ -486,9 +541,12 @@ def populate_field( # pylint: disable=too-many-branches

if value is not None:
str_value = str(value)
if not (remove_if_empty and not str_value):
if not remove_if_empty or str_value:
if to_key in ret:
print(f"Param {key} conflicting with {to_key}", file=sys.stderr)
warn(
f"{to_key}={ret[to_key]} overriden by {to_key}={str_value}",
RuntimeWarning,
)
ret[to_key] = str_value
continue
return ret
3 changes: 1 addition & 2 deletions dataprep/connector/generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""ConfigGenerator"""
from .generator import ConfigGenerator
from .ui import ConfigGeneratorUI

__all__ = ["ConfigGenerator", "ConfigGeneratorUI"]
__all__ = ["ConfigGenerator"]
16 changes: 8 additions & 8 deletions dataprep/connector/generator/generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
"""This module implements the generation of connector configuration files."""

from dataprep.connector.schema.base import BaseDef
from pathlib import Path
from typing import Any, Dict, Optional, Union
from urllib.parse import parse_qs, urlparse, urlunparse

import requests
from dataprep.connector.schema.base import BaseDef

from ..schema import (
AuthorizationDef,
ConfigDef,
PaginationDef,
)
from ..schema import AuthorizationDef, ConfigDef, PaginationDef
from .state import ConfigState
from .table import gen_schema_from_path, search_table_path

Expand Down Expand Up @@ -43,9 +39,9 @@ def __init__(self, config: Optional[Dict[str, Any]] = None) -> None:
self.config = ConfigState(ConfigDef(**config))
self.storage = {}

def add_example(
def add_example( # pylint: disable=too-many-locals
self, example: Dict[str, Any]
) -> None: # pylint: disable=too-many-locals
) -> None:
"""Add an example to the generator. The example
should be in the dictionary format.
Expand Down Expand Up @@ -161,8 +157,12 @@ def _create_config(req: Dict[str, Any]) -> ConfigDef:


class AuthUnion(BaseDef):
"""Helper class for parsing authorization."""

val: AuthorizationDef


class PageUnion(BaseDef):
"""Helper class for parsing pagination."""

val: PaginationDef
11 changes: 9 additions & 2 deletions dataprep/connector/schema/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,23 @@ class Method(str, Enum):
PUT = "PUT"


class SearchDef(BaseDef):
key: str


class RequestDef(BaseDef):
url: str
method: Method
authorization: Optional[AuthorizationDef]

headers: Optional[Dict[str, FieldDefUnion]]
params: Dict[str, FieldDefUnion]
pagination: Optional[PaginationDef]
body: Optional[BodyDef]
cookies: Optional[Dict[str, FieldDefUnion]]

authorization: Optional[AuthorizationDef]
pagination: Optional[PaginationDef]
search: Optional[SearchDef]


class SchemaFieldDef(BaseDef):
target: str
Expand Down
4 changes: 2 additions & 2 deletions examples/DataConnector_DBLP.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"outputs": [],
"source": [
"from dataprep.connector import Connector\n",
"dc = Connector(\"./DataConnectorConfigs/DBLP\")"
"dc = Connector('dblp')"
]
},
{
Expand Down Expand Up @@ -167,7 +167,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
"version": "3.7.4"
}
},
"nbformat": 4,
Expand Down
8 changes: 5 additions & 3 deletions examples/DataConnector_Twitter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
"### Connector.info\n",
"The info method gives information and guidelines of using the connector. There are 3 sections in the response and they are table, parameters and examples.\n",
">1. Table - The table(s) being accessed.\n",
">2. Parameters - Identifies which parameters can be used to call the method. For Twitter, q is a required parameter that acts as a filter. \n",
">2. Parameters - Identifies which parameters can be used to call the method. For Twitter, _q is a required parameter that acts as a filter. \n",
">3. Examples - Shows how you can call the methods in the Connector class."
]
},
Expand Down Expand Up @@ -249,7 +249,9 @@
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -848,7 +850,7 @@
}
],
"source": [
"df = dc.query(\"tweets\", q=\"covid-19\", count=50)\n",
"df = dc.query(\"tweets\", _q=\"covid-19\", count=50)\n",
"df"
]
},
Expand Down

0 comments on commit 947584a

Please sign in to comment.