In [1]:
from collections import defaultdict
from functools import reduce
from operator import add, mul

import numpy as np

In [2]:
A = np.random.rand(300, 20)
B = np.random.rand(20, 600)

# Standard way

In [3]:
%%time
np.dot(A, B)

CPU times: user 0 ns, sys: 1.96 ms, total: 1.96 ms
Wall time: 723 µs


array([[6.16930888, 5.70928784, 4.40781217, ..., 5.4728335 , 4.40986864,
        4.60502551],
       [6.04394782, 5.76211914, 4.293568  , ..., 6.31096073, 3.92713956,
        5.62918246],
       [7.27500951, 6.05274759, 5.32123079, ..., 5.79766562, 5.60620197,
        5.26196859],
       ...,
       [5.16390022, 4.35032813, 3.36051932, ..., 4.22210561, 3.18490447,
        3.68696329],
       [6.3738467 , 5.87434843, 4.30928203, ..., 5.31106715, 4.29729707,
        5.2801052 ],
       [6.69297199, 5.52734044, 4.55181578, ..., 6.13393185, 4.63787979,
        5.57002845]])

# MapReduce

In [4]:
def merge_defaultdicts(d1, d2):
    for k, v in d2.items():
        d1[k].extend(v)
    return d1

In [5]:
def matrix_mapper(matrix, r, flag):
    results = defaultdict(list)
    with np.nditer(matrix, flags = ['multi_index']) as t:
        for el in t:
            if flag == 0:
                for k in range(r):
                    results[(t.multi_index[0], k)].append((flag, t.multi_index[1], el.item()))
            elif flag == 1:
                for i in range(r):
                    results[(i, t.multi_index[1])].append((flag, t.multi_index[0], el.item()))
    return results

In [6]:
def mapper(matrix_A, matrix_B):
    d1 = matrix_mapper(matrix_A, matrix_B.shape[1], 0)
    d2 = matrix_mapper(matrix_B, matrix_A.shape[0], 1)
    return merge_defaultdicts(d1, d2)

In [7]:
def reducer(record):
    idx, elements = record
    return (
        idx, 
        reduce(
            add, # Sum
            map(
                lambda x: x[0][2] * x[1][2], # Multiplication
                zip(
                    filter(lambda x: x[0] == 0, elements), # A matrix
                    filter(lambda x: x[0] == 1, elements)  # B matrix
                )
            )
        )
    )

In [8]:
%%time
list(map(reducer, mapper(A, B).items()))

CPU times: user 6.15 s, sys: 524 ms, total: 6.68 s
Wall time: 6.33 s


[((0, 0), 6.1693088793088675),
 ((0, 1), 5.709287837242178),
 ((0, 2), 4.4078121654729205),
 ((0, 3), 4.0101206653876025),
 ((0, 4), 3.8225004736072012),
 ((0, 5), 4.380290233516659),
 ((0, 6), 6.130258323226468),
 ((0, 7), 5.706815003100942),
 ((0, 8), 6.058521574558042),
 ((0, 9), 5.4964875608472115),
 ((0, 10), 5.585623522681542),
 ((0, 11), 5.174443294214784),
 ((0, 12), 5.319704778708427),
 ((0, 13), 4.984366253947198),
 ((0, 14), 6.121030197757768),
 ((0, 15), 4.189556878955895),
 ((0, 16), 4.6114070628025345),
 ((0, 17), 4.546543329628882),
 ((0, 18), 4.6123718581490865),
 ((0, 19), 4.162287355958404),
 ((0, 20), 5.9666088944513245),
 ((0, 21), 5.999503218365978),
 ((0, 22), 4.522104173408516),
 ((0, 23), 3.5713729154168394),
 ((0, 24), 5.707838713853954),
 ((0, 25), 5.099466215693016),
 ((0, 26), 5.21017501209063),
 ((0, 27), 6.850462613209186),
 ((0, 28), 4.866546754922108),
 ((0, 29), 4.616792102074751),
 ((0, 30), 5.197360337213703),
 ((0, 31), 5.150853891448268),
 ((0, 32),

# Multiprocessing MapReduce 

In [9]:
from multiprocessing import Pool

In [10]:
num_processes = 8

pool = Pool(num_processes)

In [11]:
%%time 
pool.map(reducer, mapper(A, B).items())

CPU times: user 6.56 s, sys: 364 ms, total: 6.92 s
Wall time: 6.9 s


[((0, 0), 6.1693088793088675),
 ((0, 1), 5.709287837242178),
 ((0, 2), 4.4078121654729205),
 ((0, 3), 4.0101206653876025),
 ((0, 4), 3.8225004736072012),
 ((0, 5), 4.380290233516659),
 ((0, 6), 6.130258323226468),
 ((0, 7), 5.706815003100942),
 ((0, 8), 6.058521574558042),
 ((0, 9), 5.4964875608472115),
 ((0, 10), 5.585623522681542),
 ((0, 11), 5.174443294214784),
 ((0, 12), 5.319704778708427),
 ((0, 13), 4.984366253947198),
 ((0, 14), 6.121030197757768),
 ((0, 15), 4.189556878955895),
 ((0, 16), 4.6114070628025345),
 ((0, 17), 4.546543329628882),
 ((0, 18), 4.6123718581490865),
 ((0, 19), 4.162287355958404),
 ((0, 20), 5.9666088944513245),
 ((0, 21), 5.999503218365978),
 ((0, 22), 4.522104173408516),
 ((0, 23), 3.5713729154168394),
 ((0, 24), 5.707838713853954),
 ((0, 25), 5.099466215693016),
 ((0, 26), 5.21017501209063),
 ((0, 27), 6.850462613209186),
 ((0, 28), 4.866546754922108),
 ((0, 29), 4.616792102074751),
 ((0, 30), 5.197360337213703),
 ((0, 31), 5.150853891448268),
 ((0, 32),