# Faster transfers via torch.Tensor shared memory

In [1]:
import time
import numpy as np
import torch
import torch.multiprocessing as mp

def benchmark(func, n=10):
    times = []
    for _ in range(n):
        ts1 = time.time()
        func()
        times.append(time.time() - ts1)
    return float(np.mean(times)), float(np.std(times))

arrays = [np.random.normal(size=(N, 128, 128, 8)) for N in [1, 10, 100, 500, 1000]]
for i, arr in enumerate(arrays):
    print(f'Arr{i+1} Size: {arr.nbytes/1000/1000:.0f}MB')


def local_roundtrip(send_queue, arr):
    send_queue.put(arr)
    out = send_queue.get()
    assert(arr[0,0,0,0] == out[0,0,0,0])
    
def remote_roundtrip(send_queue, recv_queue, arr):
    send_queue.put(arr)
    out = recv_queue.get()
    assert(arr[0,0,0,0] == out[0,0,0,0])

def worker_fn(send_queue, recv_queue):
    while True:
        arr = send_queue.get()
        if arr is False:
            break
        recv_queue.put(arr)

def create_process(send_queue, recv_queue):
    p = mp.Process(target=worker_fn, args=(send_queue, recv_queue))
    p.start()
    return p

def join_process(send_queue, p):
    send_queue.put(False)
    p.join()
    
def benchmark_local():
    for i, arr in enumerate(arrays):
        send_queue = mp.Queue()
        mean,std = benchmark(lambda: local_roundtrip(send_queue, arr))
        print(f'Arr{i+1}: {1000*mean:.0f}ms +- {1000*std:.0f}')

def benchmark_remote():
    for i, arr in enumerate(arrays):
        send_queue = mp.Queue()
        recv_queue = mp.Queue()
        p = create_process(send_queue, recv_queue)
        mean,std = benchmark(lambda: remote_roundtrip(send_queue, recv_queue, arr))
        print(f'Arr{i+1}: {1000*mean:.0f}ms +- {1000*std:.0f}')
        join_process(send_queue, p)

Arr1 Size: 1MB
Arr2 Size: 10MB
Arr3 Size: 105MB
Arr4 Size: 524MB
Arr5 Size: 1049MB


### Current serialization performance (local roundtrip)

In [2]:
benchmark_local()

Arr1: 2ms +- 1
Arr2: 26ms +- 4
Arr3: 433ms +- 3
Arr4: 2294ms +- 79
Arr5: 4575ms +- 97


### Current serialization performance (remote roundtrip)

In [3]:
benchmark_remote()

Arr1: 3ms +- 3
Arr2: 59ms +- 5
Arr3: 872ms +- 10
Arr4: 4487ms +- 113
Arr5: 8935ms +- 148


## Proposed Solution (Proof of Concept)

**It's important that we recreate any Queues afterward for our new reduction method to be registered**

In [4]:
from multiprocessing.reduction import ForkingPickler

def rebuild_ndarray(tensor, dtype):
    return tensor.numpy().view(dtype)

def reduce_ndarray(arr):
    tensor = torch.as_tensor(arr.view(np.int8))  # always interpret as raw bytes to support stuff like np.datetime64 as well
    return (rebuild_ndarray, (tensor, arr.dtype))

ForkingPickler.register(np.ndarray, reduce_ndarray)

### Proposed Performance (local roundtrip)

In [5]:
benchmark_local()

Arr1: 1ms +- 1
Arr2: 4ms +- 1
Arr3: 47ms +- 4
Arr4: 242ms +- 8
Arr5: 435ms +- 22


### Current serialization performance (local roundtrip)

In [6]:
benchmark_remote()

Arr1: 15ms +- 18
Arr2: 23ms +- 15
Arr3: 189ms +- 19
Arr4: 565ms +- 34
Arr5: 940ms +- 32


## Actually sharing memory

Notice that the current solution **always** results in a copy to shared memory when pickling. While this significantly speeds up transmissions, it still results in duplicated memory usage (assuming that the same array is either kept in the main process or shared across multiple workers).

While we can easily create a numpy array that is a view to a torch.Tensor in shared memory (`arr = torch.as_tensor(arr).share_memory_().numpy()`), when serializing such an array, despite being already in shared memory, we would copy it again. This is because torch only looks at the storage type of a Tensor to determine if its already in shared memory and not at the actual address.

Luckily for us, torch and numpy keep track of who actually owns the memory and thus `np.ndarray.base` will point to the original `torch.Tensor` on which we called `.numpy()`. Unfortunately for us, in case of slices, there is one more indirection.

The following code demonstrates this approach:

In [7]:
def rebuild_ndarray(tensor, metainfo):
    offset, shape, strides, typestr = metainfo
    buffer = tensor.numpy()
    return np.ndarray(buffer=buffer,offset=offset, shape=shape, strides=strides, dtype=typestr)

def reduce_ndarray(arr: np.ndarray):
    shape = arr.__array_interface__['shape']
    strides = arr.__array_interface__['strides']
    typestr = arr.__array_interface__['typestr']
    
    base = arr.base
    while type(base) is np.ndarray and base.base is not None:  # only support pure np.ndarray's for now
        base = base.base

    if isinstance(base, torch.Tensor):
        tensor = base
        offset = np.asarray(base).__array_interface__['data'][0] - arr.__array_interface__['data'][0]
    else:
        tensor = torch.as_tensor(arr.view(np.int8))
        offset = 0
    
    return (rebuild_ndarray, (tensor, (offset,shape,strides,typestr)))


def share_memory(arr: np.ndarray) -> np.ndarray:
    tensor = torch.as_tensor(arr.view(np.int8)).share_memory_()
    return tensor.numpy()


ForkingPickler.register(np.ndarray, reduce_ndarray)

### Remote rountrip (assuming memory is already shared)

In [8]:
for i, arr in enumerate(arrays):
    send_queue = mp.Queue()
    recv_queue = mp.Queue()
    p = create_process(send_queue, recv_queue)
    arr = share_memory(arr)
    mean,std = benchmark(lambda: remote_roundtrip(send_queue, recv_queue, arr))
    print(f'Arr{i+1}: {1000*mean:.0f}ms +- {1000*std:.0f}')
    join_process(send_queue, p)

Arr1: 4ms +- 8
Arr2: 3ms +- 5
Arr3: 3ms +- 6
Arr4: 2ms +- 3
Arr5: 2ms +- 3


### Remote rountrip (sharing done during first serialization)
-> Time is now /2 because we only need to copy to shared memory during the first serialization. The backtrip is for free.

In [9]:
for i, arr in enumerate(arrays):
    send_queue = mp.Queue()
    recv_queue = mp.Queue()
    p = create_process(send_queue, recv_queue)
    mean,std = benchmark(lambda: remote_roundtrip(send_queue, recv_queue, arr))
    print(f'Arr{i+1}: {1000*mean:.0f}ms +- {1000*std:.0f}')
    join_process(send_queue, p)

Arr1: 3ms +- 4
Arr2: 8ms +- 2
Arr3: 55ms +- 6
Arr4: 256ms +- 23
Arr5: 411ms +- 39
