Skip to content

Commit

Permalink
fix: replace protocol with direct subclass (#2029)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Dec 22, 2022
1 parent fff5e7c commit c15c99a
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions src/awkward/_backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from abc import abstractmethod
from abc import ABC, abstractmethod

import awkward_cpp

Expand All @@ -18,18 +18,7 @@
nplike_of,
)
from awkward._typetracer import NoKernel, TypeTracer
from awkward.typing import (
Any,
Callable,
Final,
Protocol,
Self,
Tuple,
TypeAlias,
TypeVar,
Unpack,
runtime_checkable,
)
from awkward.typing import Callable, Final, Tuple, TypeAlias, TypeVar, Unpack

np = NumpyMetadata.instance()

Expand All @@ -39,8 +28,7 @@
KernelType: TypeAlias = Callable[..., None]


@runtime_checkable
class Backend(Protocol[T]):
class Backend(Singleton, ABC):
name: str

@property
Expand All @@ -53,16 +41,11 @@ def nplike(self) -> NumpyLike:
def index_nplike(self) -> NumpyLike:
raise ak._errors.wrap_error(NotImplementedError)

@classmethod
@abstractmethod
def instance(cls) -> Self:
raise ak._errors.wrap_error(NotImplementedError)

def __getitem__(self, key: KernelKeyType) -> KernelType:
raise ak._errors.wrap_error(NotImplementedError)


class NumpyBackend(Singleton, Backend[Any]):
class NumpyBackend(Backend):
name: Final[str] = "cpu"

_numpy: Numpy
Expand All @@ -82,7 +65,7 @@ def __getitem__(self, index: KernelKeyType) -> NumpyKernel:
return NumpyKernel(awkward_cpp.cpu_kernels.kernel[index], index)


class CupyBackend(Singleton, Backend[Any]):
class CupyBackend(Backend):
name: Final[str] = "cuda"

_cupy: Cupy
Expand Down Expand Up @@ -112,7 +95,7 @@ def __getitem__(self, index: KernelKeyType) -> CupyKernel | NumpyKernel:
)


class JaxBackend(Singleton, Backend[Any]):
class JaxBackend(Backend):
name: Final[str] = "jax"

_jax: Jax
Expand All @@ -135,7 +118,7 @@ def __getitem__(self, index: KernelKeyType) -> JaxKernel:
return JaxKernel(awkward_cpp.cpu_kernels.kernel[index], index)


class TypeTracerBackend(Singleton, Backend[Any]):
class TypeTracerBackend(Backend):
name: Final[str] = "typetracer"

_typetracer: TypeTracer
Expand Down

0 comments on commit c15c99a

Please sign in to comment.