Skip to content

Commit

Permalink
Introduce SessionState.py to hold answer object over executions and s…
Browse files Browse the repository at this point in the history
…end it back to frontend with st.experimental_rerun()
  • Loading branch information
whitphx committed Jan 24, 2021
1 parent a6f7cc0 commit aa2ab49
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 9 deletions.
117 changes: 117 additions & 0 deletions tiny_streamlit_webrtc/SessionState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Hack to add per-session state to Streamlit.
Usage
-----
>>> import SessionState
>>>
>>> session_state = SessionState.get(user_name='', favorite_color='black')
>>> session_state.user_name
''
>>> session_state.user_name = 'Mary'
>>> session_state.favorite_color
'black'
Since you set user_name above, next time your script runs this will be the
result:
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
'Mary'
"""
try:
import streamlit.ReportThread as ReportThread
from streamlit.server.Server import Server
except Exception:
# Streamlit >= 0.65.0
import streamlit.report_thread as ReportThread
from streamlit.server.server import Server


class SessionState(object):
def __init__(self, **kwargs):
"""A new SessionState object.
Parameters
----------
**kwargs : any
Default values for the session state.
Example
-------
>>> session_state = SessionState(user_name='', favorite_color='black')
>>> session_state.user_name = 'Mary'
''
>>> session_state.favorite_color
'black'
"""
for key, val in kwargs.items():
setattr(self, key, val)


def get(**kwargs):
"""Gets a SessionState object for the current session.
Creates a new object if necessary.
Parameters
----------
**kwargs : any
Default values you want to add to the session state, if we're creating a
new one.
Example
-------
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
''
>>> session_state.user_name = 'Mary'
>>> session_state.favorite_color
'black'
Since you set user_name above, next time your script runs this will be the
result:
>>> session_state = get(user_name='', favorite_color='black')
>>> session_state.user_name
'Mary'
"""
# Hack to get the session object from Streamlit.

ctx = ReportThread.get_report_ctx()

this_session = None

current_server = Server.get_current()
if hasattr(current_server, '_session_infos'):
# Streamlit < 0.56
session_infos = Server.get_current()._session_infos.values()
else:
session_infos = Server.get_current()._session_info_by_id.values()

for session_info in session_infos:
s = session_info.session
if (
# Streamlit < 0.54.0
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
or
# Streamlit >= 0.54.0
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
or
# Streamlit >= 0.65.2
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
):
this_session = s

if this_session is None:
raise RuntimeError(
"Oh noes. Couldn't get your Streamlit Session object. "
'Are you doing something fancy with threads?')

# Got the session object! Now let's attach some state into it.

if not hasattr(this_session, '_custom_session_state'):
this_session._custom_session_state = SessionState(**kwargs)

return this_session._custom_session_state
36 changes: 27 additions & 9 deletions tiny_streamlit_webrtc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import streamlit.components.v1 as components
from aiortc import RTCPeerConnection, RTCSessionDescription

import SessionState

logger = logging.getLogger(__name__)

_RELEASE = False
Expand All @@ -19,6 +21,9 @@
_component_func = components.declare_component("tiny_streamlit_webrtc", path=build_dir)


session_state = SessionState.get(answer=None)


async def process_offer(offer: RTCSessionDescription) -> RTCPeerConnection:
pc = RTCPeerConnection()

Expand All @@ -37,23 +42,36 @@ def on_track(track):
return pc


def tiny_streamlit_webrtc(key=None):
component_value = _component_func(key=key, default=None)
def tiny_streamlit_webrtc(key):
answer = session_state.answer
if answer:
answer_dict = {
"sdp": answer.sdp,
"type": answer.type,
}
else:
answer_dict = None

component_value = _component_func(key=key, answer=answer_dict, default=None)

if component_value:
offer_json = component_value["offerJson"]

# Debug
st.write(offer_json)

offer = RTCSessionDescription(sdp=offer_json["sdp"], type=offer_json["type"])
# To prevent an infinite loop, check whether `answer` already exists or not.
if not answer:
offer = RTCSessionDescription(sdp=offer_json["sdp"], type=offer_json["type"])

pc = asyncio.run(process_offer(offer))
logger.info("process_offer() is completed and RTCPeerConnection is set up: %s", pc)
pc = asyncio.run(process_offer(offer))
logger.info("process_offer() is completed and RTCPeerConnection is set up: %s", pc)

# Debug
st.write(pc.localDescription)
# TODO: How to send back the answer to frontend?
# Debug
st.write(pc.localDescription)

session_state.answer = pc.localDescription
st.experimental_rerun()


return component_value
Expand All @@ -62,4 +80,4 @@ def tiny_streamlit_webrtc(key=None):
if not _RELEASE:
import streamlit as st

tiny_streamlit_webrtc()
tiny_streamlit_webrtc(key='foo')
4 changes: 4 additions & 0 deletions tiny_streamlit_webrtc/frontend/src/TinyWebrtc.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class TinyWebrtc extends StreamlitComponentBase<State> {
this.videoRef = React.createRef()
}

public componentDidUpdate = () => {
console.log("Answer: ", this.props.args["answer"])
}

public render = (): ReactNode => {
return (
<div>
Expand Down

0 comments on commit aa2ab49

Please sign in to comment.