In [1]:
class Matrix():
    def __init__(self, arr):
        """
        Initialise the matrix object with a multidimensional list (or just a list if a vector)\
        arr: List
        eg. 
        matrix = Matrix([[1,2],
                         [3,4]])
        """
        self.data = arr
    def flatten(self):
        """
        Flatten the multidimensional array and return it as a list
        returns: List
        eg.
        matrix.flatten() #[1,2,3,4]
        """
        flat_arr = [i for o in self.data for i in o]
        return Matrix(flat_arr)
    def scalar_multiply(self, scalar):
        """
        Multiply the entire matrix by a scalar and return it as a list
        scalar: int
        returns: Matrix
        eg.
        matrix.scalar_multiply(2) #[[2,4], 
                                    [6,8]]
        """
        outp_list = []
        for i in self.data:
            tmp_list = []
            for j in i:
                tmp_list.append(j * scalar)
            outp_list.append(tmp_list)
        return Matrix(outp_list)
                
    def location(self, row, column):
        """
        Retrieve a value at a specific integer row and column in the matrix. 
        row: int
        column: int
        returns: int
        eg.
        matrix.location(0, 0) #returns 1 
        matrix.location(0, 1) #returns 2
        matrix.location(1, 0) #returns 3
        matrix.location(1, 1) #returns 4
        """
        return self.data[row][column]
    def max_along_axis(self, axis):
        """
        Retrieve the max value along an axis
        axis: int
        returns: Matrix
        eg. 
        matrix.max_long_axis(0) #maximum in each row is [2, 4]
        matrix.max_long_axis(1) #maximum in each column is [3, 4]
        """
        if axis == 0:
            return Matrix([max(o) for o in self.data])
        elif axis == 1:
            row_len = len(self.data[0])
            max_dict = {i:[] for i in range(row_len)}
            for row in self.data:
                for idx, item in enumerate(row):
                    max_dict[idx].append(item)
        return Matrix([max(o) for o in max_dict.values()])

    def slice_axis(self, axis, idx):
        """
        Return an entire axis 
        axis: int
        idx: int
        returns: Matrix
        eg. 
        #0th dimension, 0th slice 
        matrix.get_axis(0, 0) #[1,2]
        #0th dimension, 1st slice 
        matrix.get_axis(0, 1) #[3,4]
        #1st dimension, 0th slice 
        matrix.get_axis(1, 0) #[[1],
                                [3]]
        #1st dimension, 1st slice 
        matrix.get_axis(1, 1) #[[2],
                                [4]]
        """
        if axis == 0:
            for i, o in enumerate(self.data):
                if i == idx:
                    return Matrix(o)
        else:
            outp_list = []
            for row in self.data:
                outp_list.append([element for i, element in enumerate(row) if i == idx])
            return Matrix(outp_list)
            
    def reduce_sum(self, axis):
        """
        Sum up the values across a given axis and returns a list 
        axis: int
        returns: Matrix
        eg.
        matrix.reduce_sum(0) # sums up the values across the row [3, 7]
        matrix.reduce_sum(1) # sum up the values across the column [4, 6]
        """
        if axis == 0:
            return Matrix([sum(o) for o in self.data]) 
        elif axis == 1:
            row_len = len(self.data[0])
            max_dict = {i:[] for i in range(row_len)}
            for row in self.data:
                for idx, item in enumerate(row):
                    max_dict[idx].append(item)
        return Matrix([sum(o) for o in max_dict.values()])
    
    def reduce_mean(self, axis):
        """
        Calculate the mean across a given axis and returns a list
        axis: int
        returns: Matrix
        eg.
        matrix.reduce_mean(0) # mean across rows [3/2, 7 /2] = [1.5, 3.5]
        matrix.reduce_mean(1) # mean across the rows is [4/2, 6/2] = [2, 3]
        """
        dim_at_axis = self.get_len_at_depth(axis)
        sum_arr = self.reduce_sum(axis)
        return Matrix([o / dim_at_axis for o in sum_arr.data])
        
    def get_len_at_depth(self, depth):
        """
        Get the length of the data array at a given depth. Assumes the list is rectangular (ie. n x m)
        depth: int
        returns: int
        """
        curr_arr = self.data
        for i in range(depth):
            if not isinstance(curr_arr[0], list):
                return 0
            curr_arr = curr_arr[0]
        return len(curr_arr)
    
    def transpose(self):
        """
        Transposes a matrix (ie. reflects the matrix along the diagonal)
        returns: Matrix
        """
        flat_array = [i for o in self.data for i in o]
        dim = len(self.data)
        
        row_len = len(self.data[0]) #this only works because we're assuming 2 dimensions
        max_dict = {i:[] for i in range(row_len)}
        for idx,item in enumerate(flat_array):
            max_dict[idx % row_len].append(item)
        outp_list = [o for o in max_dict.values()]
        return Matrix(outp_list)
    
    def __str__(self):
        return str(self.data)
    
    def __repr__(self):
        outp_str = ""
        for row in self.data:
            outp_str += f"{row} \n"
        return outp_str

In [2]:
m = Matrix([[1,2,3], [4,5,6]])

In [3]:
m.flatten()

1 
2 
3 
4 
5 
6 

In [4]:
m.scalar_multiply(2)

[2, 4, 6] 
[8, 10, 12] 

In [5]:
m.location(0,2)

3

In [6]:
m.max_along_axis(0)

3 
6 

In [7]:
m.max_along_axis(1)

4 
5 
6 

In [8]:
m.slice_axis(0,0)

1 
2 
3 

In [9]:
m.slice_axis(1,2)

[3] 
[6] 

In [10]:
m.reduce_sum(0)

6 
15 

In [11]:
m.reduce_sum(1).data

[5, 7, 9]

In [12]:
m.reduce_mean(0).data

[3.0, 7.5]

In [13]:
m.reduce_mean(1).data

[1.6666666666666667, 2.3333333333333335, 3.0]

In [14]:
m.transpose()

[1, 4] 
[2, 5] 
[3, 6] 

defining __str__ let's us decide what shows up when we run `print(class_instance)`. Defining __repr__ let's us decide what happens when we just leave the variable on the stack `m`. In the above, notice difference between trying to get .data VS without the .data.

In [15]:
m #from __repr__

[1, 2, 3] 
[4, 5, 6] 

In [16]:
print(m) #from __str__

[[1, 2, 3], [4, 5, 6]]
