Skip to content

Commit

Permalink
Merge 47c3b64 into 7615f93
Browse files Browse the repository at this point in the history
  • Loading branch information
spyoungtech committed Aug 23, 2023
2 parents 7615f93 + 47c3b64 commit d7fb1bb
Show file tree
Hide file tree
Showing 10 changed files with 350 additions and 13 deletions.
40 changes: 39 additions & 1 deletion ahk/_async/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import time
import warnings
from functools import partial
from typing import Any
from typing import Awaitable
from typing import Callable
Expand Down Expand Up @@ -34,6 +35,7 @@
else:
from typing import TypeAlias

from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry
from ..keys import Key
from .transport import AsyncDaemonProcessTransport
from .transport import AsyncFutureResult
Expand Down Expand Up @@ -135,13 +137,49 @@ def __init__(
TransportClass: Optional[Type[AsyncTransport]] = None,
directives: Optional[list[Directive | Type[Directive]]] = None,
executable_path: str = '',
extensions: list[Extension] | None | Literal['auto'] = None,
):
self._extension_registry: _ExtensionMethodRegistry
self._extensions: list[Extension]
if extensions == 'auto':
is_async = False
is_async = True # unasync: remove
if is_async:
extensions = list(
set(entry.extension for name, entry in _extension_method_registry.async_methods.items())
)
else:
extensions = list(
set(entry.extension for name, entry in _extension_method_registry.sync_methods.items())
)
self._extension_registry = _extension_method_registry
self._extensions = extensions
else:
self._extensions = extensions or []
self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._extension_registry.merge(ext._extension_method_registry)

if TransportClass is None:
TransportClass = AsyncDaemonProcessTransport
assert TransportClass is not None
transport = TransportClass(executable_path=executable_path, directives=directives)
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions)
self._transport: AsyncTransport = transport

def __getattr__(self, name: str) -> Callable[..., Any]:
is_async = False
is_async = True # unasync: remove
if is_async:
if name in self._extension_registry.async_methods:
method = self._extension_registry.async_methods[name].method
return partial(method, self)
else:
if name in self._extension_registry.sync_methods:
method = self._extension_registry.sync_methods[name].method
return partial(method, self)

raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}')

def add_hotkey(
self, keyname: str, callback: Callable[[], Any], ex_handler: Optional[Callable[[str, Exception], Any]] = None
) -> None:
Expand Down
7 changes: 6 additions & 1 deletion ahk/_async/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import jinja2

from ahk.extensions import Extension
from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring
from ahk.message import RequestMessage
from ahk.message import ResponseMessage
Expand Down Expand Up @@ -656,7 +657,9 @@ def __init__(
directives: Optional[list[Directive | Type[Directive]]] = None,
jinja_loader: Optional[jinja2.BaseLoader] = None,
template: Optional[jinja2.Template] = None,
extensions: list[Extension] | None = None,
):
self._extensions = extensions or []
self._proc: Optional[AsyncAHKProcess]
self._proc = None
self._temp_script: Optional[str] = None
Expand Down Expand Up @@ -711,7 +714,9 @@ def _render_script(self, template: Optional[jinja2.Template] = None, **kwargs: A
template = self._template
kwargs['daemon'] = self.__template
message_types = {str(tom, 'utf-8'): c.__name__.upper() for tom, c in _message_registry.items()}
return template.render(directives=self._directives, message_types=message_types, **kwargs)
return template.render(
directives=self._directives, message_types=message_types, extensions=self._extensions, **kwargs
)

@property
def lock(self) -> Any:
Expand Down
14 changes: 14 additions & 0 deletions ahk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
#NoEnv
#Persistent
#SingleInstance Off
{% block extension_directives %}
; BEGIN extension includes
{% for ext in extensions %}
{% for inc in ext.includes %}
{{ inc }}
{% endfor %}
{% endfor %}
; END extension includes
{% endblock extension_directives %}
; BEGIN user-defined directives
{% block user_directives %}
{% for directive in directives %}
Expand Down Expand Up @@ -2833,7 +2842,12 @@
return decoded_commands
}
; BEGIN extension scripts
{% for ext in extensions %}
{{ ext.script_text }}
{% endfor %}
; END extension scripts
{% block before_autoexecute %}
{% endblock before_autoexecute %}
Expand Down
34 changes: 33 additions & 1 deletion ahk/_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import time
import warnings
from functools import partial
from typing import Any
from typing import Awaitable
from typing import Callable
Expand Down Expand Up @@ -34,6 +35,7 @@
else:
from typing import TypeAlias

from ..extensions import Extension, _extension_method_registry, _ExtensionMethodRegistry
from ..keys import Key
from .transport import DaemonProcessTransport
from .transport import FutureResult
Expand Down Expand Up @@ -131,13 +133,43 @@ def __init__(
TransportClass: Optional[Type[Transport]] = None,
directives: Optional[list[Directive | Type[Directive]]] = None,
executable_path: str = '',
extensions: list[Extension] | None | Literal['auto'] = None,
):
self._extension_registry: _ExtensionMethodRegistry
self._extensions: list[Extension]
if extensions == 'auto':
is_async = False
if is_async:
extensions = list(set(entry.extension for name, entry in _extension_method_registry.async_methods.items()))
else:
extensions = list(set(entry.extension for name, entry in _extension_method_registry.sync_methods.items()))
self._extension_registry = _extension_method_registry
self._extensions = extensions
else:
self._extensions = extensions or []
self._extension_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})
for ext in self._extensions:
self._extension_registry.merge(ext._extension_method_registry)

if TransportClass is None:
TransportClass = DaemonProcessTransport
assert TransportClass is not None
transport = TransportClass(executable_path=executable_path, directives=directives)
transport = TransportClass(executable_path=executable_path, directives=directives, extensions=extensions)
self._transport: Transport = transport

def __getattr__(self, name: str) -> Callable[..., Any]:
is_async = False
if is_async:
if name in self._extension_registry.async_methods:
method = self._extension_registry.async_methods[name].method
return partial(method, self)
else:
if name in self._extension_registry.sync_methods:
method = self._extension_registry.sync_methods[name].method
return partial(method, self)

raise AttributeError(f'{self.__class__.__name__!r} object has no attribute {name!r}')

def add_hotkey(
self, keyname: str, callback: Callable[[], Any], ex_handler: Optional[Callable[[str, Exception], Any]] = None
) -> None:
Expand Down
5 changes: 4 additions & 1 deletion ahk/_sync/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import jinja2

from ahk.extensions import Extension
from ahk._hotkey import ThreadedHotkeyTransport, Hotkey, Hotstring
from ahk.message import RequestMessage
from ahk.message import ResponseMessage
Expand Down Expand Up @@ -630,7 +631,9 @@ def __init__(
directives: Optional[list[Directive | Type[Directive]]] = None,
jinja_loader: Optional[jinja2.BaseLoader] = None,
template: Optional[jinja2.Template] = None,
extensions: list[Extension] | None = None
):
self._extensions = extensions or []
self._proc: Optional[SyncAHKProcess]
self._proc = None
self._temp_script: Optional[str] = None
Expand Down Expand Up @@ -684,7 +687,7 @@ def _render_script(self, template: Optional[jinja2.Template] = None, **kwargs: A
template = self._template
kwargs['daemon'] = self.__template
message_types = {str(tom, 'utf-8'): c.__name__.upper() for tom, c in _message_registry.items()}
return template.render(directives=self._directives, message_types=message_types, **kwargs)
return template.render(directives=self._directives, message_types=message_types, extensions=self._extensions, **kwargs)

@property
def lock(self) -> Any:
Expand Down
12 changes: 3 additions & 9 deletions ahk/_sync/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,11 @@ def __hash__(self) -> int:
return hash(self._ahk_id)

def close(self) -> None:
self._engine.win_close(
title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast')
)
self._engine.win_close(title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast'))
return None

def kill(self) -> None:
self._engine.win_kill(
title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast')
)
self._engine.win_kill(title=f'ahk_id {self._ahk_id}', detect_hidden_windows=True, title_match_mode=(1, 'Fast'))

def exists(self) -> bool:
return self._engine.win_exists(
Expand Down Expand Up @@ -591,9 +587,7 @@ def set_transparent(
blocking=blocking,
)

def set_trans_color(
self, color: Union[int, str], *, blocking: bool = True
) -> Union[None, FutureResult[None]]:
def set_trans_color(self, color: Union[int, str], *, blocking: bool = True) -> Union[None, FutureResult[None]]:
return self._engine.win_set_trans_color(
color=color,
title=f'ahk_id {self._ahk_id}',
Expand Down
104 changes: 104 additions & 0 deletions ahk/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import asyncio
import sys
import warnings
from dataclasses import dataclass
from typing import Any
from typing import Callable
from typing import TypeVar

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

from .directives import Include


@dataclass
class _ExtensionEntry:
extension: Extension
method: Callable[..., Any]


T = TypeVar('T')
P = ParamSpec('P')


@dataclass
class _ExtensionMethodRegistry:
sync_methods: dict[str, _ExtensionEntry]
async_methods: dict[str, _ExtensionEntry]

def register(self, ext: Extension, f: Callable[P, T]) -> Callable[P, T]:
if asyncio.iscoroutinefunction(f):
if f.__name__ in self.async_methods:
warnings.warn(
f'Method of name {f.__name__!r} has already been registered. '
f'Previously registered method {self.async_methods[f.__name__].method!r} '
f'will be overridden by {f!r}'
)
self.async_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f)
else:
if f.__name__ in self.sync_methods:
warnings.warn(
f'Method of name {f.__name__!r} has already been registered. '
f'Previously registered method {self.sync_methods[f.__name__].method!r} '
f'will be overridden by {f!r}'
)
self.sync_methods[f.__name__] = _ExtensionEntry(extension=ext, method=f)
return f

def merge(self, other: _ExtensionMethodRegistry) -> None:
for fname, entry in other.async_methods.items():
async_method = entry.method
if async_method.__name__ in self.async_methods:
warnings.warn(
f'Method of name {async_method.__name__!r} has already been registered. '
f'Previously registered method {self.async_methods[async_method.__name__].method!r} '
f'will be overridden by {async_method!r}'
)
self.async_methods[async_method.__name__] = entry
for fname, entry in other.sync_methods.items():
method = entry.method
if method.__name__ in self.sync_methods:
warnings.warn(
f'Method of name {method.__name__!r} has already been registered. '
f'Previously registered method {self.sync_methods[method.__name__].method!r} '
f'will be overridden by {method!r}'
)
self.sync_methods[method.__name__] = entry


_extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})


class Extension:
def __init__(
self,
includes: list[str] | None = None,
script_text: str | None = None,
# template: str | Template | None = None
):
self._text: str = script_text or ''
# self._template: str | Template | None = template
self._includes: list[str] = includes or []
self._extension_method_registry = _ExtensionMethodRegistry(sync_methods={}, async_methods={})

@property
def script_text(self) -> str:
return self._text

@script_text.setter
def script_text(self, new_script: str) -> None:
self._text = new_script

@property
def includes(self) -> list[Include]:
return [Include(inc) for inc in self._includes]

def register(self, f: Callable[P, T]) -> Callable[P, T]:
self._extension_method_registry.register(self, f)
_extension_method_registry.register(self, f)
return f
14 changes: 14 additions & 0 deletions ahk/templates/daemon.ahk
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
#NoEnv
#Persistent
#SingleInstance Off
{% block extension_directives %}
; BEGIN extension includes
{% for ext in extensions %}
{% for inc in ext.includes %}
{{ inc }}

{% endfor %}
{% endfor %}
; END extension includes
{% endblock extension_directives %}
; BEGIN user-defined directives
{% block user_directives %}
{% for directive in directives %}
Expand Down Expand Up @@ -2830,7 +2839,12 @@ CommandArrayFromQuery(ByRef text) {
return decoded_commands
}

; BEGIN extension scripts
{% for ext in extensions %}
{{ ext.script_text }}

{% endfor %}
; END extension scripts
{% block before_autoexecute %}
{% endblock before_autoexecute %}

Expand Down
Loading

0 comments on commit d7fb1bb

Please sign in to comment.