In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import ipyparallel as ipp
print(ipp.version_info)

(8, 6, 1)


In [3]:
nproc = 4
cluster = ipp.Cluster(engines='mpi', n=nproc, shutdown_atexit=False)
print(cluster)
client = cluster.start_and_connect_sync(activate=True)
view = client[:]
client.ids

<Cluster(cluster_id='1690102733-v0s9', profile='default')>
Starting 4 engines with <class 'ipyparallel.cluster.launcher.MPIEngineSetLauncher'>
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.49s/engine]


[0, 1, 2, 3]

In [4]:
!ipcluster list

PROFILE          CLUSTER ID                       RUNNING ENGINES LAUNCHER
default          1690102733-v0s9                  True          4 MPI


In [5]:
%autopx --block --group-outputs=engines

%autopx enabled


In [6]:
import numpy as np
from mpi4py import MPI
from time import sleep

from qtmpy.mpi.utils import scatter_range, scatter_len

comm = MPI.COMM_WORLD
size, rank = comm.Get_size(), comm.Get_rank()

def print_msg(msg: str):
    print(f"{rank}/{size}: {msg}")

def print_seq(msg):
    fac = 0.1
    sleep(rank * fac)
    print()
    print(msg)
    sleep((size - rank) * fac)

# Setting the length along x-axis and num of sticks
nx, ns = (5, 9)

# Creating global data and slicing it for local data
glob_shape = (nx, ns)
glob_array = np.arange(np.prod(glob_shape)).reshape(glob_shape)

loc_range = scatter_range(range(glob_shape[1]), size, rank)
loc_start, loc_stop = loc_range.start, loc_range.stop
loc_shape = (glob_shape[0], loc_stop - loc_start)

loc_array = np.empty(loc_shape)
loc_array[:] = glob_array[:, loc_start:loc_stop]
print_seq(loc_array)

[stdout:0] 
[[ 0.  1.  2.]
 [ 9. 10. 11.]
 [18. 19. 20.]
 [27. 28. 29.]
 [36. 37. 38.]]


[stdout:1] 
[[ 3.  4.]
 [12. 13.]
 [21. 22.]
 [30. 31.]
 [39. 40.]]


[stdout:2] 
[[ 5.  6.]
 [14. 15.]
 [23. 24.]
 [32. 33.]
 [41. 42.]]


[stdout:3] 
[[ 7.  8.]
 [16. 17.]
 [25. 26.]
 [34. 35.]
 [43. 44.]]


In [16]:
def mpi_transpose(comm, inp, out):    
    glob_shape = (inp.shape[0], out.shape[0])
    size, rank = comm.Get_size(), comm.Get_rank()
    sendcount = scatter_len(glob_shape[0], size) \
        * scatter_len(glob_shape[1], size, rank)
    recvcount = scatter_len(glob_shape[1], size) \
        * scatter_len(glob_shape[0], size, rank)

    comm.Alltoallv((inp, sendcount), (out.ravel(), recvcount))
    chunks = tuple(
        chunk.reshape((scatter_len(glob_shape[0], size, rank), -1)).T
        for chunk in np.split(out.ravel(), np.cumsum(recvcount[:-1]))
    )
    np.concatenate(chunks, axis=0, out=out)

inp = loc_array
out = np.empty((glob_shape[1], scatter_len(glob_shape[0], size, rank)))

mpi_transpose(comm, inp, out)
print_seq(out)

[stdout:0] 
[[ 0.  9.]
 [ 1. 10.]
 [ 2. 11.]
 [ 3. 12.]
 [ 4. 13.]
 [ 5. 14.]
 [ 6. 15.]
 [ 7. 16.]
 [ 8. 17.]]


[stdout:1] 
[[18.]
 [19.]
 [20.]
 [21.]
 [22.]
 [23.]
 [24.]
 [25.]
 [26.]]


[stdout:3] 
[[36.]
 [37.]
 [38.]
 [39.]
 [40.]
 [41.]
 [42.]
 [43.]
 [44.]]


[stdout:2] 
[[27.]
 [28.]
 [29.]
 [30.]
 [31.]
 [32.]
 [33.]
 [34.]
 [35.]]


In [47]:
sendbuf = loc_array
sendcount = scatter_len(nx, size) * scatter_len(ns, size, rank)
recvcount = scatter_len(ns, size) * scatter_len(nx, size, rank)
recvbuf = np.empty(np.sum(recvcount))

comm.Alltoallv((sendbuf, sendcount), (recvbuf, recvcount))

chunks = tuple(chunk.reshape((scatter_len(nx, size, rank), -1)).T
    for chunk in np.split(recvbuf, np.cumsum(recvcount))[:-1]
)
recvbuf = recvbuf.reshape((-1, scatter_len(nx, size, rank)))
np.concatenate(chunks, axis=0, out=recvbuf)
print_seq(recvbuf)

[stdout:0] 
[[ 0.  9.]
 [ 1. 10.]
 [ 2. 11.]
 [ 3. 12.]
 [ 4. 13.]
 [ 5. 14.]
 [ 6. 15.]
 [ 7. 16.]
 [ 8. 17.]]


[stdout:1] 
[[18.]
 [19.]
 [20.]
 [21.]
 [22.]
 [23.]
 [24.]
 [25.]
 [26.]]


[stdout:3] 
[[36.]
 [37.]
 [38.]
 [39.]
 [40.]
 [41.]
 [42.]
 [43.]
 [44.]]


[stdout:2] 
[[27.]
 [28.]
 [29.]
 [30.]
 [31.]
 [32.]
 [33.]
 [34.]
 [35.]]


In [48]:
recvbuf = recvbuf.ravel()
sendbuf = np.empty_like(loc_array.ravel())
comm.Alltoallv((recvbuf, recvcount), (sendbuf, sendcount))

chunks = tuple(chunk.reshape((scatter_len(ns, size, rank), -1)).T
    for chunk in np.split(sendbuf, np.cumsum(sendcount))[:-1]
)

sendbuf = sendbuf.reshape((-1, scatter_len(ns, size, rank)))
np.concatenate(chunks, axis=0, out=sendbuf)
print_seq(sendbuf)

[stdout:0] 
[[ 0.  1.  2.]
 [ 9. 10. 11.]
 [18. 19. 20.]
 [27. 28. 29.]
 [36. 37. 38.]]


[stdout:1] 
[[ 3.  4.]
 [12. 13.]
 [21. 22.]
 [30. 31.]
 [39. 40.]]


[stdout:2] 
[[ 5.  6.]
 [14. 15.]
 [23. 24.]
 [32. 33.]
 [41. 42.]]


[stdout:3] 
[[ 7.  8.]
 [16. 17.]
 [25. 26.]
 [34. 35.]
 [43. 44.]]


In [None]:
# sendbuf = loc_array

# # recvbuf = np.empty(
# #     (distribute_len(glob_shape[0], size, rank), glob_shape[1])
# # )

# recvbuf = np.empty(
#     (glob_shape[1], distribute_len(glob_shape[0], size, rank))
# )

# send_counts = distribute_len(glob_shape[0], size) * distribute_len(glob_shape[1], size, rank)
# recv_counts = distribute_len(glob_shape[0], size, rank) * distribute_len(glob_shape[1], size)

# recvbuf = np.empty(np.sum(recv_counts))

# comm.Alltoallv((sendbuf, send_counts), (recvbuf, recv_counts))
# recvbuf = recvbuf.reshape(-1, distribute_len(glob_shape[0], size, rank))
# # print(recvbuf)
# chunksizes = distribute_len(recvbuf.shape[0], size)
# # print(chunksizes, np.cumsum(chunksizes))
# chunks = tuple(chunk.ravel().reshape(-1, chunksizes[ichunk]) for ichunk, chunk in
#           enumerate(np.split(recvbuf, np.cumsum(chunksizes), axis=0)[:-1])
# )
# # print(chunks)
# out_arr = np.concatenate(chunks, axis=1)
# print_seq(out_arr)

In [17]:
%autopx --block --group-outputs=engine

%autopx disabled


In [18]:
cluster.stop_controller_sync()

Stopping controller
Controller stopped: {'exit_code': 0, 'pid': 17861, 'identifier': 'ipcontroller-1690102733-v0s9-17846'}


In [20]:
!ipcluster list

PROFILE          CLUSTER ID                       RUNNING ENGINES LAUNCHER
default          1690102733-v0s9                  False         4 MPI
default          1690104384-b2sd                  False         4 MPI
engine set stopped 1690102734: {'exit_code': 0, 'pid': 17890, 'identifier': 'ipengine-1690102733-v0s9-1690102734-17846'}
