diff --git a/tiny_streamlit_webrtc/__init__.py b/tiny_streamlit_webrtc/__init__.py index 9bc7017..37f41e2 100644 --- a/tiny_streamlit_webrtc/__init__.py +++ b/tiny_streamlit_webrtc/__init__.py @@ -4,7 +4,9 @@ import os import queue import streamlit.components.v1 as components -from aiortc import RTCPeerConnection, RTCSessionDescription +import cv2 +from av import VideoFrame +from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription import SessionState @@ -26,11 +28,37 @@ session_state = SessionState.get(answer=None, webrtc_thread=None) +class VideoTransformTrack(MediaStreamTrack): + + kind = "video" + + def __init__(self, track): + super().__init__() # don't forget this! + self.track = track + + async def recv(self): + frame = await self.track.recv() + + # perform edge detection + img = frame.to_ndarray(format="bgr24") + img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) + + # rebuild a VideoFrame, preserving timing information + new_frame = VideoFrame.from_ndarray(img, format="bgr24") + new_frame.pts = frame.pts + new_frame.time_base = frame.time_base + return new_frame + + async def process_offer(pc: RTCPeerConnection, offer: RTCSessionDescription) -> RTCPeerConnection: @pc.on("track") def on_track(track): logger.info("Track %s received", track.kind) - pc.addTrack(track) # Passthrough. TODO: Implement video transformation + if track.kind == "audio": + pc.addTrack(track) # Passthrough + elif track.kind == "video": + local_video = VideoTransformTrack(track) + pc.addTrack(local_video) # handle offer await pc.setRemoteDescription(offer)