Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 Run tests against py38 #480

Merged
merged 14 commits into from
Jul 2, 2023
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.8","3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ If you're a Authx user, you probably want either Authx V0.9 [Documentation](http

## Features 🔧

- [x] Support Python 3.9+.
- [x] Support Python 3.8+.
- [x] Multiple customizable authentication backend:
- [x] JWT authentication backend included
- [x] JWT encoding/decoding for application authentication
Expand Down
20 changes: 10 additions & 10 deletions authx/_internal/_logger.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
import logging
import sys
import traceback
from typing import Optional

log = logging.getLogger("authx")


def get_logger():
def get_logger() -> logging.Logger:
return log


def set_log_level(level: str):
def set_log_level(level: str) -> logging.Logger:
log.setLevel(level)
return log


def log_debug(msg: str, loc: str = None, method: str = None):
def log_debug(msg: str, loc: Optional[str] = None, method: Optional[str] = None) -> None:
log.debug(msg=_build_log_msg(msg=msg, loc=loc, method=method))


def log_info(msg: str, loc: str = None, method: str = None):
def log_info(msg: str, loc: Optional[str] = None, method: Optional[str] = None) -> None:
log.info(msg=_build_log_msg(msg=msg, loc=loc, method=method))


def log_error(msg: str, loc: str = None, method: str = None, e: Exception = None):
def log_error(msg: str, loc: Optional[str] = None, method: Optional[str] = None, e: Optional[Exception] = None) -> None:
log.error(msg=_build_log_msg(msg=msg, loc=loc, method=method))
log.error(msg=f"{traceback.print_exc()}")
return
log.error(f"{traceback.format_exc()}")


def _build_log_msg(msg: str, loc: str = None, method: str = None):
def _build_log_msg(msg: str, loc: Optional[str] = None, method: Optional[str] = None) -> str:
log_str = f"{msg}"
if loc:
log_str = f"[{loc}] {log_str}"
Expand All @@ -38,7 +40,5 @@ def _build_log_msg(msg: str, loc: str = None, method: str = None):


logging.basicConfig()
default_factory = logging.getLogRecordFactory()
log = logging.getLogger("authx")
logging.StreamHandler(sys.stdout)
log.setLevel(logging.DEBUG)
72 changes: 32 additions & 40 deletions authx/_internal/_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import datetime as date
import datetime as dt
import uuid
from datetime import datetime, timedelta
from datetime import timezone as tz
from typing import Union

import pytz
from dateutil import parser as dateutil_parser
from dateutil.relativedelta import relativedelta
from pytz import timezone
from pytz import BaseTzInfo, timezone

from authx.types import Numeric

Expand All @@ -26,38 +27,23 @@
utc = timezone("UTC")


def get_now() -> date.datetime:
"""Returns the current UTC datetime

Returns:
datetime.datetime: Current datetime (UTC)
"""
return date.datetime.now(tz=date.timezone.utc)
def get_now() -> dt.datetime:
return dt.datetime.now(tz=dt.timezone.utc)


def get_now_ts() -> Numeric:
"""Returns the current UTC datetime as timestamp (float)

Returns:
Numeric: Current datetime (UTC)
"""
return get_now().timestamp()


def get_uuid() -> str:
"""Generates a Universe Unique Identifier v4 (UUIDv4)

Returns:
str: unique identifier
"""
return str(uuid.uuid4())


def time_diff(dt1: datetime, dt2: datetime) -> relativedelta:
return relativedelta(dt1, dt2)


def to_UTC(event_timestamp: datetime, tz: pytz.timezone = utc):
def to_UTC(event_timestamp: Union[datetime, str], tz: pytz.timezone = utc) -> datetime: # type: ignore
if isinstance(event_timestamp, datetime):
dt = event_timestamp
else:
Expand All @@ -66,98 +52,104 @@ def to_UTC(event_timestamp: datetime, tz: pytz.timezone = utc):
return dt.astimezone(tz)


def to_UTC_without_tz(event_timestamp: str, format: str = "%Y-%m-%d %H:%M:%S.%f"):
def to_UTC_without_tz(event_timestamp: str, format: str = "%Y-%m-%d %H:%M:%S.%f") -> str:
dt = datetime.strptime(event_timestamp, format)
return dt.astimezone(tz.utc).strftime(format)


def beginning_of_day(dt: datetime):
def beginning_of_day(dt: datetime) -> datetime:
dt = dt.replace(minute=0, hour=0, second=0, microsecond=0)
return dt


def end_of_day(dt: datetime):
def end_of_day(dt: datetime) -> datetime:
dt = dt.replace(minute=59, hour=23, second=59, microsecond=999999)
return dt


def minutes_ago(dt: datetime, days: int = 0, hours: int = 0, minutes: int = 1, seconds: int = 0):
def minutes_ago(dt: datetime, days: int = 0, hours: int = 0, minutes: int = 1, seconds: int = 0) -> datetime:
return dt - timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)


def minutes_after(dt: datetime, days: int = 0, hours: int = 0, minutes: int = 1, seconds: int = 0):
def minutes_after(dt: datetime, days: int = 0, hours: int = 0, minutes: int = 1, seconds: int = 0) -> datetime:
return dt + timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)


def hours_ago(dt: datetime, days: int = 0, hours: int = 1, minutes: int = 0, seconds: int = 0):
def hours_ago(dt: datetime, days: int = 0, hours: int = 1, minutes: int = 0, seconds: int = 0) -> datetime:
return dt - timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)


def days_ago(dt: datetime, days: int = 1, hours: int = 0, minutes: int = 0, seconds: int = 0):
def days_ago(dt: datetime, days: int = 1, hours: int = 0, minutes: int = 0, seconds: int = 0) -> datetime:
past = dt - timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
if dt.tzinfo:
past = past.replace(tzinfo=dt.tzinfo)
return past


def months_ago(dt: datetime, months: int = 1):
def months_ago(dt: datetime, months: int = 1) -> datetime:
return dt - relativedelta(months=months)


def months_after(dt: datetime, months: int = 1):
def months_after(dt: datetime, months: int = 1) -> datetime:
return dt + relativedelta(months=months)


def years_ago(dt: datetime, years: int = 1):
def years_ago(dt: datetime, years: int = 1) -> datetime:
past = dt - relativedelta(years=years)
if dt.tzinfo:
past = past.replace(tzinfo=past.tzinfo)
return past


def days_after(dt: datetime, days: int = 1, hours: int = 0, minutes: int = 0, seconds: int = 0):
def days_after(dt: datetime, days: int = 1, hours: int = 0, minutes: int = 0, seconds: int = 0) -> datetime:
future = dt + timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
if dt.tzinfo:
future = future.replace(tzinfo=dt.tzinfo)
return future


def is_today(dt: datetime):
def is_today(dt: datetime) -> bool:
return dt.astimezone(utc).day == datetime.now().astimezone(utc).day


def is_yesterday(dt: datetime):
def is_yesterday(dt: datetime) -> bool:
return dt.astimezone(utc).day == days_ago(datetime.now().astimezone(utc)).day


def is_tomorrow(dt: datetime):
def is_tomorrow(dt: datetime) -> bool:
return dt.astimezone(utc).day == days_after(datetime.now().astimezone(utc)).day


def IST_time():
def IST_time() -> datetime:
return datetime.now().astimezone(utc)


def tz_now(tz: pytz = utc):
def tz_now(tz: BaseTzInfo = utc) -> datetime:
dt = datetime.utcnow()
return dt.replace(tzinfo=tz)


def tz_from_iso(dt: str, to_tz: pytz = utc, format="%Y-%m-%dT%H:%M:%S.%f%z") -> datetime:
def tz_from_iso(dt: str, to_tz: BaseTzInfo = utc, format: str = "%Y-%m-%dT%H:%M:%S.%f%z") -> datetime:
date_time = datetime.strptime(dt, format)
return date_time.astimezone(to_tz)


def start_of_week(dt: str, to_tz: pytz = utc) -> datetime:
def start_of_week(dt: Union[str, datetime], to_tz: BaseTzInfo = utc) -> datetime:
if isinstance(dt, str):
dt = datetime.strptime(dt, "%Y-%m-%d")
day_of_the_week = dt.weekday()
return days_ago(dt=dt, days=day_of_the_week)


def end_of_week(dt: str, to_tz: pytz = utc) -> datetime:
def end_of_week(dt: Union[str, datetime], to_tz: BaseTzInfo = utc) -> datetime:
if isinstance(dt, str):
dt = datetime.strptime(dt, "%Y-%m-%d")
_start_of_week = start_of_week(dt=dt, to_tz=to_tz)
return days_after(dt=_start_of_week, days=6)


def end_of_last_week(dt: str, to_tz: pytz = utc):
def end_of_last_week(dt: Union[str, datetime], to_tz: BaseTzInfo = utc) -> datetime:
if isinstance(dt, str):
dt = datetime.strptime(dt, "%Y-%m-%d")
_end_of_current_week = end_of_week(dt=dt, to_tz=to_tz)
return days_ago(dt=_end_of_current_week, days=7)
4 changes: 2 additions & 2 deletions authx/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class AuthXDependency(Generic[T]):
def __init__(
self,
_from: "AuthX[T]",
request: Request = None,
response: Response = None,
request: Optional[Request] = None,
response: Optional[Response] = None,
) -> None:
self._response = response
self._request = request
Expand Down
30 changes: 16 additions & 14 deletions authx/external/Oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
logger = logging.getLogger(__name__)


def _get_keys(url_or_keys):
def _get_keys(url_or_keys: typing.Union[str, typing.Any]) -> typing.Any:
if not isinstance(url_or_keys, str) or not url_or_keys.startswith("https://"):
return url_or_keys
logger.info("Getting jwk from %s...", url_or_keys)
with urllib.request.urlopen(url_or_keys) as f:
return json.loads(f.read().decode())


def _validate_provider(provider_name, provider):
def _validate_provider(provider_name: str, provider: typing.Dict[str, typing.Any]) -> None:
mandatory_keys = {"issuer", "keys", "audience"}
if not mandatory_keys.issubset(set(provider)):
raise ValueError(
Expand All @@ -41,10 +41,10 @@ class MiddlewareOauth2:
def __init__(
self,
app: ASGIApp,
providers,
public_paths=None,
get_keys=None,
key_refresh_minutes=None,
providers: typing.Dict[str, typing.Dict[str, typing.Any]],
public_paths: typing.Optional[typing.Set[str]] = None,
get_keys: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
key_refresh_minutes: typing.Optional[typing.Union[int, typing.Dict[str, int]]] = None,
) -> None:
self._app = app
for provider in providers:
Expand All @@ -63,10 +63,10 @@ def __init__(
self._timeout = {provider: datetime.timedelta(minutes=key_refresh_minutes) for provider in providers}

# cached attribute and respective timeout
self._last_retrieval = {}
self._keys = {}
self._last_retrieval: typing.Dict[str, datetime.datetime] = {}
self._keys: typing.Dict[str, typing.Any] = {}

def _provider_claims(self, provider, token):
def _provider_claims(self, provider: str, token: str) -> typing.Any:
issuer = self._providers[provider]["issuer"]
audience = self._providers[provider]["audience"]
logger.debug(
Expand All @@ -86,7 +86,7 @@ def _provider_claims(self, provider, token):
return decoded

def claims(self, token: str) -> typing.Tuple[str, typing.Dict[str, str]]:
errors = {}
errors: typing.Dict[str, str] = {}
for provider in self._providers:
try:
return provider, self._provider_claims(provider, token)
Expand All @@ -104,7 +104,9 @@ def claims(self, token: str) -> typing.Tuple[str, typing.Dict[str, str]]:
raise InvalidToken(errors)

@staticmethod
async def _prepare_error_response(message, status_code, scope, receive, send):
async def _prepare_error_response(
message: str, status_code: int, scope: Scope, receive: Receive, send: Send
) -> None:
if scope["type"] == "http":
response = JSONResponse(
{"message": message},
Expand Down Expand Up @@ -151,7 +153,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

return await self._app(scope, receive, send)

def _should_refresh(self, provider: str):
def _should_refresh(self, provider: str) -> bool:
if self._keys.get(provider, None) is None:
# we do not even have the key (first time) => should refresh
return True
Expand All @@ -161,11 +163,11 @@ def _should_refresh(self, provider: str):
# have the key and have timeout => check if we passed the timeout
return self._last_retrieval[provider] + self._timeout[provider] < datetime.datetime.utcnow()

def _refresh_keys(self, provider: str):
def _refresh_keys(self, provider: str) -> None:
self._keys[provider] = self._get_keys(self._providers[provider]["keys"])
self._last_retrieval[provider] = datetime.datetime.utcnow()

def _provider_keys(self, provider: str):
def _provider_keys(self, provider: str) -> typing.Any:
if self._should_refresh(provider):
self._refresh_keys(provider)
return self._keys[provider]
8 changes: 4 additions & 4 deletions authx/external/cache/expiry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, Optional

import pytz

Expand All @@ -8,10 +8,10 @@
class HTTPExpiry:
@staticmethod
async def get_ttl(
ttl_in_seconds: int = None,
ttl_in_seconds: Optional[int] = None,
end_of_day: bool = True,
end_of_week: bool = None,
ttl_func: Callable = None,
end_of_week: Optional[bool] = None,
ttl_func: Optional[Callable] = None,
tz: pytz.timezone = utc,
) -> int:
"""Return the seconds till expiry of cache. Defaults to one day"""
Expand Down
10 changes: 7 additions & 3 deletions authx/external/cache/keys.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, List
from typing import Any, List, Optional

from authx._internal import HTTPCache


class HTTPKeys:
@staticmethod
async def generate_key(key: str, config: HTTPCache, obj: Any = None, obj_attr: str = None) -> str:
async def generate_key(
key: str, config: HTTPCache, obj: Optional[Any] = None, obj_attr: Optional[str] = None
) -> str:
"""Converts a raw key passed by the user to a key with an parameter passed by the user and associates a namespace"""

_key = (
Expand All @@ -19,7 +21,9 @@ async def generate_key(key: str, config: HTTPCache, obj: Any = None, obj_attr: s
return await HTTPKeys.generate_namespaced_key(key=_key, config=config)

@staticmethod
async def generate_keys(keys: List[str], config: HTTPCache, obj: Any = None, obj_attr: str = None) -> List[str]:
async def generate_keys(
keys: List[str], config: HTTPCache, obj: Optional[Any] = None, obj_attr: Optional[str] = None
) -> List[str]:
"""Converts a list of raw keys passed by the user to a list of namespaced keys with an optional parameter if passed"""
return [await HTTPKeys.generate_key(key=k, config=config, obj=obj, obj_attr=obj_attr) for k in keys]

Expand Down
Loading