Skip to content

Commit

Permalink
feat(connector): use pydantic for schema
Browse files Browse the repository at this point in the history
  • Loading branch information
dovahcrow committed Sep 24, 2020
1 parent 500ce13 commit dff0844
Show file tree
Hide file tree
Showing 10 changed files with 810 additions and 766 deletions.
142 changes: 102 additions & 40 deletions dataprep/connector/connector.py
Expand Up @@ -6,17 +6,18 @@
import sys
from asyncio import as_completed
from pathlib import Path
from typing import Any, Awaitable, Dict, List, Optional, Union
from typing import Any, Awaitable, Dict, List, Optional, Union, cast

import pandas as pd
from aiohttp import ClientSession
from jinja2 import Environment, StrictUndefined, Template
from jinja2 import Environment, StrictUndefined, Template, UndefinedError

from ..errors import UnreachableError
from .config_manager import config_directory, ensure_config
from .errors import RequestError, UniversalParameterOverridden, InvalidParameterError
from .implicit_database import ImplicitDatabase, ImplicitTable
from .int_ref import IntRef
from .schema import ConfigDef, FieldDef
from .throttler import OrderedThrottler, ThrottleSession

INFO_TEMPLATE = Template(
Expand All @@ -38,8 +39,7 @@


class Connector:
"""
This is the main class of the connector component.
"""This is the main class of the connector component.
Initialize Connector class as the example code.
Parameters
Expand All @@ -61,8 +61,11 @@ class Connector:
"""

_impdb: ImplicitDatabase
# Varibles that used across different queries, can be overriden by query
_vars: Dict[str, Any]
_auth: Dict[str, Any]
# storage for authorization
_storage: Dict[str, Any]
_concurrency: int
_jenv: Environment

Expand All @@ -88,6 +91,7 @@ def __init__(

self._vars = kwargs
self._auth = _auth or {}
self._storage = {}
self._concurrency = _concurrency
self._jenv = Environment(undefined=StrictUndefined)
self._throttler = OrderedThrottler(_concurrency)
Expand Down Expand Up @@ -116,7 +120,7 @@ async def query( # pylint: disable=too-many-locals
**where
The additional parameters required for the query.
"""
allowed_params = self._impdb.tables[table].config["request"]["params"]
allowed_params = self._impdb.tables[table].config.request.params
for key in where:
if key not in allowed_params:
raise InvalidParameterError(key)
Expand All @@ -142,12 +146,12 @@ def info(self) -> None:
# get info
tbs: Dict[str, Any] = {}
for cur_table in self._impdb.tables:
table_config_content = self._impdb.tables[cur_table].config
table_config_content: ConfigDef = self._impdb.tables[cur_table].config
params_required = []
params_optional = []
example_query_fields = []
count = 1
for k, val in table_config_content["request"]["params"].items():
for k, val in table_config_content.request.params.items():
if isinstance(val, bool) and val:
params_required.append(k)
example_query_fields.append(f"""{k}="word{count}\"""")
Expand All @@ -167,31 +171,33 @@ def info(self) -> None:
)

def show_schema(self, table_name: str) -> pd.DataFrame:
"""
This method shows the schema of the table that will be returned,
"""This method shows the schema of the table that will be returned,
so that the user knows what information to expect.
Parameters
----------
table_name
The table name.
Returns
-------
pd.DataFrame
The returned data's schema.
Note
----
The schema is defined in the configuration file.
The user can either use the default one or change it by editing the configuration file.
"""
print(f"table: {table_name}")
table_config_content = self._impdb.tables[table_name].config
schema = table_config_content["response"]["schema"]
table_config_content: ConfigDef = self._impdb.tables[table_name].config
schema = table_config_content.response.schema_
new_schema_dict: Dict[str, List[Any]] = {}
new_schema_dict["column_name"] = []
new_schema_dict["data_type"] = []
for k in schema.keys():
new_schema_dict["column_name"].append(k)
new_schema_dict["data_type"].append(schema[k]["type"])
new_schema_dict["data_type"].append(schema[k].type)
return pd.DataFrame.from_dict(new_schema_dict)

async def _query_imp( # pylint: disable=too-many-locals,too-many-branches
Expand All @@ -206,7 +212,9 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches
raise ValueError(f"No such table {table} in {self._impdb.name}")

itable = self._impdb.tables[table]
if itable.pag_params is None and _count is not None:
reqconf = itable.config.request

if reqconf.pagination is None and _count is not None:
print(
f"ignoring _count since {table} has no pagination settings",
file=sys.stderr,
Expand All @@ -216,22 +224,23 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches
raise RuntimeError("_count should be larger than 0")

async with ClientSession() as client:

throttler = self._throttler.session()

if itable.pag_params is None or _count is None:
if reqconf.pagination is None or _count is None:
df = await self._fetch(
itable, kwargs, _client=client, _throttler=throttler, _auth=_auth,
)
return df

pag_type = itable.pag_params.type
pagdef = reqconf.pagination

# pagination begins
max_per_page = itable.pag_params.max_count
max_per_page = pagdef.max_count
total = _count
n_page = math.ceil(total / max_per_page)

if pag_type == "seek":
if pagdef.type == "seek":
last_id = 0
dfs = []
# No way to parallelize for seek type
Expand All @@ -255,10 +264,10 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches
# The API returns empty for this page, maybe we've reached the end
break

last_id = int(df[itable.pag_params.seek_id][len(df) - 1]) - 1
last_id = int(df[pagdef.seek_id][len(df) - 1]) - 1
dfs.append(df)

elif pag_type == "offset":
elif pagdef.type == "offset":
resps_coros = []
allowed_page = IntRef(n_page)
for i in range(n_page):
Expand Down Expand Up @@ -290,7 +299,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches

return df

async def _fetch( # pylint: disable=too-many-locals,too-many-branches
async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-statements
self,
table: ImplicitTable,
kwargs: Dict[str, Any],
Expand All @@ -306,50 +315,53 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches
if (_limit is None) != (_offset is None):
raise ValueError("_limit and _offset should both be None or not None")

method = table.method
url = table.url
reqdef = table.config.request
method = reqdef.method
url = reqdef.url
req_data: Dict[str, Dict[str, Any]] = {
"headers": {},
"params": {},
"cookies": {},
}
merged_vars = {**self._vars, **kwargs}

if table.authorization is not None:
table.authorization.build(req_data, _auth or self._auth)
if reqdef.authorization is not None:
reqdef.authorization.build(req_data, _auth or self._auth, self._storage)

for key in ["headers", "params", "cookies"]:
if getattr(table, key) is not None:
instantiated_fields = getattr(table, key).populate(
self._jenv, merged_vars
)
field_def = getattr(reqdef, key, None)
if field_def is not None:
instantiated_fields = populate_field(field_def, self._jenv, merged_vars)
req_data[key].update(**instantiated_fields)

if table.body is not None:
if reqdef.body is not None:
# TODO: do we support binary body?
instantiated_fields = table.body.populate(self._jenv, merged_vars)
if table.body_ctype == "application/x-www-form-urlencoded":
instantiated_fields = populate_field(
reqdef.body.content, self._jenv, merged_vars
)
if reqdef.body.ctype == "application/x-www-form-urlencoded":
req_data["data"] = instantiated_fields
elif table.body_ctype == "application/json":
elif reqdef.body.ctype == "application/json":
req_data["json"] = instantiated_fields
else:
raise NotImplementedError(table.body_ctype)
raise NotImplementedError(reqdef.body.ctype)

if table.pag_params is not None and _limit is not None:
pag_type = table.pag_params.type
limit_key = table.pag_params.limit_key
if reqdef.pagination is not None and _limit is not None:
pagdef = reqdef.pagination
pag_type = pagdef.type
limit_key = pagdef.limit_key
if pag_type == "seek":
if table.pag_params.seek_key is None:
if pagdef.seek_key is None:
raise ValueError(
"pagination type is seek but no seek_key set in the configuration file."
)
offset_key = table.pag_params.seek_key
offset_key = pagdef.seek_key
elif pag_type == "offset":
if table.pag_params.offset_key is None:
if pagdef.offset_key is None:
raise ValueError(
"pagination type is offset but no offset_key set in the configuration file."
)
offset_key = table.pag_params.offset_key
offset_key = pagdef.offset_key
else:
raise UnreachableError()

Expand Down Expand Up @@ -387,3 +399,53 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches
return None
else:
return df


def populate_field( # pylint: disable=too-many-branches
fields: Dict[str, FieldDef], jenv: Environment, params: Dict[str, Any]
) -> Dict[str, str]:
"""Populate a dict based on the fields definition and provided vars."""

ret: Dict[str, str] = {}

for key, def_ in fields.items():
from_key, to_key = key, key

if isinstance(def_, bool):
required = def_
value = params.get(from_key)
if value is None and required:
raise KeyError(from_key)
remove_if_empty = False
elif isinstance(def_, str):
# is a template
template: Optional[str] = def_
tmplt = jenv.from_string(cast(str, template))
value = tmplt.render(**params)
remove_if_empty = False
else:
template = def_.template
remove_if_empty = def_.remove_if_empty
to_key = def_.to_key or to_key
from_key = def_.from_key or from_key

if template is None:
required = def_.required
value = params.get(from_key)
if value is None and required:
raise KeyError(from_key)
else:
tmplt = jenv.from_string(template)
try:
value = tmplt.render(**params)
except UndefinedError:
value = "" # This empty string will be removed if `remove_if_empty` is True

if value is not None:
str_value = str(value)
if not (remove_if_empty and not str_value):
if to_key in ret:
print(f"Param {key} conflicting with {to_key}", file=sys.stderr)
ret[to_key] = str_value
continue
return ret

0 comments on commit dff0844

Please sign in to comment.