Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Socket.IO for message transport #449

Merged
merged 11 commits into from
Feb 6, 2023
648 changes: 509 additions & 139 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pynecone/.templates/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
"react-markdown": "^8.0.3",
"react-plotly.js": "^2.6.0",
"react-syntax-highlighter": "^15.5.0",
"reconnecting-websocket": "^4.4.0",
"rehype-katex": "^6.0.2",
"rehype-raw": "^6.1.1",
"remark-gfm": "^3.0.1",
"remark-math": "^5.1.1",
"socket.io-client": "^4.5.4",
"victory": "^36.6.8"
}
}
2 changes: 1 addition & 1 deletion pynecone/.templates/web/pcversion.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.14
0.1.15
25 changes: 11 additions & 14 deletions pynecone/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// State management for Pynecone web apps.
import ReconnectingWebSocket from 'reconnecting-websocket';
import io from 'socket.io-client';

// Global variable to hold the token.
let token;
Expand Down Expand Up @@ -90,7 +90,7 @@ export const applyEvent = async (event, router, socket) => {
event.token = getToken();
event.router_data = (({ pathname, query }) => ({ pathname, query }))(router);
if (socket) {
socket.send(JSON.stringify(event));
socket.emit("event", JSON.stringify(event));
}
};

Expand All @@ -109,11 +109,6 @@ export const updateState = async (state, setState, result, setResult, router, so
return;
}

// If the socket is not ready, return.
if (!socket.readyState) {
return;
}

// Set processing to true to block other events from being processed.
setResult({ ...result, processing: true });

Expand All @@ -137,23 +132,25 @@ export const updateState = async (state, setState, result, setResult, router, so
*/
export const connect = async (socket, state, setState, result, setResult, router, endpoint) => {
// Create the socket.
socket.current = new ReconnectingWebSocket(endpoint);
socket.current = io(endpoint, {
'path': '/event',
});

// Once the socket is open, hydrate the page.
socket.current.onopen = () => {
updateState(state, setState, result, setResult, router, socket.current)
}
socket.current.on('connect', () => {
updateState(state, setState, result, setResult, router, socket.current);
});

// On each received message, apply the delta and set the result.
socket.current.onmessage = function (update) {
update = JSON.parse(update.data);
socket.current.on('event', function (update) {
update = JSON.parse(update);
applyDelta(state, update.delta);
setResult({
processing: false,
state: state,
events: update.events,
});
};
});
};

/**
Expand Down
140 changes: 70 additions & 70 deletions pynecone/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union

from fastapi import FastAPI, WebSocket
from fastapi.middleware import cors
from starlette.websockets import WebSocketDisconnect
from fastapi import FastAPI
from socketio import ASGIApp, AsyncNamespace, AsyncServer

from pynecone import constants, utils
from pynecone.base import Base
Expand Down Expand Up @@ -33,6 +32,9 @@ class App(Base):
# The backend API object.
api: FastAPI = None # type: ignore

# The Socket.IO AsyncServer.
sio: AsyncServer = None

# The state class to use for the app.
state: Type[State] = DefaultState

Expand Down Expand Up @@ -64,10 +66,23 @@ def __init__(self, *args, **kwargs):
self.state_manager.setup(state=self.state)

# Set up the API.

self.api = FastAPI()
self.add_cors()
self.add_default_endpoints()

# Set up the Socket.IO AsyncServer.
self.sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")

# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
socket_app = ASGIApp(self.sio, socketio_path=str(constants.Endpoint.EVENT))

# Create the event namespace and attach the main app. Not related to the path above.
event_namespace = EventNamespace("/event")
event_namespace.app = self

# Register the event namespace with the socket.
self.sio.register_namespace(event_namespace)

# Mount the socket app with the API.
self.api.mount("/", socket_app)

def __repr__(self) -> str:
"""Get the string representation of the app.
Expand All @@ -85,24 +100,6 @@ def __call__(self) -> FastAPI:
"""
return self.api

def add_default_endpoints(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the ping endpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it back in. For future reference it can be tested like so:

    socket.current.emit('ping')
    console.log('ping')
    socket.current.on('ping', function(data) {
      console.log(data);
    });

In the console you should see:

ping
pong

"""Add the default endpoints."""
# To test the server.
self.api.get(str(constants.Endpoint.PING))(ping)

# To make state changes.
self.api.websocket(str(constants.Endpoint.EVENT))(event(app=self))

def add_cors(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need this anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's taken care of when the AsyncServer is created on the App __init__ :

    self.sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")

There's also cors_credentials= which defaults to True already.

python-scoketio does not have CORS options like allow_methods and allow_headers but I believe it's configured to avoid cross-origin problems with browsers by allowing all headers and methods anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay got it nice

"""Add CORS middleware to the app."""
self.api.add_middleware(
cors.CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

def preprocess(self, state: State, event: Event) -> Optional[Delta]:
"""Preprocess the event.

Expand Down Expand Up @@ -327,52 +324,6 @@ def compile(self, force_compile: bool = False):
compiler.compile_components(custom_components)


async def ping() -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be useful for health checks, etc.

"""Test API endpoint.

Returns:
The response.
"""
return "pong"


def event(app: App):
"""Websocket endpoint for events.

Args:
app: The app to add the endpoint to.

Returns:
The websocket endpoint.
"""

async def ws(websocket: WebSocket):
"""Create websocket endpoint.

Args:
websocket: The websocket sending events.
"""
# Accept the connection.
await websocket.accept()

# Process events until the connection is closed.
while True:
# Get the event.
try:
event = Event.parse_raw(await websocket.receive_text())
except WebSocketDisconnect:
# Close the connection.
return

# Process the event.
update = await process(app, event)

# Send the update.
await websocket.send_text(update.json())

return ws


async def process(app: App, event: Event) -> StateUpdate:
"""Process an event.

Expand Down Expand Up @@ -405,3 +356,52 @@ async def process(app: App, event: Event) -> StateUpdate:

# Return the update.
return update


class EventNamespace(AsyncNamespace):
"""The event namespace."""

# The backend API object.
app: App

def on_connect(self, sid, environ):
"""Event for when the websocket disconnects.

Args:
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
pass

def on_disconnect(self, sid):
"""Event for when the websocket disconnects.

Args:
sid: The Socket.IO session id.
"""
pass

async def on_event(self, sid, data):
"""Event for receiving front-end websocket events.

Args:
sid: The Socket.IO session id.
data: The event data.
"""
# Get the event.
event = Event.parse_raw(data)

# Process the event.
update = await process(self.app, event)

# Emit the event.
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)

async def on_ping(self, sid):
"""Event for testing the API endpoint.

Args:
sid: The Socket.IO session id.
"""
# Emit the test event.
await self.emit(str(constants.SocketEvent.PING), "pong", to=sid)
10 changes: 2 additions & 8 deletions pynecone/compiler/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,8 @@ def format_state(
" if(!isReady) {{",
" return;",
" }}",
" const reconnectSocket = () => {{",
f" {SOCKET}.current.reconnect()",
" }}",
f" if (typeof {SOCKET}.current !== 'undefined') {{{{",
f" if (!{SOCKET}.current) {{{{",
f" window.addEventListener('focus', reconnectSocket)",
f" connect({SOCKET}, {{state}}, {{set_state}}, {RESULT}, {SET_RESULT}, {ROUTER}, {EVENT_ENDPOINT})",
" }}",
f" if (!{SOCKET}.current) {{{{",
f" connect({SOCKET}, {{state}}, {{set_state}}, {RESULT}, {SET_RESULT}, {ROUTER}, {EVENT_ENDPOINT})",
" }}",
" const update = async () => {{",
f" if ({RESULT}.{STATE} != null) {{{{",
Expand Down
16 changes: 15 additions & 1 deletion pynecone/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ class LogLevel(str, Enum):
class Endpoint(Enum):
"""Endpoints for the pynecone backend API."""

PING = "ping"
EVENT = "event"

def __str__(self) -> str:
Expand Down Expand Up @@ -192,6 +191,21 @@ def get_url(self) -> str:
return url


class SocketEvent(Enum):
"""Socket events sent by the pynecone backend API."""

PING = "ping"
EVENT = "event"

def __str__(self) -> str:
"""Get the string representation of the event name.

Returns:
The event name string.
"""
return str(self.value)


class RouteArgType(SimpleNamespace):
"""Type of dynamic route arg extracted from URI route."""

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ uvicorn = "^0.20.0"
rich = "^12.6.0"
redis = "^4.3.5"
httpx = "^0.23.1"
websockets = "^10.4"
python-socketio = "^5.7.2"
psutil = "^5.9.4"

[tool.poetry.dev-dependencies]
Expand Down