In [1]:
import io

import PIL.Image

import ipywidgets
import ipywebrtc

In [2]:
def transform_image_stream(stream, transform) -> ipywidgets.Image:
    recorder = ipywebrtc.ImageRecorder(stream=camera)
    output = ipywidgets.Image()
    
    def on_value_changed(event):
        input = event["new"]
        image = PIL.Image.open(io.BytesIO(input))
        with io.BytesIO() as buffer:
            transform(image).save(buffer, format="PNG")
            output.value = buffer.getvalue()
            
        recorder.recording = True
    
    recorder.image.observe(on_value_changed, "value")
    recorder.recording = True
    
    return output, recorder

In [3]:
from pystiche import demo
from torch import hub

url = "https://download.pystiche.org/models/example_transformer_ptcv.pth"
state_dict = hub.load_state_dict_from_url(url)
transformer = demo.transformer()
transformer.load_state_dict(state_dict)
_ = transformer.cuda()

In [4]:
from torchvision.transforms.functional import to_tensor, to_pil_image

def nst_transform(image: PIL.Image.Image) -> PIL.Image.Image:
    image = to_tensor(image.convert("RGB")).cuda()
    # transformer expects a batched image
    image = transformer(image.unsqueeze(0)).squeeze(0)
    return to_pil_image(image.cpu())

In [5]:
camera = ipywebrtc.CameraStream.facing_user(audio=False, constraints=dict(video=dict(width=400)))

output, recorder = transform_image_stream(camera, nst_transform)

ipywidgets.HBox([camera, output])

HBox(children=(CameraStream(constraints={'video': {'width': 400, 'facingMode': 'user'}, 'audio': False}), Imag…

In [6]:
recorder

ImageRecorder(image=Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01`\x00\x00\x01 \x08\x06\x00\x0…