diff --git a/divide_and_conquer/strassen_matrix_multiplication.py b/divide_and_conquer/strassen_matrix_multiplication.py
index f529a255d2ef..78c2e56fac07 100644
--- a/divide_and_conquer/strassen_matrix_multiplication.py
+++ b/divide_and_conquer/strassen_matrix_multiplication.py
@@ -49,18 +49,20 @@ def split_matrix(a: list) -> tuple[list, list, list, list]:
     if len(a) % 2 != 0 or len(a[0]) % 2 != 0:
         raise Exception("Odd matrices are not supported!")
 
-    matrix_length = len(a)
-    mid = matrix_length // 2
+    def extract_submatrix(rows, cols):
+        return [[a[i][j] for j in cols] for i in rows]
 
-    top_right = [[a[i][j] for j in range(mid, matrix_length)] for i in range(mid)]
-    bot_right = [
-        [a[i][j] for j in range(mid, matrix_length)] for i in range(mid, matrix_length)
-    ]
+    mid = len(a) // 2
 
-    top_left = [[a[i][j] for j in range(mid)] for i in range(mid)]
-    bot_left = [[a[i][j] for j in range(mid)] for i in range(mid, matrix_length)]
+    rows_top, rows_bot = range(mid), range(mid, len(a))
+    cols_left, cols_right = range(mid), range(mid, len(a))
 
-    return top_left, top_right, bot_left, bot_right
+    return (
+        extract_submatrix(rows_top, cols_left),  # Top-left
+        extract_submatrix(rows_top, cols_right),  # Top-right
+        extract_submatrix(rows_bot, cols_left),  # Bottom-left
+        extract_submatrix(rows_bot, cols_right),  # Bottom-right
+    )
 
 
 def matrix_dimensions(matrix: list) -> tuple[int, int]:
diff --git a/divide_and_conquer/tests/__init__.py b/divide_and_conquer/tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py
new file mode 100644
index 000000000000..d3ed399adfbd
--- /dev/null
+++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py
@@ -0,0 +1,47 @@
+import pytest
+
+from divide_and_conquer.strassen_matrix_multiplication import split_matrix
+
+
+def test_4x4_matrix():
+    matrix = [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]]
+    expected = ([[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], [[4, 3], [1, 6]])
+    assert split_matrix(matrix) == expected
+
+
+def test_8x8_matrix():
+    matrix = [
+        [4, 3, 2, 4, 4, 3, 2, 4],
+        [2, 3, 1, 1, 2, 3, 1, 1],
+        [6, 5, 4, 3, 6, 5, 4, 3],
+        [8, 4, 1, 6, 8, 4, 1, 6],
+        [4, 3, 2, 4, 4, 3, 2, 4],
+        [2, 3, 1, 1, 2, 3, 1, 1],
+        [6, 5, 4, 3, 6, 5, 4, 3],
+        [8, 4, 1, 6, 8, 4, 1, 6],
+    ]
+    expected = (
+        [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
+        [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
+        [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
+        [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
+    )
+    assert split_matrix(matrix) == expected
+
+
+def test_invalid_odd_matrix():
+    matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
+    with pytest.raises(Exception, match="Odd matrices are not supported!"):
+        split_matrix(matrix)
+
+
+def test_invalid_non_square_matrix():
+    matrix = [
+        [1, 2, 3, 4],
+        [5, 6, 7, 8],
+        [9, 10, 11, 12],
+        [13, 14, 15, 16],
+        [17, 18, 19, 20],
+    ]
+    with pytest.raises(Exception, match="Odd matrices are not supported!"):
+        split_matrix(matrix)