# Pydantic-based executors
The [PEP-3184](https://peps.python.org/pep-3148/) executor standard allows us to create an interface for executor objects and provide intelligent context for their execution. Pydantic validators allow the dynamic validation of executor initialization and execution based on signature inspection.

Before you start, make sure you're using Pydantic >= 1.9.0. 1.8 has all sorts of bugs with json encoder propagation. 

In [1]:
# imports
import contextlib
import copy
import inspect
import logging
from concurrent.futures import Future, ThreadPoolExecutor
from importlib import import_module
from typing import Any, Callable, Dict, Generic, Iterable, Optional, TypeVar, Tuple
from types import FunctionType, MethodType
from pydantic import BaseModel, Field, root_validator, validate_arguments, validator, ValidationError, Extra
from pydantic.generics import GenericModel

logger = logging.getLogger("__name__")


## GENERICS

Because the executor classes take many forms, we'll be making use of Pydantic's generic class composition for executor type interpolation. We are able to do this by creating a placeholder TypeVar. Here, this is names ObjType, because the executor classes make use of a generalizable loading approach that could be extented to objects generally.

In [2]:
ObjType = TypeVar("ObjType")

## JSON Encoders

Pydantic does not propogate JSON encoders to child classes, so we'll define a set of common encoders:

In [3]:
import json

def is_jsonable(x):
    try:
        json.dumps(x)
        return True
    except:
        return False

    
class CallableKwargs(dict):
    
    def clean(self):
        return {key: value for key, value in self.items() if is_jsonable(value)}

    
class CallableArgs(Tuple):
    
    def clean(self): 
        return tuple(value if is_jsonable(value) else None for value in self)
    

In [4]:
tpe = ThreadPoolExecutor()
kwargs = CallableKwargs({"hey": 5, "tpe": tpe})

In [5]:
kwargs.clean()

{'hey': 5}

In [6]:
args = CallableArgs((tpe, 5, ))
args.clean()

(None, 5)

In [7]:
JSON_ENCODERS = {
    # add callable types to the encoders
    CallableArgs: lambda x: x.clean(),
    CallableKwargs: lambda x: x.clean(),
    # function/method type distinguished for class members and not recognized as callables
    FunctionType: lambda x: f"{x.__module__}:{x.__qualname__}",
    MethodType: lambda x: f"{x.__module__}:{x.__qualname__}",
    Callable: lambda x: f"{x.__module__}:{type(x).__qualname__}",
    type: lambda x: f"{x.__module__}:{x.__name__}",
    # for encoding instances of the ObjType}
    ObjType: lambda x: f"{x.__module__}:{x.__class__.__qualname__}",
}

## Utility functions for validating signatures and getting callables from strings

Central to generalizablity between executors is the ability to validate signatures args/kwargs against the executor class. 

In [8]:
def get_callable_from_string(callable: str, bind: Any = None) -> Callable:
    """Get callable from a string. In the case that the callable points to a bound method,
    the function returns a callable taking the bind instance as the first arg.

    Args:
        callable: String representation of callable abiding convention
             __module__:callable
        bind: Class to bind as self

    Returns:
        Callable
    """
    callable_split = callable.rsplit(":")

    if len(callable_split) != 2:
        raise ValueError(f"Improperly formatted callable string: {callable_split}")

    module_name, callable_name = callable_split

    try:
        module = import_module(module_name)
    except ModuleNotFoundError as err:
        logger.error("Unable to import module %s", module_name)
        raise err

    # construct partial in case of bound method
    if "." in callable_name:
        bound_class, callable_name = callable_name.rsplit(".")

        try:
            bound_class = getattr(module, bound_class)
        except Exception as e:
            logger.error("Unable to get %s from %s", bound_class, module_name)
            raise e

        # require right partial for assembly of callable
        # https://funcy.readthedocs.io/en/stable/funcs.html#rpartial
        def rpartial(func, *args):
            return lambda *a: func(*(a + args))

        callable = getattr(bound_class, callable_name)
        params = inspect.signature(callable).parameters

        # check bindings
        is_bound = params.get("self", None) is not None
        if not is_bound and bind is not None:
            raise ValueError("Cannot bind %s to %s.", callable_name, bind)

        # bound, return partial
        if bind is not None:
            if not isinstance(bind, (bound_class,)):
                raise ValueError(
                    "Provided bind %s is not instance of %s",
                    bind,
                    bound_class.__qualname__,
                )

        if is_bound and isinstance(callable, (FunctionType,)) and bind is None:
            callable = rpartial(getattr, callable_name)

        elif is_bound and isinstance(callable, (FunctionType,)) and bind is not None:
            callable = getattr(bind, callable_name)

    else:
        if bind is not None:
            raise ValueError("Cannot bind %s to %s.", callable_name, type(bind))

        try:
            callable = getattr(module, callable_name)
        except Exception as e:
            logger.error("Unable to get %s from %s", callable_name, module_name)
            raise e

    return callable


import itertools
def validate_and_compose_signature(callable: Callable, *args, **kwargs):
    
    # try partial bind to validate
    signature = inspect.signature(callable)
    bound_args = signature.bind_partial(*args, **kwargs)
    
    sig_pos_or_kw = {}
    sig_kw_only = bound_args.arguments.get("kwargs")
    sig_args_only = bound_args.arguments.get("args")
    
    n_args = len(args)
    
    # Now go parameter by parameter and assemble kwargs
    for i, param in enumerate(signature.parameters.values()):

        if param.kind == param.POSITIONAL_OR_KEYWORD:
            sig_pos_or_kw[param.name] = param.default if not param.default == param.empty else None
            
            # assign via binding
            if param.name in bound_args.arguments:
                sig_pos_or_kw[param.name] = bound_args.arguments[param.name]
                
    
    return sig_pos_or_kw, sig_kw_only, sig_args_only

In [9]:
def test_fn(x, y=4, *args, m, **kwargs):
    return x

validate_and_compose_signature(test_fn, y=5, x=2, hi=4)

({'x': 2, 'y': 5}, {'hi': 4}, None)

## Representing callables as Pydantic models
Representing callables as pydantic models allows us to take advantage of both pydantic serialization to json and pydantic's validation hooks for the kwarg validation upon creation, with possibility of delaying load. Here `CallableModel`, we can provide initialization kwargs for a to-be-instantiated-later object and reap the benefit of additional kwarg validation. 

In [10]:
class CallableModel(BaseModel):
    callable: Callable
  #  args: Optional[CallableArgs]
    kwargs: Optional[CallableKwargs]

    class Config:
        arbitrary_types_allowed = True
        json_encoders = JSON_ENCODERS
        extra = Extra.forbid

    @root_validator(pre=True)
    def validate_all(cls, values):
        callable = values.pop("callable")

        if not isinstance(
            callable,
            (
                str,
                Callable,
            ),
        ):
            raise ValueError(
                "Callable must be object or a string. Provided %s", type(callable)
            )

        # parse string to callable
        if isinstance(callable, (str,)):

            # for function loading
            if "bind" in values:
                callable = get_callable_from_string(callable, bind=values["bind"])

            else:
                callable = get_callable_from_string(callable)
        
        values["callable"] = callable

        # for reloading:
        kwargs = {}
        args = ()
        if "args" in values:
            args = values.pop("args")
            
        if "kwargs" in values:
            kwargs = values["kwargs"]

        # ignore kwarg-only and arg-only args for now
        values["kwargs"], _, _ = validate_and_compose_signature(callable, *args, **kwargs)

        return values

    def __call__(self, *args, **kwargs):
        if kwargs is None:
            kwargs = {}
            
        # create self.kwarg copy
        fn_kwargs = copy.copy(self.kwargs)
        
        # update kwargs w/ passed positional args
        if len(args):
            pos_args = args[:len(self.kwargs)]
            for i, arg in enumerate(pos_args):
                fn_kwargs.pop(list(self.kwargs.keys())[i])
                
        # update stored kwargs
        fn_kwargs.update(kwargs)
        
        return self.callable(*args, **fn_kwargs)

Let's test the callables on example function and class:

In [11]:
def test_function(x: int, y: int = 5):
    return x + y


class TestClass:
    def __init__(self, x, y):
        self.x = x
        self.y = y

In [12]:
fn = CallableModel(callable=test_function, kwargs={"x":1, "y":3})

fn.kwargs

{'x': 1, 'y': 3}

In [13]:
fn = CallableModel(callable=test_function, args=(1,3,))

fn.kwargs

{'x': 1, 'y': 3}

In [14]:
fn = CallableModel(callable=test_function, args=(1,), kwargs={"y":3})

fn.kwargs

{'x': 1, 'y': 3}

In [15]:
# dict rep
fn_dict = fn.dict()
fn_dict

{'callable': <function __main__.test_function(x: int, y: int = 5)>,
 'kwargs': {'x': 1, 'y': 3}}

In [16]:
# load from dict
fn_from_dict = CallableModel(**fn.dict()) 
fn_from_dict()

4

In [17]:
# json representation
fn.json() 

'{"callable": "__main__:test_function", "kwargs": {"x": 1, "y": 3}}'

In [18]:
# callable from json
fn_from_json = CallableModel.parse_raw(fn.json())
fn_from_json()

4

# With Classes

In [19]:
# Class kwargs passed after
parameterized_class = CallableModel(callable=TestClass, kwargs={"x":1, "y":3})
test_class_obj = parameterized_class()
assert isinstance(test_class_obj, (TestClass,))

In [20]:
# dict rep
parameterized_class_dict = parameterized_class.dict()
parameterized_class_dict

{'callable': __main__.TestClass, 'kwargs': {'x': 1, 'y': 3}}

In [21]:
# from dict
parameterized_class_from_dict = CallableModel(**parameterized_class_dict)
parameterized_class_from_dict

CallableModel(callable=<class '__main__.TestClass'>, kwargs={'x': 1, 'y': 3})

In [22]:
parameterized_class_from_dict_obj = parameterized_class_from_dict()
assert isinstance(parameterized_class_from_dict_obj, (TestClass,))

In [23]:
#json 
parameterized_class_json = parameterized_class.json()
parameterized_class_json

'{"callable": "__main__:TestClass", "kwargs": {"x": 1, "y": 3}}'

In [24]:
parameterized_class_from_json = CallableModel.parse_raw(parameterized_class_json)
test_class_obj = parameterized_class_from_json()
assert isinstance(test_class_obj, (TestClass,))

We can use the callables to construct a dynamic object loader. The generic type allows us to use this same method for any executor. The syntax: `ObjLoader[ThreadPoolExecutor]` composes a new class entirely, this one specific to the `ThreadPoolExecutor`. 

In [25]:
class ObjLoader(
    GenericModel,
    Generic[ObjType],
    arbitrary_types_allowed=True,
    json_encoders=JSON_ENCODERS,
):
    object: Optional[ObjType]
    loader: CallableModel = None
    object_type: Optional[type]

    @root_validator(pre=True)
    def validate_all(cls, values):
        # inspect class init signature
        obj_type = cls.__fields__["object"].type_
        
        # adjust for re init from json
        if "loader" not in values:
            loader = CallableModel(callable=obj_type, **values)

        else:
            # validate loader callable is same as obj type
            if values["loader"].get("callable") is not None:
                # unparameterized callable will handle parsing
                callable = CallableModel(
                    callable=values["loader"]["callable"]
                )
                
                if not callable.callable is obj_type:
                    raise ValueError(
                        "Provided loader of type %s. ObjLoader parameterized for %s",
                        callable.callable.__name__,
                        obj_type,
                    )

                # opt for obj type
                values["loader"].pop("callable")

            # re-init drop callable from loader vals to use new instance
            loader = CallableModel(callable=obj_type, **values["loader"])

        # update the class json encoders. Will only execute on initial type construction
        if obj_type not in cls.__config__.json_encoders:
            cls.__config__.json_encoders[obj_type] = cls.__config__.json_encoders.pop(
                ObjType
            )
        return {"object_type": obj_type, "loader": loader}

    def load(self, store: bool = False):
        # store object reference on loader
        if store:
            self.object = self.loader.call()
            return self.object

        # return loaded object w/o storing
        else:
            return self.loader()

Let's test object loader on our `TestClass`:

In [26]:
# create type
TestClassLoader = ObjLoader[TestClass]

obj_loader = TestClassLoader(kwargs={"x":1, "y":3})
loaded = obj_loader.load()
loaded

<__main__.TestClass at 0x118377a00>

Can do this for a generic object like `ThreadPoolExecutor`:

In [27]:
# create Type
TPELoader = ObjLoader[ThreadPoolExecutor]

tpe_loader = TPELoader(kwargs={"max_workers":1})
tpe = tpe_loader.load()
tpe
tpe_loader_json  = tpe_loader.json()
tpe_loader_json
tpe_loader_from_json = TPELoader.parse_raw(tpe_loader_json)


# shutdown tpe
tpe.shutdown()


## Executors
The previous classes were an attempt to demonstrate generic utility. The Executors to follow will build off of those common utilities to parameterize generic executors complying with the pep-3148 standard (the callables have been typified in case of deviation). Likewise, the following BaseExecutor outlines common executor fields and methods.

In [28]:
# COMMON BASE FOR EXECUTORS
class BaseExecutor(
    GenericModel,
    Generic[ObjType],
    arbitrary_types_allowed=True,
    json_encoders=JSON_ENCODERS,
):
    # executor_type must comply with https://peps.python.org/pep-3148/ standard
    loader: Optional[ObjLoader[ObjType]] # loader of executor type

    # This is a utility field not included in reps. The typing lib has opened issues on access of generic type within class.
    # This tracks for if-necessary future use.
    executor_type: type = Field(None, exclude=True) 
    submit_callable: str = "submit"
    map_callable: str = "map"
    shutdown_callable: str = "shutdown"

    # executor will not be explicitely serialized, but loaded using loader with class
    # and kwargs
    executor: Optional[ObjType]

    @root_validator(pre=True)
    def validate_all(cls, values):
        executor_type = cls.__fields__["executor"].type_ # introspect fields to get type

        # check if executor provided
        executor = values.get("executor")
        if executor is not None:
            values.pop("executor")
        
        # VALIDATE SUBMIT CALLABLE AGAINST EXECUTOR TYPE
        if "submit_callable" not in values:
            # use default
            submit_callable = cls.__fields__["submit_callable"].default
        else:
            submit_callable = values.pop("submit_callable")

        try:
            getattr(executor_type, submit_callable)
        except AttributeError:
            raise ValueError(
                "Executor type %s has no submit method %s.",
                executor_type.__name__,
                submit_callable,
            )

        # VALIDATE MAP CALLABLE AGAINST EXECUTOR TYPE
        if not values.get("map_callable"):
            # use default
            map_callable = cls.__fields__["map_callable"].default
        else:
            map_callable = values.pop("map_callable")

        try:
            getattr(executor_type, map_callable)
        except AttributeError:
            raise ValueError(
                "Executor type %s has no map method %s.",
                executor_type.__name__,
                map_callable,
            )

        # VALIDATE SHUTDOWN CALLABLE AGAINST EXECUTOR TYPE
        if not values.get("shutdown_callable"):
            # use default
            shutdown_callable = cls.__fields__["shutdown_callable"].default
        else:
            shutdown_callable = values.pop("shutdown_callable")

        try:
            getattr(executor_type, shutdown_callable)
        except AttributeError:
            raise ValueError(
                "Executor type %s has no shutdown method %s.",
                executor_type.__name__,
                shutdown_callable,
            )

        # Compose loader utility
        if values.get("loader") is not None:
            loader_values = values.get("loader")
            loader = ObjLoader[executor_type](**loader_values)

        else:
            # maintain reference to original object
            loader_values = copy.copy(values)

            # if executor in values, need to remove
            if "executor" in loader_values:
                loader_values.pop("executor")

            loader = ObjLoader[executor_type](**loader_values)

        # update encoders
        # update the class json encoders. Will only execute on initial type construction
        if executor_type not in cls.__config__.json_encoders:
            cls.__config__.json_encoders[
                executor_type
            ] = cls.__config__.json_encoders.pop(ObjType)

        return {
            "executor_type": executor_type,
            "submit_callable": submit_callable,
            "shutdown_callable": shutdown_callable,
            "map_callable": map_callable,
            "loader": loader,
            "executor": executor,
        }

    def shutdown(self) -> None:
        shutdown_fn = getattr(self.executor, self.shutdown_callable)
        shutdown_fn()

## Normal, ContextExecutor
Now, we subclass base to create two executors: `NormalExecutor`, and `ContextExecutor`. In the case that the user would like to create a persistent executor passed to the Evaluator, they would use the NormalExecutor. The ContextExecutor provides a context manager to dynamically create executor instances during execution.

In [29]:
# NormalExecutor with no context handling on submission and executor persistence
class NormalExecutor(
    BaseExecutor[ObjType],
    Generic[ObjType],
    arbitrary_types_allowed=True,
    json_encoders=JSON_ENCODERS,
):

    @validator("executor", always=True)
    def validate_executor(cls, v, values):

        if v is None:
            v = values["loader"].load()

        # if not None, validate against executor type
        else:
            if not isinstance(v, (values["executor_type"],)):
                raise ValueError(
                    "Provided executor is not instance of %s",
                    values["executor_type"].__name__,
                )

        return v

    def submit(self, fn, **kwargs) -> Future:
        submit_fn = getattr(self.executor, self.submit_callable)
        return submit_fn(fn, **kwargs)

    def map(self, fn, iter: Iterable) -> Iterable[Future]:
        map_fn = getattr(self.executor, self.map_callable)
        return map_fn(fn, *iter)

Create some NormalExecutors: (must manually shutdown)

In [30]:
# ThreadPool
# create type
NormTPExecutor = NormalExecutor[ThreadPoolExecutor]

tpe_exec = NormTPExecutor(kwargs={"max_workers":1})
# submit
tpe_exec.submit(fn=test_function, x=1, y=8)

<Future at 0x118350af0 state=finished returned int>

In [31]:
# map
tpe_exec.map(test_function, ((1, 4), (3, 4)))

<generator object Executor.map.<locals>.result_iterator at 0x11837ac10>

In [32]:
tpe_exec.shutdown()

In [33]:
# Dask
from distributed import Client
from distributed.cfexecutor import ClientExecutor

# Using an existing executor
client = Client(silence_logs=logging.ERROR)
executor = client.get_executor()

# create type
NormalDaskExecutor =  NormalExecutor[type(executor)]

dask_executor = NormalDaskExecutor(executor=executor)
dask_executor.submit(fn=test_function, x=1, y=8)

2022-06-06 15:05:53,117 - distributed.diskutils - INFO - Found stale lock file and directory '/Users/jgarra/sandbox/Xopt/developer/dask-worker-space/worker-w6g1rimw', purging
2022-06-06 15:05:53,119 - distributed.diskutils - INFO - Found stale lock file and directory '/Users/jgarra/sandbox/Xopt/developer/dask-worker-space/worker-na0d07pe', purging
2022-06-06 15:05:53,120 - distributed.diskutils - INFO - Found stale lock file and directory '/Users/jgarra/sandbox/Xopt/developer/dask-worker-space/worker-krzkjnsh', purging


<Future at 0x12afa2f70 state=pending>

In [34]:
res = dask_executor.map(test_function, ((1, 4), (3, 4)))

In [35]:
for r in res:
    print(r)

4
8


In [36]:
dask_executor_json = dask_executor.json()
dask_executor_json

'{"loader": {"object": null, "loader": {"callable": "distributed.cfexecutor:ClientExecutor", "kwargs": {"client": null}}, "object_type": "distributed.cfexecutor:ClientExecutor"}, "submit_callable": "submit", "map_callable": "map", "shutdown_callable": "shutdown", "executor": "distributed.cfexecutor:ClientExecutor"}'

In [37]:
dask_executor.shutdown()

In [38]:
# this raises error because client not passed...
# dask_executor_from_json = NormalDaskExecutor.parse_raw(dask_executor_json)

Context managers handle shutdown for us:

In [39]:
# ContexExecutor with context handling on submission and no executor persistence
class ContextExecutor(
    BaseExecutor[ObjType],
    Generic[ObjType],
    arbitrary_types_allowed=True,
    json_encoders=JSON_ENCODERS,
):
    @contextlib.contextmanager
    def context(self):

        try:
            self.executor = self.loader.load()
            yield self.executor

        finally:
            self.shutdown()
            self.executor = None

    def submit(self, fn, **kwargs) -> Future:
        with self.context() as ctxt:
            submit_fn = getattr(ctxt, self.submit_callable)
            return submit_fn(fn, **kwargs)
        
    def map(self, fn, iter: Iterable) -> Iterable[Future]:
        with self.context() as ctxt:
            map_fn = getattr(ctxt, self.map_callable)
            return map_fn(fn, iter)


Create some ContextExecutors

In [41]:
# ThreadPoolExecutor
# create type

ContextTPExecutor = ContextExecutor[ThreadPoolExecutor]

context_exec = ContextTPExecutor(kwargs={"max_workers":1})
context_exec.submit(fn=test_function, x=1, y=8)

<Future at 0x118357700 state=finished returned int>

In [42]:
context_exec.map(test_function, ((1, 4), (3, 4)))

<generator object Executor.map.<locals>.result_iterator at 0x12bb2feb0>

In [43]:
context_exec_json = context_exec.json()
context_exec_json

'{"loader": {"object": null, "loader": {"callable": "concurrent.futures.thread:ThreadPoolExecutor", "kwargs": {"max_workers": 1, "thread_name_prefix": "", "initializer": null, "initargs": []}}, "object_type": "concurrent.futures.thread:ThreadPoolExecutor"}, "submit_callable": "submit", "map_callable": "map", "shutdown_callable": "shutdown", "executor": null}'

In [44]:
context_exec_from_json = ContextTPExecutor.parse_raw(
        context_exec_json
    )
context_exec_from_json.submit(fn=test_function, x=1, y=8)

<Future at 0x118355310 state=finished returned int>

In [45]:
context_exec_from_json.map(test_function, ((1, 4), (3, 4)))

<generator object Executor.map.<locals>.result_iterator at 0x12bb2f040>

Some executors are generated with Clients that manage sessions:
** will require gathering results before shutdown...