In [29]:
class Matrix():
    def __init__(self, arr):
        """
        Initialise the matrix object with a multidimensional list (or just a list if a vector)\
        arr: List
        eg. 
        matrix = Matrix([[1,2],
                         [3,4]])
        """
        self.data = arr
    def flatten(self):
        """
        Flatten the multidimensional array and return it as a list
        returns: List
        eg.
        matrix.flatten() #[1,2,3,4]
        """
        if self.data == []:
        return S
        if isinstance(S[0], list):
            return flatten(S[0]) + flatten(S[1:])
        return S[:1] + flatten(S[1:])
    
    def scalar_multiply(self, scalar):
        """
        Multiply the entire matrix by a scalar and return it as a list
        scalar: int
        returns: List
        eg.
        matrix.scalar_multiply(2) #[[2,4], 
                                    [6,8]]
        """
        outp_list = []
        for i in self.data:
            tmp_list = []
            for j in i:
                tmp_list.append(j * scalar)
            outp_list.append(tmp_list)
        return Matrix(outp_list)
                
    def location(self, row, column):
        """
        Retrieve a value at a specific integer row and column in the matrix. 
        row: int
        column: int
        returns: int
        eg.
        matrix.location(0, 0) #returns 1 
        matrix.location(0, 1) #returns 2
        matrix.location(1, 0) #returns 3
        matrix.location(1, 1) #returns 4
        """
        return Matrix(self.data[row][column])
    def max_along_axis(self, axis):
        """
        Retrieve the max value along an axis
        axis: int
        returns: int
        eg. 
        matrix.max_long_axis(0) #maximum in each row is [2, 4]
        matrix.max_long_axis(1) #maximum in each column is [3, 4]
        """
        if axis == 0:
            return Matrix([max(o) for o in self.data])
        elif axis == 1:
            row_len = len(self.data[0])
            max_dict = {i:[] for i in range(row_len)}
            for row in self.data:
                for idx, item in enumerate(row):
                    max_dict[idx].append(item)
        return Matrix([max(o) for o in max_dict.values()])

    def slice_axis(self, axis, idx):
        """
        Return an entire axis 
        axis: int
        idx: int
        returns: List
        eg. 
        #0th dimension, 0th slice 
        matrix.get_axis(0, 0) #[1,2]
        #0th dimension, 1st slice 
        matrix.get_axis(0, 1) #[3,4]
        #1st dimension, 0th slice 
        matrix.get_axis(1, 0) #[[1],
                                [3]]
        #1st dimension, 1st slice 
        matrix.get_axis(1, 1) #[[2],
                                [4]]
        """
        if axis == 0:
            for i, o in enumerate(self.data):
                if i == idx:
                    return Matrix(o)
        else:
            outp_list = []
            for row in self.data:
                outp_list.append([element for i, element in enumerate(row) if i == idx])
            return Matrix(outp_list)
            
    def reduce_sum(self, axis):
        """
        Sum up the values across a given axis and returns a list 
        axis: int
        returns: List
        eg.
        matrix.reduce_sum(0) # sums up the values across the row [3, 7]
        matrix.reduce_sum(1) # sum up the values across the column [4, 6]
        """
        if axis == 0:
            return Matrix([sum(o) for o in self.data]) 
        elif axis == 1:
            row_len = len(self.data[0])
            max_dict = {i:[] for i in range(row_len)}
            for row in self.data:
                for idx, item in enumerate(row):
                    max_dict[idx].append(item)
        return Matrix([sum(o) for o in max_dict.values()])
    
    def reduce_mean(self, axis):
        """
        Calculate the mean across a given axis and returns a list
        axis: int
        returns: List
        eg.
        matrix.reduce_mean(0) # mean across rows [3/2, 7 /2] = [1.5, 3.5]
        matrix.reduce_mean(1) # mean across the rows is [4/2, 6/2] = [2, 3]
        """
        dim_at_axis = self.get_len_at_depth(axis)
        sum_arr = self.reduce_sum(axis)
        return Matrix([o / dim_at_axis for o in sum_arr.data])
        
    def get_len_at_depth(self, depth):
        """
        Get the length of the data list at a given depth
        """
        curr_arr = self.data
        for i in range(depth):
            curr_arr = curr_arr[0]
        return len(curr_arr)
    
    def get_max_depth(self):
        """
        Get the max depth of a nested list
        """
        curr_arr = self.data
        count = 0
        while isinstance(curr_arr, list):
            curr_arr = curr_arr[0]
            count += 1
        return count
            
    def __str__(self):
        return str(self.data)
    
    def __repr__(self):
        outp_str = ""
        for row in self.data:
            outp_str += f"{row} \n"
        return outp_str

In [30]:
m = Matrix([1,2,3])

In [31]:
m.get_max_depth()

1

In [32]:
m.flatten()

KeyboardInterrupt: 

In [182]:
m.scalar_multiply(2)

[2, 4, 6] 
[8, 10, 12] 

In [183]:
m.location(0,2).data

3

In [184]:
m.max_along_axis(0).data

[3, 6]

In [185]:
m.max_along_axis(1).data

[4, 5, 6]

In [186]:
m.slice_axis(0,0).data

[1, 2, 3]

In [187]:
m.slice_axis(1,2).data

[[3], [6]]

In [188]:
m.reduce_sum(0).data

[6, 15]

In [189]:
m.reduce_sum(1).data

[5, 7, 9]

In [190]:
m.reduce_mean(0).data

[3.0, 7.5]

In [191]:
m.reduce_mean(1).data

[1.6666666666666667, 2.3333333333333335, 3.0]

defining __str__ let's us decide what shows up when we run `print(class_instance)`. Defining __repr__ let's us decide what happens when we just leave the variable on the stack `m`

In [192]:
m #from __repr__

[1, 2, 3] 
[4, 5, 6] 

In [193]:
print(m) #from __str__

[[1, 2, 3], [4, 5, 6]]
