Efficient implementation of Dask's `matmul` from https://github.com/dask/dask/pull/7000.
This is included here so we don't have to rebuild Dask.

In [1]:
import numpy as np
from dask.array.core import (
    asanyarray,
    blockwise,
    is_scalar_for_elemwise,
)

In [3]:
def result_type(*args):
    args = [a if is_scalar_for_elemwise(a) else a.dtype for a in args]
    return np.result_type(*args)

def _matmul(a, b):
    chunk = np.matmul(a, b)
    # Since we have performed the contraction via matmul
    # but blockwise expects all dimensions back, we need
    # to add one dummy dimension back
    return chunk[..., np.newaxis]


def matmul(a, b):
    a = asanyarray(a)
    b = asanyarray(b)

    if a.ndim == 0 or b.ndim == 0:
        raise ValueError("`matmul` does not support scalars.")

    a_is_1d = False
    if a.ndim == 1:
        a_is_1d = True
        a = a[np.newaxis, :]

    b_is_1d = False
    if b.ndim == 1:
        b_is_1d = True
        b = b[:, np.newaxis]

    if a.ndim < b.ndim:
        a = a[(b.ndim - a.ndim) * (np.newaxis,)]
    elif a.ndim > b.ndim:
        b = b[(a.ndim - b.ndim) * (np.newaxis,)]

    # out_ind includes all dimensions to prevent contraction
    # in the blockwise below
    out_ind = tuple(range(a.ndim + 1))
    # lhs_ind includes `a`/LHS dimensions
    lhs_ind = tuple(range(a.ndim))
    # on `b`/RHS everything above 2nd dimension, is the same
    # as `a`, -2 dimension is "contracted" with the last dimension
    # of `a`, last dimension of `b` is `b` specific
    rhs_ind = tuple(range(a.ndim - 2)) + (lhs_ind[-1], a.ndim)

    out = blockwise(
        _matmul,
        out_ind,
        a,
        lhs_ind,
        b,
        rhs_ind,
        adjust_chunks={lhs_ind[-1]: 1},
        dtype=result_type(a, b),
        concatenate=False,
    )

    # Because contraction + concatenate in blockwise leads to high
    # memory footprints, we want to avoid them. Instead we will perform
    # blockwise (without contraction) followed by reduction. More about
    # this issue: https://github.com/dask/dask/issues/6874

    # When we perform reduction, we need to worry about the last 2 dimensions
    # which hold the matrices, some care is required to handle chunking in
    # that space.
    contraction_dimension_is_chunked = (
        max(min(a.chunks[-1], b.chunks[-2])) < a.shape[-1]
    )
    b_last_dim_max_chunk = max(b.chunks[-1])
    if contraction_dimension_is_chunked or b_last_dim_max_chunk < b.shape[-1]:
        if b_last_dim_max_chunk > 1:
            # This is the case when both contraction and last dimension axes
            # are chunked
            out = out.reshape(out.shape[:-1] + (1, -1))
            out = out.sum(axis=-3)
            out = out.reshape(out.shape[:-2] + (b.shape[-1],))
        else:
            # Contraction axis is chunked
            out = out.sum(axis=-2)
    else:
        # Neither contraction nor last dimension axes are chunked, we
        # remove the dummy dimension without reduction
        out = out.reshape(out.shape[:-2] + (b.shape[-1],))

    if a_is_1d:
        out = out[..., 0, :]
    if b_is_1d:
        out = out[..., 0]

    return out