In [2]:
import rpyc
import labbench as lb
from labbench import paramattr as attr
from typing import Type
import multiprocessing
import numpy as np
from multiprocessing.sharedctypes import RawArray
import ctypes

N = 500_000_000


class MyDevice(lb.Device):
    size: int = attr.value.int(default=500_000_000, min=1)
    backend = None

    def open(self):
        self.shared_array = RawArray(ctypes.c_double, N)
        self.x = np.ndarray(N, dtype=np.float32, buffer=self.shared_array)
        print('open!')

    def close(self):
        print('close!')

    def big_output(self):
        pass


class SubprocessDeviceAdapter:
    conn: rpyc.core.protocol.Connection
    device: lb.Device

    def __init__(self, conn, device):
        self.conn = conn
        self.device = device

    def open(self):
        # conn is already open
        self.device.open()
        return self.device

    def close(self):
        try:
            if self.device is not None:
                self.device.close()
        finally:
            self.conn.close()

    def __enter__(self, *args, **kws):
        print('adapter enter')
        self.open()
        return self.device

    def __exit__(self, *args, **kws):
        print('adapter exit')
        self.close()


class DeviceService(rpyc.Service):
    def __init__(self, radio_type: Type[MyDevice], *args, **kws):
        self.device = radio_type(*args, **kws)


def spawn_device(cls: lb.Device, *args, **kws):
    """return a context manager that opens `cls` instantiated with the given arguments in another process """

    service = rpyc.utils.helpers.classpartial(DeviceService, cls, *args, **kws)
    conf = {'allow_all_attrs': True}#, 'logger': lb.logger}
    svc = rpyc.OneShotServer(service=service, protocol_config=conf)
    ctx = multiprocessing.get_context('fork')
    ctx.Process(target=svc.start).start()

    conn = rpyc.connect('localhost', svc.port)
    return SubprocessDeviceAdapter(conn, conn.root.device)


with spawn_device(MyDevice, size=4) as device, lb._host.Host() as host:
    device._logger.info('message!')
    print('value: ', device.size)
    x = device.big_output()
    %timeit -n1 -r1 device.big_output()
    logs = host.log
    for log in logs:
        print(log['message'])


adapter enter
open!


[1;30m INFO  [0m [32m2024-07-29 10:21:25,956.957[0m • [34mMyDevice():[0m message!


close!
value:  4
362 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
running in git repository at /Users/dkuester/Documents/src/flex-spectrum-sensor
git_commit_id → 'a0b4712e2659ca96197e1201306262ae1096372f'
git_remote_url → 'https://github.com/usnistgov/spectrum-sensor-edge-analysis'
git_browse_url → 'https://github.com/usnistgov/spectrum-sensor-edge-analysis/tree/a0b4712e2659ca96197e1201306262ae1096372f'
git_pending_changes → []
opened
adapter exit


In [10]:
x

EOFError: stream has been closed