In [2]:
from jupyter_client import MultiKernelManager
from tqdm.notebook import tqdm

In [3]:
def get_output(client, msg_id):
    from queue import Empty
    try:
        while True:
            message = client.get_iopub_msg(timeout=10)
            if message["parent_header"]["msg_id"] != msg_id:
                continue
            if message["header"]["msg_type"] != "stream":
                continue
            content = message["content"]
            if content["name"] != "stdout":
                continue
            return content["text"].strip()
    except Empty as e:
        return None

In [4]:
def get_result(client):
    msg_id = client.execute("print(result)")
    return get_output(client, msg_id)

def get_duration(client):
    msg_id = client.execute("print(duration)")
    return get_output(client, msg_id)

In [5]:
from pathlib import Path
code = Path("benchmarks/roberta_finetuning.py").read_text()

In [6]:
mkm = MultiKernelManager()

In [7]:
num_kernels = 11
kernels = [mkm.get_kernel(mkm.start_kernel()) for _ in range(num_kernels)]
clients = [kernel.client() for kernel in kernels]

In [9]:
[kernel.provisioner.process.pid for kernel in kernels]

[119552,
 119555,
 119556,
 119557,
 119558,
 119559,
 119560,
 119561,
 119562,
 119563,
 119564]

In [16]:
[kernel.provisioner.process.poll() is None for kernel in kernels]

[True, True, True, True, True, True, True, True, True, True, True]

In [7]:
results = []
durations = []
for index, client in enumerate(tqdm(clients)):
    client.execute(code, reply=True)
    result = get_result(client)
    duration = get_duration(client)
    results.append(result)
    durations.append(duration)
    tqdm.write(f"{result} {duration}")

  0%|          | 0/11 [00:00<?, ?it/s]

0.475 10.826914548873901
0.25 11.057157039642334
0.425 10.537917852401733
0.225 10.454579830169678
0.675 10.310715675354004
0.7 10.794390678405762
0.225 10.732997179031372
0.225 10.505618572235107
out of memory 7.129721403121948
out of memory 5.981494903564453
out of memory 7.051535367965698


In [8]:
results

['0.475',
 '0.25',
 '0.425',
 '0.225',
 '0.675',
 '0.7',
 '0.225',
 '0.225',
 'out of memory',
 'out of memory',
 'out of memory']

In [10]:
mkm.shutdown_all()

: 

In [10]:
# len(clients)
# for kernel in kernels:
#     print(kernel.is_alive())
#     for client in clients:
#     print(client.is_alive())