In [1]:
# This script implements the matrix class and its arthematic operations.

In [113]:
from __future__ import annotations
import pprint
from typing import Sequence, List

In [173]:
class Matrix:
    def __init__(self, val: Sequence[Sequence[float]]):
        if not val or not val[0]:
            raise ValueError("Matrix must be non-empty")
        ncols = len(val[0])
        if any(len(row) != ncols for row in val):
            raise ValueError("All rows must have the same length")
        # defensive copy
        self.val: List[List[float]] = [list(row) for row in val]

    # ---------- Utils ----------
    @property
    def nrows(self):
        return len(self.val)
    
    @property
    def ncols(self):
        return len(self.val[0])
    
    @property
    def shape(self):
        return (self.nrows, self.ncols)

    @property
    def ndim(self):
        return 2
        
    def __repr__(self):
        rows = ', \n'.join(str(row) for row in self.val)
        return f"Matrix([\n{rows}\n])"

    def __getitem__(self, idx: int):
        '''for indexing into the matrix to return the indexed row from the matrix'''
        if idx < 0 or idx >= len(self.val):
            raise IndexError(f"Row index {idx} out of range; must be in [0, {len(self.val)-1}]")
        
        return self.val[idx]

    def __eq__(self, other):
        '''Checks equality of two matrices element wise.'''
        if not isinstance(other, Matrix): return False

        if self.shape != other.shape: return False
        
        return all(self[i][j] == other[i][j]
                  for i in range(self.nrows)
                  for j in range(self.ncols))

    def zeros(self, shape: tuple):
        '''create a matrix of zeros for shape (nrows, ncols)'''
        nrows, ncols = shape
        out = Matrix([[0 for _ in range(ncols)] for _ in range(nrows)])
        return out
    
    def ones(self, shape: tuple):
        '''create a matrix of ones for shape (nrows, ncols)'''
        nrows, ncols = shape
        out = Matrix([[0 for _ in range(ncols)] for _ in range(nrows)])
        return out

    # ---------- Core Functionality ----------
    def __add__(self, other: Matrix) -> Matrix:
        '''matrix only for now, does not accept scalars'''
        if not isinstance(other, Matrix): 
            return NotImplemented
        
        if self.shape != other.shape:
            raise ValueError(f'Cannot add matrices of different sizes, {self.shape} != {other.shape}')

        out = [[self.val[i][j] + other.val[i][j] for j in range(self.ncols)]for i in range(self.nrows)]
        return Matrix(out)
    
    def __sub__(self, other: Matrix) -> Matrix:
        '''matrix only for now, does not accept scalars'''
        if not isinstance(other, Matrix): 
            return NotImplemented
        
        if self.shape != other.shape:
            raise ValueError(f'Cannot add matrices of different sizes, {self.shape} != {other.shape}')

        out = [[self.val[i][j] - other.val[i][j] for j in range(self.ncols)] 
               for i in range(self.nrows)]
        return Matrix(out)

    def __matmul__(self, other):
        # to keep my sanity
        n = len(self.val)      # row
        m = len(self.val[0])   # col  
        p = len(other.val)     # row 
        q = len(other.val[0])  # col
        
        if not isinstance(other, Matrix): return NotImplemented 

        # check matmul conditions
        if m!=p:
            raise ValueError(f'Cannot multiply matrices of shapes [{self.shape} X {other.shape}], {self.shape[1]}!={other.shape[0]}!')
            
        res = [[0 for _ in range(q)] for _ in range(n)]

        for i in range(n):
            for j in range(q):
                for k in range(m):
                    res[i][j] += self.val[i][k] * other.val[k][j]
        
        return Matrix(res)

In [174]:
x = Matrix([[1,2,3],[4,5,6]])
y = Matrix([[1,4,3],[4,5,6]])
x == y

False

In [175]:
x + y

Matrix([
[2, 6, 6], 
[8, 10, 12]
])

In [176]:
r = Matrix([[4, 6, 8], [4, 6, 8]])
v = Matrix([[2, 4, 6], [2, 4, 6]])
r - v 

Matrix([
[2, 2, 2], 
[2, 2, 2]
])

In [177]:
x.shape, y.shape

((2, 3), (2, 3))

In [178]:
x @  y # works

ValueError: Cannot multiply matrices of shapes [(2, 3) X (2, 3)], 3!=2!

In [184]:
x = Matrix([[4, 6, 8], [4, 6, 8]])    # 2,3
y = Matrix([[1, 2], [2, 3], [3, 4]])  # 3,2

x @ y                                 # 2,2

Matrix([
[40, 58], 
[40, 58]
])

In [185]:
x.zeros((2, 3))

Matrix([
[0, 0, 0], 
[0, 0, 0]
])

In [186]:
x.ones((1, 3))

Matrix([
[0, 0, 0]
])