In [None]:
from jupyter_client import MultiKernelManager
from tqdm.notebook import tqdm

In [None]:
def get_output(client, msg_id):
    from queue import Empty
    try:
        while True:
            message = client.get_iopub_msg(timeout=10)
            if message["parent_header"]["msg_id"] != msg_id:
                continue
            if message["header"]["msg_type"] != "stream":
                continue
            content = message["content"]
            if content["name"] != "stdout":
                continue
            return content["text"].strip()
    except Empty as e:
        return None

In [None]:
def get_result(client):
    msg_id = client.execute("print(result)")
    return get_output(client, msg_id)

def get_duration(client):
    msg_id = client.execute("print(duration)")
    return get_output(client, msg_id)

In [None]:
from pathlib import Path
code = Path("benchmarks/roberta_finetuning.py").read_text()

In [None]:
mkm = MultiKernelManager()

In [None]:
num_kernels = 11
kernels = [mkm.get_kernel(mkm.start_kernel()) for _ in range(num_kernels)]
clients = [kernel.client() for kernel in kernels]

In [None]:
[kernel.provisioner.process.pid for kernel in kernels]

In [None]:
[kernel.provisioner.process.poll() is None for kernel in kernels]

In [None]:
results = []
durations = []
for index, client in enumerate(tqdm(clients)):
    client.execute(code, reply=True)
    result = get_result(client)
    duration = get_duration(client)
    results.append(result)
    durations.append(duration)
    tqdm.write(f"{result} {duration}")

In [None]:
results

In [None]:
mkm.shutdown_all()

In [None]:
# len(clients)
# for kernel in kernels:
#     print(kernel.is_alive())
#     for client in clients:
#     print(client.is_alive())