<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Slice-specified-nodes-in-dimspec" data-toc-modified-id="Slice-specified-nodes-in-dimspec-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Slice specified nodes in dimspec</a></span></li><li><span><a href="#Test-parallelism" data-toc-modified-id="Test-parallelism-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Test parallelism</a></span><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#Example-task" data-toc-modified-id="Example-task-2.0.1"><span class="toc-item-num">2.0.1&nbsp;&nbsp;</span>Example task</a></span></li></ul></li><li><span><a href="#Use-ray" data-toc-modified-id="Use-ray-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Use ray</a></span></li><li><span><a href="#Simple-invocation" data-toc-modified-id="Simple-invocation-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Simple invocation</a></span></li></ul></li></ul></div>

In [56]:
import ray
import pyrofiler as pyrof
import numpy as np

# Slice specified nodes in dimspec

In [76]:
def _none_slice():
    return slice(None)

def _get_idx(x, idxs, slice_idx, shapes=None):
    if shapes is None:
        shapes = [2]*len(idxs)
    point = np.unravel_index(slice_idx, shapes)
    get_point = {i:p for i,p in zip(idxs, point)}
    if x in idxs:
        p = get_point[x]
        return slice(p,p+1)
    else:
        return _none_slice()

def _slices_for_idxs(idxs, *args, shapes=None, slice_idx=0):
    """Return array of slices along idxs"""
    slices = []
    for indexes in args:
        _slice = [_get_idx(x, idxs, slice_idx, shapes) for x in indexes ]
        slices.append(tuple(_slice))
    return slices
        

In [77]:
dims1 = [1,3,4 ]
dims2 = [2,4,3, 5]
contract = [dims1, dims2]

slice_among = [4, 3]
shapes = [2, 3]

test_slices = [
    _slices_for_idxs(slice_among, *contract, shapes=shapes, slice_idx=i)
    for i in range(4)
    ]
[print(x) for x in test_slices]


[(slice(None, None, None), slice(0, 1, None), slice(0, 1, None)), (slice(None, None, None), slice(0, 1, None), slice(0, 1, None), slice(None, None, None))]
[(slice(None, None, None), slice(1, 2, None), slice(0, 1, None)), (slice(None, None, None), slice(0, 1, None), slice(1, 2, None), slice(None, None, None))]
[(slice(None, None, None), slice(2, 3, None), slice(0, 1, None)), (slice(None, None, None), slice(0, 1, None), slice(2, 3, None), slice(None, None, None))]
[(slice(None, None, None), slice(0, 1, None), slice(1, 2, None)), (slice(None, None, None), slice(1, 2, None), slice(0, 1, None), slice(None, None, None))]


[None, None, None, None]

# Test parallelism
### Example task

In [66]:
def get_example_task():
    A = 14
    B, C = 7, 5
    shape1 = [2]*(A+B)
    shape2 = [2]*(A+C)
    T1 = np.random.randn(*shape1)
    T2 = np.random.randn(*shape2)
    common = list(range(A))
    idxs1 = common + list(range(A, A+B))
    idxs2 = common + list(range(A+B, A+B+C))
    return (T1, idxs1), (T2, idxs2)

x, y = get_example_task()
x[1], y[1]

([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 21, 22, 23, 24, 25])

## Use ray

In [62]:
ray.init()

Exception: Perhaps you called ray.init twice by accident? This error can be suppressed by passing in 'ignore_reinit_error=True' or by calling 'ray.shutdown()' prior to 'ray.init()'.

## Simple invocation

In [193]:

#@ray.remote
def contract(A, B):
    a, idxa = A
    b, idxb = B
    contract_idx = set(idxa) & set(idxb)
    result_idx = set(idxa + idxb)
    C = np.einsum(a,idxa, b,idxb, result_idx)
    return C

with pyrof.timing('contract'):
    C = contract(x, y)


contract : 0.4512217044830322


In [205]:
contract_idx = set(x[1]) & set(y[1])
result_idx = set(x[1] + y[1])

def sliced_contract(x, y, idxs, num):
    slices = _slices_for_idxs(idxs, x[1], y[1], slice_idx=num)
    a = x[0][slices[0]]
    b = y[0][slices[1]]
    with pyrof.timing('contract'):
        C = contract((a, x[1]), (b, y[1]))
    return C


def target_slice(result_idx, idxs, num):
    slices = _slices_for_idxs(idxs, result_idx, slice_idx=num)
    return slices

C = contract(x,y)
target_shape = C.shape
C0 = sliced_contract(x, y, [17], 0)
C1 = sliced_contract(x, y, [17], 1)

C_par = np.empty(target_shape)
s0 = target_slice(result_idx, [17], 0)
s1 = target_slice(result_idx, [17], 1)
C_par[s0[0]] = C0
C_par[s1[0]] = C1

assert np.array_equal(C, C_par)


contract : 0.6614696979522705
contract : 1.6132872104644775


In [200]:
contract_idx = set(x[1]) & set(y[1])
slices = _slices_for_idxs([1, 15], x[1], y[1])

a = x[0][slices[0]]
b = y[0][slices[1]]
print(a.shape, b.shape)

with pyrof.timing('contract'):
    C = contract((a, x[1]), (b, y[1]))
C.shape

(2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2) (2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
contract : 0.22393345832824707


(2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)

In [64]:
x = np.zeros(10)
x_id = ray.put(x)
print(x)
f = increment.remote(x_id)
print(ray.get(f))
print(x)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


RayTaskError(ValueError): [36mray::__main__.increment()[39m (pid=32540, ip=130.202.136.154)
  File "python/ray/_raylet.pyx", line 452, in ray._raylet.execute_task
  File "<ipython-input-13-d161bfe5d285>", line 3, in increment
ValueError: assignment destination is read-only

ObjectID(ffffffffffffffffffffffff0100008003000000)


In [40]:
sl = [slice(0,1)]*7
xsl = x[0][tuple(sl)]
xsl.shape

(1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2)