# Implementing lattice multiplication in Python

Lattice multiplication is an algorithm that (at least when I was in a child in England before readily available cell phones) is sometimes taught as an approach that can be used to multiply large numbers manually (i.e., in one's head or with pen and paper). Lattice multiplication is an algorithm that is fairly simple to use, especially when drawn out and seen. Here, we will implement the algorithm in Python.

## Lattice multiplication algorithm steps

Before we can implement the algorithm in Python, we need to first identify the discrete steps of the algorithm and consider how each can be acheived. As discussed in the lecture slides, the steps of Lattice multiplication are as follows:

1. For two integers, N and M, construct a matrix of dimensions N by M, where each cell of the matrix is divided in two parts
2. For each cell of the matrix, multiply the appropriate digit of N and M and store the result as two digits (i.e., if the result is less than 10, add a leading 0 to make it two digits)
3. Sum the numbers along each diagonal of the matrix
4. Working from right to left, process the sums. If any of the sums are greater than 10, carry the 1 by adding it to the value of the next sum

We can see these steps at work if we go through an illustrated example. Let's work through the process to multiply 12 * 345.

First, a matrix is constructed of dimensions N by M. i.e., N columns and M rows.

![starting grid](images/empty_grid.png)

Next, we fill in each cell with the product of the corresponding digits of our two integers. If the product of the two digits is greater than 10, the digit corresponding to the number of "tens" goes in the upper half of the cell, while the "ones" digit goes in the lower half of the cell.

![filled grid](images/filled_grid.png)

Finally, the numbers along each diagonal are summed, starting with the bottom right corner and working right to left. When the result of an summing operation is greater than 10, the "tens" digit is carried to the following diagonal. That results in the following output for our grid

![finished_gred](images/finished_grid.png)

## Python implementation

Let's now work through the steps of the algorithm and write Python code to perform each one.

### Matrix construction

First, we must contruct our empty matrix. The matrix should have rows equal to the number of digits in one number and columns equal to the number of digits in the other number. It doesn't matter which number is used to determine each dimension. Let's use the first number to determine the number of rows and the second for columns. 

As the first step of constructing our matrix is to count digits in each number, we first need a way to get that information. We can use our iterint approach from the week when we covered functions. That looks something like this

In [1]:
def int_digits(num: int) -> list[int]:
	digits = []
	for n in str(num):
		digits.append(int(n))
	return digits

We can now get the digits of an integer in a `list` that can be iterated over. More importantly for our present task though, `list` has a `__len__()` method so we can get its length with `len()`.

In [2]:
x = int_digits(12345)
print(x)
print(len(x))

[1, 2, 3, 4, 5]
5


Now we can write a function to construct a matrix with dimensions determined by the number of digits in two numbers. Let's stick with the numbers used for the above example: 12 * 345, which will create a matrix of two rows and three columns.

A matrix is simply a 2-dimensional array of elements. What that looks like in pure Python is typically a `list` of `list`s. You can think of the outer `list` as being the rows of the matrix and the inner `list` as the columns. For example, `[[1, 2], [3, 4]]` would represent the matrix

1 2\
3 4

We could also use numpy or pandas to create a matrix as they are designed specifically for that sort of data structure. We'll stick with pure Python here.

To add a bit more complexity, we're actually going to need to store a list of the two digits in each cell of our matrix. That will result in a `list` of `list`s of `list`s (i.e., a three-dimensional array). This is simply because we need to be able to refer to the first and second digit. We can't store the two digits as single `int` because `int` does not support indexing. Our three-dimensional array is really just going to be a two-dimensional matrix, where each cell contains two numbers.


In [3]:
def make_matrix(number_a: int, number_b: int) -> list[list[list[None]]]:
    # First get a list of the digits
    digits_a = int_digits(number_a)
    digits_b = int_digits(number_b)
    
    # Build the empty matrix
    matrix = []
    for row in digits_a: # We won't actually use the row and digits_a variables here, but these variable names help make what we are doing more readable
        mat_row = []
        for col in digits_b:
            # append the empty list that will contain our two digits later
            mat_row.append([])
        matrix.append(mat_row)
    
    return matrix

num_a = 12
num_b = 345
matrix = make_matrix(num_a, num_b)

# confirm it has the right dimensions (2 rows, 3 columns)
for row in matrix:
    print(row)
        

[[], [], []]
[[], [], []]


### Matrix filling

Now we have a function to make a matrix, the next step is to fill it in with the result of each multiplication operation. 

In [4]:
def fill_matrix(number_a: int, number_b: int, matrix: list[list[list[None]]]) -> list[list[list[int]]]:
    # First get a list of the digits
    digits_a = int_digits(number_a)
    digits_b = int_digits(number_b)
    
    # Multiply all pairs of digits between the two numbers and store the results in the corresponding cells
    for i, a in enumerate(digits_a):
        for j, b in enumerate(digits_b):
            product = a * b
            digits_prod = int_digits(product)
    
            # Before we can add it to the matrix, we need to deal with the digits again
            if len(digits_prod) < 2:
                # single digit products should have a 0 added as the "ten"
                digits_prod = [0] + digits_prod

            # replace existing list as cell with digits of the product
            matrix[i][j] = digits_prod
    
    return matrix


filled_matrix = fill_matrix(num_a, num_b, matrix)

for row in filled_matrix:
    print(row)

[[0, 3], [0, 4], [0, 5]]
[[0, 6], [0, 8], [1, 0]]


### Bonus: matrix viewing

Now we have our filled matrix with all of the values added. You can compare this to the example at the beginning of this notebook to see that we have recreated the matrix shown there. It's a bit hard to see without the diagonal indicated. I've included a function to visualize the matrix a bit more clearly. We won't go through it here, but hopefully it will help with following along the steps with our data. It should work to visualize the matrix for any pair of integers. It takes a matrix as well as an optional list of digits from the next step we'll go over.

In [5]:
def show_matrix(matrix: list[list[list[int]]], components: list[int] = None
) -> None:
    mat_width = len(matrix[0])
    if components:
        h_delim = "  "
    else:
        h_delim = ""

    # build top line
    h_delim += "".join(['+'] + ["-------+"]*mat_width)
    outlines = [h_delim]

    # build main matrix
    for n, row in enumerate(matrix):
        if components:
            num = components[n]
            line = [f"{num}{' ' * (2-len(str(num)))}|"]
        else:
            line = ["|"]
        line += [f" {i[0]} / {i[1]} |" for i in row]
        outlines.append("".join(line))
        outlines.append(h_delim)

    # Add components along bottom if given
    if components:
        bottom_line = "      "
        bottom_line += "".join([
            f"{i}"+ " " * (8-len(str(i))) for i in components[len(matrix):]])
        outlines.append(bottom_line)

    print("\n".join(outlines))
    
show_matrix(filled_matrix)

+-------+-------+-------+
| 0 / 3 | 0 / 4 | 0 / 5 |
+-------+-------+-------+
| 0 / 6 | 0 / 8 | 1 / 0 |
+-------+-------+-------+


Not perfect, but a bit easier to see the diagonal to help us with the next step. We'll use this again once we've added the sums of each diagonal. I've called those "components" in the function because they are later put together to form the solution to our multiplication. I couldn't think of a better name, but if you can then feel free to change it.

### Summing diagonals

This part might be the trickiest to get your head around. However, once you see the solution it will hopefully become clear. If you like a puzzle, this is a point at which you could stop reading and try to figure it out for yourself. To assist with that, I'll state the problem posed by the next step and then indicate when the solution is going to be stated.

**The problem**

If you look back at the example at the start of this document, you will see that we need to sum each diagonal of the matrix. In that example, only 4 digits are shown. However, there are actually 5 sums being calculated, but the top left cell contains a 0 so that sum isn't included in the final result. We can see the positions of the 5 values (1 for each row and column) that we will produce by using the `show_matrix()` function with a list of 0s


In [6]:
show_matrix(filled_matrix, [0]*5)

  +-------+-------+-------+
0 | 0 / 3 | 0 / 4 | 0 / 5 |
  +-------+-------+-------+
0 | 0 / 6 | 0 / 8 | 1 / 0 |
  +-------+-------+-------+
      0       0       0       


So, if there are 5 diagonals, how do we identify which cells of our three-dimensional matrix should be added to each of those 5 numbers? We need to come up with a programmatic approach that would work for any size matrix (Note that each cell will always contain two values).

If you wish to try it yourself first, stop reading here.

**A solution**

If you read across row 0 (i.e., the row at index 0 of the matrix) and consider which of the 5 sums each number will be added to, you start to see a pattern. Starting with the top left of the first row and referring to the columns and the two numbers in each cell by their indices: column 0 and cell 0 is added to the number 0. Column 0, cell 1 is added to the number 1. Column 1, cell 0 is added to number 1. Column 1, cell 1 is added to number 2. And so on. 

Row 1 further shows the pattern: row 1, column 0, cell 0 is added to number 1. Row 1, column 0, cell 1 is added to number 2. Row 1, column 1, cell 0 is added to number 2. Row 1, column 1, cell 1 is added to number 3.

The pattern is that the indices of each number (row, column, and cell) adds up to the index of the corresponding number. Therefore, if we iterate over the rows, columns, and cells and keep track of which index we are working with, we will be able to add each digit in our matrix to the appropriate number.

In [7]:
def read_matrix(matrix: list[list[list[int]]]) -> list[int]:
    num_output_digits = len(matrix) + len(matrix[0]) # number of rows + number of columnes
    out_digits = [0 for _ in range(num_output_digits)] # start each number at 0
    for i in range(len(matrix)):
        for j in range(len(matrix[0])):
            for k in range(2):
                out_digits[i+j+k] += matrix[i][j][k]

    return(out_digits)

components = read_matrix(filled_matrix)
show_matrix(filled_matrix, components)

  +-------+-------+-------+
0 | 0 / 3 | 0 / 4 | 0 / 5 |
  +-------+-------+-------+
3 | 0 / 6 | 0 / 8 | 1 / 0 |
  +-------+-------+-------+
      10      14      0       


This matrix looks very similar to the one at the start of this notebook. The only difference now is that we have some two-digit numbers which need to be dealt with. To deal with those, we will need to carry the 1 to the number to the left. We can take advantage of our `int_digits()` function to access the digits of each number to perform this step.

In [8]:
def carry_the_one(numbers: list[int]) -> list[int]:
    numbers_copy = numbers.copy() # make a copy so we can edit a local version of the input list
    out_digits = [0 for _ in numbers_copy]
    for i in reversed(range(len(numbers_copy))): # start with the number on the right and work left
        digits = int_digits(numbers_copy[i])
        if len(digits) == 2:
            # add base digit to corresponding position in output
            out_digits[i] += digits[1]
            # carry the 1
            numbers_copy[i-1] += digits[0]
        else:
            out_digits[i] += digits[0]
    return out_digits

carried = carry_the_one(components)

print(carried)

[0, 4, 1, 4, 0]


We can also view those numbers on our matrix to confirm that it looks very similar to the example image.

In [9]:
show_matrix(filled_matrix, carried)

  +-------+-------+-------+
0 | 0 / 3 | 0 / 4 | 0 / 5 |
  +-------+-------+-------+
4 | 0 / 6 | 0 / 8 | 1 / 0 |
  +-------+-------+-------+
      1       4       0       


Now, getting the final answer is as simple as wrangling the Python object we've been using to deal with the individual digits. i.e., we need to stitch together our `list` of `int`s into a single `int`. We can perform the opposite operations used in `int_digits()` to do that.

In [10]:
def collapse_int_list(numbers: list[int]) -> int:
    str_nums = [str(n) for n in numbers]
    int_result = int("".join(str_nums))
    return int_result

result = collapse_int_list(carried)
print(result)

4140


Finished! 

Just to be sure, let's wrap the components in a single function and test it with a few different numbers to make sure it doesn't only work with the data we used to write it. To be clear, we haven't written this implementation to handle negative or non-integer numbers, so we'll just test positive integers.

In [11]:
def lattice(a: int, b: int) -> int:
    matrix = make_matrix(a, b)
    filled = fill_matrix(a, b, matrix)
    components = read_matrix(filled)
    carried = carry_the_one(components)
    result = collapse_int_list(carried)
    
    return result


print(lattice(99, 88) == 99 * 88)
print(lattice(12345, 678910) == 12345 * 678910)
print(lattice(112233, 998877) == 112233 * 998877)



True
True
True


Looks good!

This algorithm is a little simpler than the Needlman-Wunsch algorithm you will be implementing. However, some of the approaches used here will be helpful in thinking about how you can go about structuring your data and tracing back through after initialization.