Skip to content

Commit

Permalink
Add type hints to altair.utils.schemapi
Browse files Browse the repository at this point in the history
  • Loading branch information
thewchan committed Jun 8, 2022
1 parent 0217b2e commit da41251
Showing 1 changed file with 60 additions and 52 deletions.
112 changes: 60 additions & 52 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import contextlib
import inspect
import json
from __future__ import annotations
from typing import Any, Callable, DefaultDict, Dict, FrozenSet, Iterable, Iterator, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Type, TypeAlias, TypeVar, Union

import jsonschema
from jsonschema.validators import RefResolver
import numpy as np
import numpy.typing as npt
import pandas as pd


Expand All @@ -16,20 +20,22 @@
# Individual schema classes can override this by setting the
# class-level _class_is_valid_at_instantiation attribute to False
DEBUG_MODE = True
GenericT = TypeVar("GenericT")
T = TypeVar("T", bound="SchemaBase")
AltairObj: TypeAlias = Union["SchemaBase", List[Any], Tuple[Any], npt.NDArray[Any], Dict[Any, Any], np.number[Any], pd.Timestamp, np.datetime64]


def enable_debug_mode():
def enable_debug_mode() -> None:
global DEBUG_MODE
DEBUG_MODE = True


def disable_debug_mode():
def disable_debug_mode() -> None:
global DEBUG_MODE
DEBUG_MODE = True


@contextlib.contextmanager
def debug_mode(arg):
def debug_mode(arg: bool) -> Iterator[Optional[bool]]:
global DEBUG_MODE
original = DEBUG_MODE
DEBUG_MODE = arg
Expand All @@ -39,9 +45,9 @@ def debug_mode(arg):
DEBUG_MODE = original


def _subclasses(cls):
def _subclasses(cls: Type[GenericT]) -> Iterable[Type[GenericT]]:
"""Breadth-first sequence of all classes which inherit from cls."""
seen = set()
seen: Set[Type[GenericT]] = set()
current_set = {cls}
while current_set:
seen |= current_set
Expand All @@ -50,7 +56,7 @@ def _subclasses(cls):
yield cls


def _todict(obj, validate, context):
def _todict(obj: AltairObj, validate: Union[bool, str], context: Optional[Dict[Any, Any]]) -> Union[AltairObj, Dict[Any, Any], str, float]:
"""Convert an object to a dict representation."""
if isinstance(obj, SchemaBase):
return obj.to_dict(validate=validate, context=context)
Expand All @@ -72,9 +78,9 @@ def _todict(obj, validate, context):
return obj


def _resolve_references(schema, root=None):
def _resolve_references(schema: Mapping[str, Any], root: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
"""Resolve schema references."""
resolver = jsonschema.RefResolver.from_schema(root or schema)
resolver: RefResolver = jsonschema.RefResolver.from_schema(root or schema)
while "$ref" in schema:
with resolver.resolving(schema["$ref"]) as resolved:
schema = resolved
Expand All @@ -84,12 +90,12 @@ def _resolve_references(schema, root=None):
class SchemaValidationError(jsonschema.ValidationError):
"""A wrapper for jsonschema.ValidationError with friendlier traceback"""

def __init__(self, obj, err):
def __init__(self, obj: Any, err: jsonschema.ValidationError) -> None:
super(SchemaValidationError, self).__init__(**self._get_contents(err))
self.obj = obj

@staticmethod
def _get_contents(err):
def _get_contents(err: jsonschema.ValidationError) -> Dict[str, Any]:
"""Get a dictionary with the contents of a ValidationError"""
try:
# works in jsonschema 2.3 or later
Expand All @@ -104,7 +110,7 @@ def _get_contents(err):
contents = {key: getattr(err, key) for key in spec.args[1:]}
return contents

def __str__(self):
def __str__(self) -> str:
cls = self.obj.__class__
schema_path = ["{}.{}".format(cls.__module__, cls.__name__)]
schema_path.extend(self.schema_path)
Expand All @@ -128,12 +134,12 @@ class UndefinedType(object):

__instance = None

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: Any, **kwargs: Any) -> UndefinedType:
if not isinstance(cls.__instance, cls):
cls.__instance = object.__new__(cls, *args, **kwargs)
return cls.__instance

def __repr__(self):
def __repr__(self) -> Literal["Undefined"]:
return "Undefined"


Expand All @@ -147,12 +153,12 @@ class SchemaBase(object):
the _rootschema class attribute) which is used for validation.
"""

_schema = None
_rootschema = None
_schema: Optional[Mapping[str, Any]] = None
_rootschema: Optional[Mapping[str, Any]] = None
_class_is_valid_at_instantiation = True
_validator = jsonschema.Draft7Validator

def __init__(self, *args, **kwds):
def __init__(self, *args: Any, **kwds: Any) -> None:
# Two valid options for initialization, which should be handled by
# derived classes:
# - a single arg with no kwds, for, e.g. {'type': 'string'}
Expand All @@ -176,7 +182,7 @@ def __init__(self, *args, **kwds):
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy(self, deep=True, ignore=()):
def copy(self: T, deep: Union[bool, Sequence[Any]] = True, ignore: Sequence[Any] = ()) -> T:
"""Return a copy of the object
Parameters
Expand All @@ -191,7 +197,7 @@ def copy(self, deep=True, ignore=()):
only stored by reference.
"""

def _shallow_copy(obj):
def _shallow_copy(obj: T) -> T:
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, list):
Expand All @@ -201,7 +207,7 @@ def _shallow_copy(obj):
else:
return obj

def _deep_copy(obj, ignore=()):
def _deep_copy(obj: T, ignore: Sequence[Any] = ()) -> T:
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
Expand All @@ -221,7 +227,7 @@ def _deep_copy(obj, ignore=()):
return obj

try:
deep = list(deep)
deep: List[Any] = list(deep)
except TypeError:
deep_is_list = False
else:
Expand All @@ -237,36 +243,36 @@ def _deep_copy(obj, ignore=()):
copy[attr] = _shallow_copy(copy._get(attr))
return copy

def _get(self, attr, default=Undefined):
def _get(self, attr: str, default: Any = Undefined) -> Any:
"""Get an attribute, returning default if not present."""
attr = self._kwds.get(attr, Undefined)
attr: Any = self._kwds.get(attr, Undefined)
if attr is Undefined:
attr = default
return attr

def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
# reminder: getattr is called after the normal lookups
if attr == "_kwds":
raise AttributeError()
if attr in self._kwds:
return self._kwds[attr]
else:
try:
_getattr = super(SchemaBase, self).__getattr__
_getattr: Callable[[str], Any] = super(SchemaBase, self).__getattr__
except AttributeError:
_getattr = super(SchemaBase, self).__getattribute__
return _getattr(attr)

def __setattr__(self, item, val):
def __setattr__(self, item: str, val: Any) -> None:
self._kwds[item] = val

def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
return self._kwds[item]

def __setitem__(self, item, val):
def __setitem__(self, item: str, val: Any) -> None:
self._kwds[item] = val

def __repr__(self):
def __repr__(self) -> str:
if self._kwds:
args = (
"{}: {!r}".format(key, val)
Expand All @@ -280,14 +286,14 @@ def __repr__(self):
else:
return "{}({!r})".format(self.__class__.__name__, self._args[0])

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return (
type(self) is type(other)
and self._args == other._args
and self._kwds == other._kwds
)

def to_dict(self, validate=True, ignore=None, context=None):
def to_dict(self, validate: Union[bool, str] = True, ignore: Optional[Sequence[str]] = None, context: Optional[Dict[Any, Any]] = None) -> Union[AltairObj, Dict[Any, Any], str, float]:
"""Return a dictionary representation of the object
Parameters
Expand Down Expand Up @@ -341,8 +347,8 @@ def to_dict(self, validate=True, ignore=None, context=None):
return result

def to_json(
self, validate=True, ignore=[], context={}, indent=2, sort_keys=True, **kwargs
):
self, validate: Union[bool, str] = True, ignore: Sequence[str] = (), context: Optional[Dict[Any, Any]] = None, indent: int = 2, sort_keys: bool = True, **kwargs: Any
) -> str:
"""Emit the JSON representation for this object as a string.
Parameters
Expand Down Expand Up @@ -370,16 +376,18 @@ def to_json(
spec : string
The JSON specification of the chart object.
"""
if not context:
context = {}
dct = self.to_dict(validate=validate, ignore=ignore, context=context)
return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs)

@classmethod
def _default_wrapper_classes(cls):
def _default_wrapper_classes(cls) -> Iterable[Type[SchemaBase]]:
"""Return the set of classes used within cls.from_dict()"""
return _subclasses(SchemaBase)

@classmethod
def from_dict(cls, dct, validate=True, _wrapper_classes=None):
def from_dict(cls: Type[T], dct: Mapping[str, Any], validate: bool = True, _wrapper_classes: Optional[Union[Iterable[Type[SchemaBase]], Iterable[Type[T]]]] = None) -> T:
"""Construct class from a dictionary representation
Parameters
Expand Down Expand Up @@ -411,7 +419,7 @@ def from_dict(cls, dct, validate=True, _wrapper_classes=None):
return converter.from_dict(dct, cls)

@classmethod
def from_json(cls, json_string, validate=True, **kwargs):
def from_json(cls: Type[T], json_string: str, validate: bool = True, **kwargs: Any) -> T:
"""Instantiate the object from a valid JSON string
Parameters
Expand All @@ -432,42 +440,42 @@ def from_json(cls, json_string, validate=True, **kwargs):
return cls.from_dict(dct, validate=validate)

@classmethod
def validate(cls, instance, schema=None):
def validate(cls: Type[T], instance: Any, schema: Optional[Mapping[str, Any]] = None) -> None:
"""
Validate the instance against the class schema in the context of the
rootschema.
"""
if schema is None:
schema = cls._schema
resolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
resolver: RefResolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
return jsonschema.validate(
instance, schema, cls=cls._validator, resolver=resolver
)

@classmethod
def resolve_references(cls, schema=None):
def resolve_references(cls, schema: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
"""Resolve references in the context of this object's schema or root schema."""
return _resolve_references(
schema=(schema or cls._schema),
root=(cls._rootschema or cls._schema or schema),
)

@classmethod
def validate_property(cls, name, value, schema=None):
def validate_property(cls, name: str, value: Any, schema: Optional[Mapping[str, Any]] = None) -> None:
"""
Validate a property against property schema in the context of the
rootschema
"""
value = _todict(value, validate=False, context={})
props = cls.resolve_references(schema or cls._schema).get("properties", {})
resolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
resolver: RefResolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
return jsonschema.validate(value, props.get(name, {}), resolver=resolver)

def __dir__(self):
def __dir__(self) -> List[str]:
return list(self._kwds.keys())


def _passthrough(*args, **kwds):
def _passthrough(*args: Any, **kwds: Any) -> Union[Any, Dict[str, Any]]:
return args[0] if args else kwds


Expand All @@ -481,16 +489,16 @@ class _FromDict(object):

_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id")

def __init__(self, class_list):
def __init__(self, class_list: Iterable[Type[Any]]) -> None:
# Create a mapping of a schema hash to a list of matching classes
# This lets us quickly determine the correct class to construct
self.class_dict = collections.defaultdict(list)
self.class_dict: DefaultDict[Any, List[Any] ]= collections.defaultdict(list)
for cls in class_list:
if cls._schema is not None:
self.class_dict[self.hash_schema(cls._schema)].append(cls)

@classmethod
def hash_schema(cls, schema, use_json=True):
def hash_schema(cls, schema: Mapping[str, Any], use_json: bool = True) -> int:
"""
Compute a python hash for a nested dictionary which
properly handles dicts, lists, sets, and tuples.
Expand All @@ -513,7 +521,7 @@ def hash_schema(cls, schema, use_json=True):
return hash(s)
else:

def _freeze(val):
def _freeze(val: Union[Dict[Any, Any], Set[Any], Sequence[Any], GenericT]) -> Union[FrozenSet[Any], Tuple[Any], GenericT]:
if isinstance(val, dict):
return frozenset((k, _freeze(v)) for k, v in val.items())
elif isinstance(val, set):
Expand All @@ -526,8 +534,8 @@ def _freeze(val):
return hash(_freeze(schema))

def from_dict(
self, dct, cls=None, schema=None, rootschema=None, default_class=_passthrough
):
self, dct: Union[Mapping[str, Any], SchemaBase], cls: Optional[Type[T]] = None, schema: Optional[Mapping[str, Any]] = None, rootschema: Optional[Mapping[str, Any]] = None, default_class: Any = _passthrough
) -> Union[T, SchemaBase]:
"""Construct an object from a dict representation"""
if (schema is None) == (cls is None):
raise ValueError("Must provide either cls or schema, but not both.")
Expand All @@ -553,7 +561,7 @@ def from_dict(
if "anyOf" in schema or "oneOf" in schema:
schemas = schema.get("anyOf", []) + schema.get("oneOf", [])
for possible_schema in schemas:
resolver = jsonschema.RefResolver.from_schema(rootschema)
resolver: RefResolver = jsonschema.RefResolver.from_schema(rootschema)
try:
jsonschema.validate(dct, possible_schema, resolver=resolver)
except jsonschema.ValidationError:
Expand All @@ -569,7 +577,7 @@ def from_dict(
if isinstance(dct, dict):
# TODO: handle schemas for additionalProperties/patternProperties
props = schema.get("properties", {})
kwds = {}
kwds: Mapping[str, Any] = {}
for key, val in dct.items():
if key in props:
val = self.from_dict(val, schema=props[key], rootschema=rootschema)
Expand All @@ -578,7 +586,7 @@ def from_dict(

elif isinstance(dct, list):
item_schema = schema.get("items", {})
dct = [
dct: List[Union[T, SchemaBase]] = [
self.from_dict(val, schema=item_schema, rootschema=rootschema)
for val in dct
]
Expand Down

0 comments on commit da41251

Please sign in to comment.