Skip to content

Commit

Permalink
Implement webrtc_thread
Browse files Browse the repository at this point in the history
  • Loading branch information
whitphx committed Jan 24, 2021
1 parent 7fbf0eb commit 093f81b
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions tiny_streamlit_webrtc/__init__.py
@@ -1,6 +1,8 @@
import asyncio
import logging
import threading
import os
import queue
import streamlit.components.v1 as components
from aiortc import RTCPeerConnection, RTCSessionDescription

Expand All @@ -21,12 +23,10 @@
_component_func = components.declare_component("tiny_streamlit_webrtc", path=build_dir)


session_state = SessionState.get(answer=None)
session_state = SessionState.get(answer=None, webrtc_thread=None)


async def process_offer(offer: RTCSessionDescription) -> RTCPeerConnection:
pc = RTCPeerConnection()

async def process_offer(pc: RTCPeerConnection, offer: RTCSessionDescription) -> RTCPeerConnection:
@pc.on("track")
def on_track(track):
logger.info("Track %s received", track.kind)
Expand All @@ -42,6 +42,31 @@ def on_track(track):
return pc


def webrtc_worker(offer: RTCSessionDescription, answer_queue: queue.Queue):
pc = RTCPeerConnection()

loop = asyncio.new_event_loop()

task = loop.create_task(process_offer(pc, offer))


def done_callback(task: asyncio.Task):
pc: RTCPeerConnection = task.result()
answer_queue.put(pc.localDescription)


task.add_done_callback(done_callback)

try:
loop.run_forever()
finally:
logger.debug("Event loop %s has stopped.", loop)
loop.run_until_complete(pc.close())
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
logger.debug("Event loop %s cleaned up.", loop)


def tiny_streamlit_webrtc(key):
answer = session_state.answer
if answer:
Expand All @@ -64,16 +89,21 @@ def tiny_streamlit_webrtc(key):
if not answer:
offer = RTCSessionDescription(sdp=offer_json["sdp"], type=offer_json["type"])

pc = asyncio.run(process_offer(offer))
logger.info("process_offer() is completed and RTCPeerConnection is set up: %s", pc)
answer_queue = queue.Queue()
webrtc_thread = threading.Thread(target=webrtc_worker, args=(offer, answer_queue))
webrtc_thread.start()
session_state.webrtc_thread = webrtc_thread

answer = answer_queue.get(timeout=10)

# Debug
st.write(pc.localDescription)
st.write(answer)

session_state.answer = pc.localDescription
logger.info("Answer: %s", answer)
session_state.answer = answer
logger.info("Rerun to send it back")
st.experimental_rerun()


return component_value


Expand Down

0 comments on commit 093f81b

Please sign in to comment.