Skip to content

Commit

Permalink
feat(connector): implement authorization code
Browse files Browse the repository at this point in the history
  • Loading branch information
dovahcrow committed Oct 1, 2020
1 parent 3019027 commit e6838ca
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 55 deletions.
194 changes: 162 additions & 32 deletions dataprep/connector/schema/defs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
"""Strong typed schema definition."""
from __future__ import annotations

import http.server
import random
import socketserver
import string
from base64 import b64encode
from enum import Enum
from pathlib import Path
from threading import Thread
from time import time
from typing import Any, Dict, Optional, Union

from typing import Any, Dict, List, Optional, Union
from urllib.parse import parse_qs, urlparse
import socket
import requests
from pydantic import Field

from ...utils import is_notebook
from .base import BaseDef, BaseDefT


# pylint: disable=missing-class-docstring,missing-function-docstring
FILE_PATH: Path = Path(__file__).resolve().parent

with open(f"{FILE_PATH}/oauth2.html", "rb") as f:
OAUTH2_TEMPLATE = f.read()


def get_random_string(length: int) -> str:
letters = string.ascii_lowercase
result_str = "".join(random.choice(letters) for _ in range(length))
return result_str


class OffsetPaginationDef(BaseDef):
Expand Down Expand Up @@ -67,9 +84,126 @@ class FieldDef(BaseDef):
FieldDefUnion = Union[FieldDef, bool, str] # Put bool before str


class OAuth2AuthorizationDef(BaseDef):
class TCPServer(socketserver.TCPServer):
def server_bind(self) -> None:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address)


class HTTPServer(http.server.BaseHTTPRequestHandler):
def do_GET(self) -> None: # pylint: disable=invalid-name
# pylint: disable=protected-access
query = urlparse(self.path).query
parsed = parse_qs(query)

(code,) = parsed["code"]
(state,) = parsed["state"]

self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(OAUTH2_TEMPLATE)

Thread(target=self.server.shutdown).start()

# Hacky way to pass data out
self.server._oauth2_code = code # type: ignore
self.server._oauth2_state = state # type: ignore

def log_request(
self, code: Union[str, int] = "-", size: Union[str, int] = "-"
) -> None:
pass


class OAuth2AuthorizationCodeAuthorizationDef(BaseDef):
type: str = Field("OAuth2", const=True)
grant_type: str = Field("AuthorizationCode", const=True)
scopes: List[str]
auth_server_url: str
token_server_url: str

def build(
self,
req_data: Dict[str, Any],
params: Dict[str, Any],
storage: Optional[Dict[str, Any]] = None,
) -> None:
if storage is None:
raise ValueError("storage is required for OAuth2")

if "access_token" not in storage or storage.get("expires_at", 0) < time():
port = params.get("port", 9999)
code = self._auth(params["client_id"], port)

ckey = params["client_id"]
csecret = params["client_secret"]
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()

resp: Dict[str, Any] = requests.post(
self.token_server_url,
headers={"Authorization": f"Basic {b64cred}"},
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": f"http://localhost:{port}/",
},
).json()

if resp["token_type"].lower() != "bearer":
raise RuntimeError("token_type is not bearer")

access_token = resp["access_token"]
storage["access_token"] = access_token
if "expires_in" in resp:
storage["expires_at"] = (
time() + resp["expires_in"] - 60
) # 60 seconds grace period to avoid clock lag

req_data["headers"]["Authorization"] = f"Bearer {storage['access_token']}"

def _auth(self, client_id: str, port: int = 9999) -> str:
# pylint: disable=protected-access

state = get_random_string(23)
scope = ",".join(self.scopes)
authurl = (
f"{self.auth_server_url}?"
f"response_type=code&client_id={client_id}&"
f"redirect_uri=http%3A%2F%2Flocalhost:{port}/&scope={scope}&"
f"state={state}"
)
if is_notebook():
from IPython.display import ( # pylint: disable=import-outside-toplevel
Javascript,
display,
)

display(Javascript(f"window.open('{authurl}');"))
else:
import webbrowser # pylint: disable=import-outside-toplevel

webbrowser.open_new_tab(authurl)

with TCPServer(("", 9999), HTTPServer) as httpd:
try:
httpd.serve_forever()
finally:
httpd.server_close()

if httpd._oauth2_state != state: # type: ignore
raise RuntimeError("OAuth2 state does not match")

if httpd._oauth2_code is None: # type: ignore
raise RuntimeError(
"OAuth2 authorization code auth failed, no code acquired."
)
return httpd._oauth2_code # type: ignore


class OAuth2ClientCredentialsAuthorizationDef(BaseDef):
type: str = Field("OAuth2", const=True)
grant_type: str
grant_type: str = Field("ClientCredentials", const=True)
token_server_url: str

def build(
Expand All @@ -81,32 +215,27 @@ def build(
if storage is None:
raise ValueError("storage is required for OAuth2")

if self.grant_type == "ClientCredentials":
if "access_token" not in storage or storage.get("expires_at", 0) < time():
# Not yet authorized
ckey = params["client_id"]
csecret = params["client_secret"]
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()
resp: Dict[str, Any] = requests.post(
self.token_server_url,
headers={"Authorization": f"Basic {b64cred}"},
data={"grant_type": "client_credentials"},
).json()
if resp["token_type"].lower() != "bearer":
raise RuntimeError("token_type is not bearer")

access_token = resp["access_token"]
storage["access_token"] = access_token
if "expires_in" in resp:
storage["expires_at"] = (
time() + resp["expires_in"] - 60
) # 60 seconds grace period to avoid clock lag

req_data["headers"]["Authorization"] = f"Bearer {storage['access_token']}"

# TODO: handle auto refresh
elif self.grant_type == "AuthorizationCode":
raise NotImplementedError
if "access_token" not in storage or storage.get("expires_at", 0) < time():
# Not yet authorized
ckey = params["client_id"]
csecret = params["client_secret"]
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()
resp: Dict[str, Any] = requests.post(
self.token_server_url,
headers={"Authorization": f"Basic {b64cred}"},
data={"grant_type": "client_credentials"},
).json()
if resp["token_type"].lower() != "bearer":
raise RuntimeError("token_type is not bearer")

access_token = resp["access_token"]
storage["access_token"] = access_token
if "expires_in" in resp:
storage["expires_at"] = (
time() + resp["expires_in"] - 60
) # 60 seconds grace period to avoid clock lag

req_data["headers"]["Authorization"] = f"Bearer {storage['access_token']}"


class QueryParamAuthorizationDef(BaseDef):
Expand Down Expand Up @@ -156,7 +285,8 @@ def build(


AuthorizationDef = Union[
OAuth2AuthorizationDef,
OAuth2ClientCredentialsAuthorizationDef,
OAuth2AuthorizationCodeAuthorizationDef,
QueryParamAuthorizationDef,
BearerAuthorizationDef,
HeaderAuthorizationDef,
Expand Down
12 changes: 12 additions & 0 deletions dataprep/connector/schema/oauth2.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<html>

<head>
<title>OAuth2 Success</title>
</head>

<body>
<p>OAuth2 Success. This window can be closed now.</p>
<script>window.close();</script>
</body>

</html>
2 changes: 1 addition & 1 deletion dataprep/eda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Text,
)
from .missing import compute_missing, plot_missing, render_missing
from .utils import is_notebook
from ..utils import is_notebook

__all__ = [
"plot_correlation",
Expand Down
2 changes: 1 addition & 1 deletion dataprep/eda/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from bokeh.embed import components
from bokeh.resources import INLINE
from jinja2 import Environment, PackageLoader
from .utils import is_notebook
from ..utils import is_notebook

output_notebook(INLINE, hide_banner=True) # for offline usage

Expand Down
2 changes: 1 addition & 1 deletion dataprep/eda/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dask.callbacks import Callback

from .utils import is_notebook
from ..utils import is_notebook

if is_notebook():
from tqdm.notebook import tqdm
Expand Down
2 changes: 1 addition & 1 deletion dataprep/eda/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from bokeh.resources import CDN
from jinja2 import Template

from .utils import is_notebook
from ..utils import is_notebook

INLINE_TEMPLATE = Template(
"""
Expand Down
20 changes: 1 addition & 19 deletions dataprep/eda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
import logging
from math import ceil
from typing import Any, Union
from typing import Union

import dask.dataframe as dd
import numpy as np
Expand All @@ -13,24 +13,6 @@
LOGGER = logging.getLogger(__name__)


def is_notebook() -> Any:
"""
:return: whether it is running in jupyter notebook
"""
try:
# pytype: disable=import-error
from IPython import get_ipython # pylint: disable=import-outside-toplevel

# pytype: enable=import-error

shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True
return False
except (NameError, ImportError):
return False


def to_dask(df: Union[pd.DataFrame, dd.DataFrame]) -> dd.DataFrame:
"""Convert a dataframe to a dask dataframe."""
if isinstance(df, dd.DataFrame):
Expand Down
20 changes: 20 additions & 0 deletions dataprep/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Utility functions used by the whole library."""
from typing import Any


def is_notebook() -> Any:
"""
:return: whether it is running in jupyter notebook
"""
try:
# pytype: disable=import-error
from IPython import get_ipython # pylint: disable=import-outside-toplevel

# pytype: enable=import-error

shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True
return False
except (NameError, ImportError):
return False

0 comments on commit e6838ca

Please sign in to comment.