Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,9 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/

# Lcov
coverage.lcov
lcov.info
htmlcov/
.coverage/
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ FloatTensor["features/2"] # Half the features dimension

- `min(a,b)` Minimum of two expressions
- `max(a,b)` Maximum of two expressions

> [!WARNING]
> While nested function calls like `min(max(a,b),c)` are supported,
> combining function calls with other operators in the same expression
> (e.g., `min(1,batch)+max(2,channels)`) is not supported to simplify parsing.
- `isqrt(a)` Integer (floor) square root of a symbol or expression

### Symbolic Dimensions

Expand Down Expand Up @@ -282,10 +278,10 @@ def free_function(tensor: FloatTensor["batch dim1"]) -> None:

## Limitations

- In the current implementation, _every_ call will be checked, which may or may not be slow depending on how big the context is (it shouldn't be that slow).
- In the current implementation, _every_ call will be checked, the performance overhead on most systems should be negligible (OTOO microseconds).
- Pydantic default values are not checked.
- Only symbolic, literal, and expressions are allowed for dimension specifiers, f-string syntax from `jaxtyping` is not supported.
- Only torch tensors and numpy arrays are supported for now.
- Static checking is not supported, only runtime checks, though some errors will be caught statically by construction.
- Static shape checking is not supported, DLType only performs runtime checks, though some expression errors will be caught statically by construction if symbolic (i.e. non-string) shapes are used.
- DLType does not support checkking inside unbounded container types (i.e. `list[TensorTypeBase]`) for performance reasons.
- DLType does not support unions, but does support optionals.
20 changes: 8 additions & 12 deletions dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ class DLTypeAnnotation(NamedTuple):
def from_hint(
cls,
hint: type | None,
name: str,
*,
optional: bool = False,
) -> tuple[DLTypeAnnotation | None, ...]:
"""Create a new _DLTypeAnnotation from a type hint."""
if hint is None:
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=3)
return (None,)

_logger.debug("Creating DLType from hint %r", hint)
Expand All @@ -77,11 +79,11 @@ def from_hint(
raise TypeError(msg)

# Recursively process the non-None type with optional=True
return cls.from_hint(non_none_types[0], optional=True)
return cls.from_hint(non_none_types[0], name, optional=True)

# tuple handling special case
if origin is tuple:
return tuple(itertools.chain(*[cls.from_hint(inner_hint) for inner_hint in args]))
return tuple(itertools.chain(*[cls.from_hint(inner_hint, name) for inner_hint in args]))

# Only process Annotated types
if origin is not Annotated:
Expand Down Expand Up @@ -135,7 +137,7 @@ def _maybe_get_type_hints(
return existing_hints
try:
return {
name: DLTypeAnnotation.from_hint(hint)
name: DLTypeAnnotation.from_hint(hint, name)
for name, hint in get_type_hints(func, include_extras=True).items()
}
except NameError:
Expand Down Expand Up @@ -209,7 +211,7 @@ def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR09

@wraps(func)
@_dependency_utilities.torch_jit_unused # pyright: ignore[reportUnknownMemberType]
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901, PLR0912
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901
__tracebackhide__ = not _constants.DEBUG_MODE
nonlocal signature
nonlocal dltype_hints
Expand Down Expand Up @@ -266,12 +268,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901, PLR0912
_resolve_value(tensor, maybe_annotation),
_resolve_types(maybe_annotation),
)
elif any(isinstance(actual_args[name], T) for T in _dtypes.SUPPORTED_TENSOR_TYPES):
warnings.warn(
f"[argument={name}] is missing a DLType hint",
UserWarning,
stacklevel=2,
)
else:
_logger.debug("No DLType hint for %r", name)

Expand Down Expand Up @@ -331,7 +327,7 @@ def _inner_dltyped_namedtuple(cls: type[NT]) -> type[NT]:
for field_name in cls._fields:
if field_name in field_hints:
hint = field_hints[field_name]
dltype_fields[field_name] = DLTypeAnnotation.from_hint(hint)
dltype_fields[field_name] = DLTypeAnnotation.from_hint(hint, field_name)

# If no fields need validation, return the original class
if not dltype_fields:
Expand Down Expand Up @@ -395,7 +391,7 @@ def _inner_dltyped_dataclass(cls: type[DataclassT]) -> type[DataclassT]:
original_init = cls.__init__
# Get field annotations
field_hints = get_type_hints(cls, include_extras=True)
dltype_hints = {name: DLTypeAnnotation.from_hint(hint) for name, hint in field_hints.items()}
dltype_hints = {name: DLTypeAnnotation.from_hint(hint, name) for name, hint in field_hints.items()}

def new_init(self: DataclassT, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
"""A new __init__ method that validates the fields after initialization."""
Expand Down
Loading