In [373]:
import numpy as np
from pprint import PrettyPrinter

class BlockDiagonalMatrix:
    def __init__(self, blocks: np.ndarray) -> None:
        """
        Initialize the Block Diagonal Matrix.
        
        :param blocks: 2D NumPy array where each element is a NumPy array representing 
                       the diagonal elements of a diagonal block.
        """
        self.blocks = blocks
        self.n = self.blocks.shape[0]  # Number of blocks along one dimension
        self.d = self.blocks.shape[2]  # Size of each block (assuming non-empty blocks)
    
    def numpy(self) -> np.ndarray:
        """
        Convert the block diagonal matrix into a NumPy array representation.
        This method constructs the full matrix from the block diagonal representation and 
        returns it as a NumPy array. 

        :return: A NumPy array representing the full sparse matrix.
        """
        # Initialize a full matrix of zeros with the appropriate size
        full_matrix = np.zeros((self.n * self.d, self.n * self.d))

        # Fill in the diagonal blocks
        for i in range(self.n):
            for j in range(self.n):
                if np.any(self.blocks[i, j] != 0):  # Only process non-zero blocks
                    np.fill_diagonal(
                        full_matrix[i * self.d:(i + 1) * self.d, j * self.d:(j + 1) * self.d],
                        self.blocks[i, j]
                    )

        return full_matrix
    
    def __str__(self) -> str:
        """
        Constructs the full matrix from the block diagonal representation and
        returns its string representation. It fills in the diagonal blocks and converts the 
        full matrix into a string format.

        :return: A string representation of the full matrix.
        """
        return str(self.numpy())
    
    def _repr_pretty_(self, p: "PrettyPrinter", cycle: bool) -> None:
        """
        Provides a more readable representation of the matrix when used in Jupyter
        notebooks. It helps to visualize the block diagonal matrix in a user-friendly format.

        :param p: The PrettyPrinter object.
        :param cycle: Flag to indicate if there is a cycle in the object graph.
        """
        p.text(str(self) if not cycle else "...")


    def __add__(self, other: "BlockDiagonalMatrix") -> "BlockDiagonalMatrix":
        """        
        This method adds two block diagonal matrices element-wise. The result is a new
        BlockDiagonalMatrix where each block is the sum of the corresponding blocks in 
        the input matrices. It overloads the `+` operator.

        :param other: Another BlockDiagonalMatrix to add.
        :return: The sum as a new BlockDiagonalMatrix.
        """
        assert self.blocks.shape == other.blocks.shape, "Matrices must have the same dimensions."
        result_blocks = self.blocks + other.blocks
        return BlockDiagonalMatrix(result_blocks)

    def __sub__(self, other: "BlockDiagonalMatrix") -> "BlockDiagonalMatrix":
        """        
        This method subtracts two block diagonal matrices element-wise. The result is a new
        BlockDiagonalMatrix where each block is the sum of the corresponding blocks in 
        the input matrices. It overloads the `+` operator.

        :param other: Another BlockDiagonalMatrix to add.
        :return: The subtraction result as a new BlockDiagonalMatrix.
        """
        assert self.blocks.shape == other.blocks.shape, "Matrices must have the same dimensions."
        result_blocks = self.blocks - other.blocks
        return BlockDiagonalMatrix(result_blocks)
    
    def __mul__(self, other: "BlockDiagonalMatrix") -> "BlockDiagonalMatrix":
        """       
        This method performs dot product between corresponding blocks
        of the two block diagonal matrices. The result is a new BlockDiagonalMatrix where
        each block is the element-wise product of the corresponding blocks in the input matrices.
        It overloads the `*` operator.

        :param other: Another BlockDiagonalMatrix to multiply.
        :return: The product as a new BlockDiagonalMatrix.
        """
        assert self.n == other.n and self.d == other.d, "Matrices must have the same dimensions."
        result_blocks = self.blocks * other.blocks
        return BlockDiagonalMatrix(result_blocks)

    def __matmul__(self, other: "BlockDiagonalMatrix") -> "BlockDiagonalMatrix":
        """
        This method performs matrix multiplication by leveraging broadcasting and vectorized operations.
        It computes the result by multiplying corresponding blocks and summing over the common dimension.
        The resulting BlockDiagonalMatrix contains the product of the two input matrices in block diagonal form.
        It overloads the `@` operator.

        :param other: Another BlockDiagonalMatrix to multiply.
        :return: The product as a new BlockDiagonalMatrix.
        """
        assert self.n == other.n and self.d == other.d, "Matrices must have the same dimensions."

        # Expand dimensions to enable broadcasting
        left_blocks = self.blocks[:, :, np.newaxis, :]
        right_blocks = other.blocks[np.newaxis, :, :, :]

        # Element-wise multiplication and sum along the k-axis
        result_blocks = np.sum(left_blocks * right_blocks, axis=1)

        return BlockDiagonalMatrix(result_blocks)

    def inverse(self) -> "BlockDiagonalMatrix":
        """
        Efficiently compute the inverse of a Block Diagonal Matrix in divide and conquer method.
        In case of a singular matrix, it raises an exception.
        
        :return: The inverse as a new BlockDiagonalMatrix.
        """
        inverse_blocks = np.zeros(self.blocks.shape)
        
        for i in range(self.d):
            sub_arr = self.blocks[:, :, i]
            try:
                sub_arr_inv = np.linalg.inv(sub_arr)
            except np.linalg.LinAlgError as e:
                raise Exception("Matrix is singular and cannot be inverted.")
            inverse_blocks[:, :, i] = sub_arr_inv
        
        return BlockDiagonalMatrix(inverse_blocks)


In [399]:
# Example usage
blocks_A = np.array([
    [[1, 2], [3, 4], [1, 0]],
    [[1, 0], [5, 6], [1, 0]],
    [[0, 0], [1, 1], [7, 8]]
])
blocks_B = np.array([
    [[2, 3], [0, 2], [1, 0]],
    [[1, 0], [4, 5], [1, 0]],
    [[1,3], [1, 0], [6, 7]]
])

A = BlockDiagonalMatrix(blocks_A)
B = BlockDiagonalMatrix(blocks_B)

# Add matrices
print("Matrix addition:", end=" ")
print(((A + B).numpy() == (A.numpy() + B.numpy())).all())

# Subtract matrices
print("Matrix subtraction: ", end=" ")
print(((A - B).numpy() == (A.numpy() - B.numpy())).all())

# Dot product
print("Dot product: ", end=" ")
print(((A * B).numpy() == (A.numpy() * B.numpy())).all())

# Multiply matrices
print("Matrix multiplication: ", end=" ")
print(((A @ B).numpy() == (A.numpy() @ B.numpy())).all())

# Invert
print("Matrix inversion: ", end=" ")
print((A.inverse().numpy() == np.linalg.inv(A.numpy())).all())

Matrix addition: True
Matrix subtraction:  True
Dot product:  True
Matrix multiplication:  True
Matrix inversion:  True


In [355]:
X = BlockDiagonalMatrix(np.random.randint(0, 9, (15,15,11)))
Y = BlockDiagonalMatrix(np.random.randint(0, 9, (15,15,11)))

x = X.numpy()
y = Y.numpy()


In [380]:
%timeit -n20 X + Y
%timeit -n20 x + y


The slowest run took 6.71 times longer than the fastest. This could mean that an intermediate result is being cached.
8.64 μs ± 6.32 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)
24.5 μs ± 7.64 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)


In [398]:
%timeit -n20 X + Y
%timeit -n20 x - y


The slowest run took 10.73 times longer than the fastest. This could mean that an intermediate result is being cached.
7.98 μs ± 11.1 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)
21.8 μs ± 8.17 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)


In [382]:
%timeit -n20 X * Y
%timeit -n20 x * y


16.8 μs ± 9.72 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)
30.2 μs ± 6.28 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)


In [392]:
%timeit -n20 X @ Y
%timeit -n20 x @ y


190 μs ± 53.4 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)
333 μs ± 92.6 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)


In [379]:
%timeit -n20 X.inverse()
%timeit -n20 np.linalg.inv(x)


1.29 ms ± 233 μs per loop (mean ± std. dev. of 7 runs, 20 loops each)
110 ms ± 18.5 ms per loop (mean ± std. dev. of 7 runs, 20 loops each)
