Skip to content

Commit

Permalink
Implement hot reloading with websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
AA-Turner committed Apr 8, 2024
1 parent 1aa4bd3 commit 3a32d1b
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 47 deletions.
8 changes: 2 additions & 6 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ sphinx-autobuild
:target: https://opensource.org/licenses/MIT
:alt: MIT

Rebuild Sphinx documentation on changes, with live-reload in the browser.
Rebuild Sphinx documentation on changes, with hot reloading in the browser.

.. image:: ./docs/_static/demo.png
:align: center
Expand Down Expand Up @@ -167,16 +167,12 @@ __ https://github.com/sphinx-doc/sphinx-autobuild/issues/34
Acknowledgements
================

This project stands on the shoulders of giants like
Sphinx_, LiveReload_ and python-livereload_,
This project stands on the shoulders of giants,
without whom this project would not be possible.

Many thanks to everyone who has `contributed code`_ as well as
participated in `discussions on the issue tracker`_.
This project is better thanks to your contribution.

.. _Sphinx: https://sphinx-doc.org/
.. _LiveReload: https://livereload.com/
.. _python-livereload: https://github.com/lepture/python-livereload
.. _contributed code: https://github.com/sphinx-doc/sphinx-autobuild/graphs/contributors
.. _discussions on the issue tracker: https://github.com/sphinx-doc/sphinx-autobuild/issues
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "flit_core.buildapi"
# project metadata
[project]
name = "sphinx-autobuild"
description = "Rebuild Sphinx documentation on changes, with live-reload in the browser."
description = "Rebuild Sphinx documentation on changes, with hot reloading in the browser."
readme = "README.rst"
urls.Changelog = "https://github.com/sphinx-doc/sphinx-autobuild/blob/main/NEWS.rst"
urls.Documentation = "https://github.com/sphinx-doc/sphinx-autobuild#readme"
Expand Down Expand Up @@ -43,7 +43,9 @@ classifiers = [
]
dependencies = [
"sphinx",
"livereload",
"starlette>=0.35",
"uvicorn>=0.25",
"websockets>=11.0",
"colorama",
]
dynamic = ["version"]
Expand Down
2 changes: 1 addition & 1 deletion sphinx_autobuild/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Rebuild Sphinx documentation on changes, with live-reload in the browser."""
"""Rebuild Sphinx documentation on changes, with hot reloading in the browser."""

__version__ = "2024.02.04"
29 changes: 19 additions & 10 deletions sphinx_autobuild/__main__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
"""Entrypoint for ``python -m sphinx_autobuild``."""

from sphinx_autobuild import _hacks # isort:skip # noqa

import argparse
import os
import shlex
import sys

import colorama
from livereload import Server
import uvicorn

# This isn't public API, but there aren't many better options
from sphinx.cmd.build import get_parser as sphinx_get_parser
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.routing import Mount, WebSocketRoute
from starlette.staticfiles import StaticFiles

from sphinx_autobuild import __version__
from sphinx_autobuild.build import Builder
from sphinx_autobuild.filter import IgnoreFilter
from sphinx_autobuild.middleware import JavascriptInjectorMiddleware
from sphinx_autobuild.server import RebuildServer
from sphinx_autobuild.utils import find_free_port, open_browser, show


Expand All @@ -33,7 +37,6 @@ def main():
host_name = args.host
port_num = args.port or find_free_port()
url_host = f"{host_name}:{port_num}"
server = Server()

pre_build_commands = list(map(shlex.split, args.pre_build))

Expand All @@ -43,15 +46,21 @@ def main():
pre_build_commands=pre_build_commands,
)

watch_dirs = [src_dir] + args.additional_watched_dirs
ignore_handler = IgnoreFilter(
[p for p in args.ignore + [out_dir, args.warnings_file, args.doctree_dir] if p],
args.re_ignore,
)
server.watch(src_dir, builder, ignore=ignore_handler)
for dirpath in args.additional_watched_dirs:
dirpath = os.path.realpath(dirpath)
server.watch(dirpath, builder, ignore=ignore_handler)
server.watch(out_dir, ignore=ignore_handler)
watcher = RebuildServer(watch_dirs, ignore_handler, change_callback=builder)

app = Starlette(
routes=[
WebSocketRoute("/websocket-reload", watcher, name="reload"),
Mount("/", app=StaticFiles(directory=out_dir, html=True), name="static"),
],
middleware=[Middleware(JavascriptInjectorMiddleware, ws_url=url_host)],
lifespan=watcher.lifespan,
)

if not args.no_initial_build:
builder(rebuild=False)
Expand All @@ -60,7 +69,7 @@ def main():
open_browser(url_host, args.delay)

try:
server.serve(port=port_num, host=host_name, root=out_dir)
uvicorn.run(app, host=host_name, port=port_num, log_level="warning")
except KeyboardInterrupt:
show(context="Server ceasing operations. Cheerio!")

Expand Down
28 changes: 0 additions & 28 deletions sphinx_autobuild/_hacks.py

This file was deleted.

44 changes: 44 additions & 0 deletions sphinx_autobuild/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from starlette.datastructures import MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send


def web_socket_script(ws_url: str) -> str:
# language=HTML
return f"""
<script>
const ws = new WebSocket("ws://{ws_url}/websocket-reload");
ws.onmessage = () => window.location.reload();
</script>
"""


class JavascriptInjectorMiddleware:
def __init__(self, app: ASGIApp, ws_url: str) -> None:
self.app = app
self.script = web_socket_script(ws_url).encode("utf-8")

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
add_script = False
if scope["type"] != "http":
await self.app(scope, receive, send)
return

async def send_wrapper(message: Message) -> None:
nonlocal add_script
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
if headers.get("Content-Type", "").startswith("text/html"):
add_script = True
if "Content-Length" in headers:
length = int(headers["Content-Length"]) + len(self.script)
headers["Content-Length"] = str(length)
elif message["type"] == "http.response.body":
request_complete = not message.get("more_body", False)
if add_script and request_complete:
message["body"] += self.script
await send(message)

await self.app(scope, receive, send_wrapper)
return
78 changes: 78 additions & 0 deletions sphinx_autobuild/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import asyncio
import os
from contextlib import AbstractAsyncContextManager, asynccontextmanager

import watchfiles
from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket

TYPE_CHECKING = False
if TYPE_CHECKING:
from collections.abc import Callable

from sphinx_autobuild.filter import IgnoreFilter


class RebuildServer:
def __init__(
self,
paths: list[os.PathLike[str]],
ignore_filter: IgnoreFilter,
change_callback: Callable[[], None],
) -> None:
self.paths = [os.path.realpath(path, strict=True) for path in paths]
self.ignore = ignore_filter
self.change_callback = change_callback
self.flag = asyncio.Event()
self.should_exit = asyncio.Event()

@asynccontextmanager
async def lifespan(self, _app) -> AbstractAsyncContextManager[None]:
task = asyncio.create_task(self.main())
yield
self.should_exit.set()
await task
return

async def main(self) -> None:
tasks = (
asyncio.create_task(self.watch()),
asyncio.create_task(self.should_exit.wait()),
)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
[task.cancel() for task in pending]
[task.result() for task in done]

async def watch(self) -> None:
async for _changes in watchfiles.awatch(
*self.paths,
watch_filter=lambda _, path: not self.ignore(path),
):
self.change_callback()
self.flag.set()

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "websocket"
ws = WebSocket(scope, receive, send)
await ws.accept()

tasks = (
asyncio.create_task(self.watch_reloads(ws)),
asyncio.create_task(self.wait_client_disconnect(ws)),
)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
[task.cancel() for task in pending]
[task.result() for task in done]

async def watch_reloads(self, ws: WebSocket) -> None:
while True:
await self.flag.wait()
self.flag.clear()
await ws.send_text("refresh")

@staticmethod
async def wait_client_disconnect(ws: WebSocket) -> None:
async for _ in ws.iter_text():
pass

0 comments on commit 3a32d1b

Please sign in to comment.