# Searching for faster matrix multiplication algorithms with quantum annealing

This notebook aims to implement a quantum annealing-based algorithm to discover matrix multiplication algorithms. The inspiration to work on this topic arises from the interesting paper: https://doi.org/10.1038/s41586-022-05172-4.

In [1]:
import dimod
from dimod.generators.constraints import combinations
from dwave.system import LeapHybridSampler
from hybrid.reference import KerberosSampler

import numpy as np
from numpy.random import shuffle
import json
import itertools
import os
import math
import random
import matplotlib.pyplot as plt

from utils import *

notebook_path = os.path.abspath("main.ipynb")

In [2]:
# Square matrices
dim = 2
n_elem = dim**2

In [3]:
original_multiplication = [[0,0,0], [1,2,0], [0,1,1], [1,3,1], [2,0,2], [3,2,2], [2,1,3], [3,3,3]]

In [4]:
strassen_tensors = [np.tensordot([0,0,0,1], np.tensordot([-1,0,1,0], [1,0,1,0], axes=0), axes=0),
          np.tensordot([1,1,0,0], np.tensordot([0,0,0,1], [-1,1,0,0], axes=0), axes=0),
           np.tensordot([-1,0,1,0], np.tensordot([1,1,0,0], [0,0,0,1], axes=0), axes=0)]
           #np.tensordot([1,0,0,1], np.tensordot([1,0,0,1], [1,0,0,1], axes=0), axes=0)]
          #np.tensordot([0,1,0,-1], np.tensordot([0,0,1,1], [1,0,0,0], axes=0), axes=0),
           #np.tensordot([1,0,0,0], np.tensordot([0,1,0,-1], [0,1,0,1], axes=0), axes=0),
           #np.tensordot([0,0,1,1], np.tensordot([1,0,0,0], [0,0,1,-1], axes=0), axes=0)]

In [5]:
#print("Matrix operations encoded to tensor: ", initial_tensor)
initial_tensor = get_standard_tensor(dim)
#print(initial_tensor)
#print(np.count_nonzero(initial_tensor.flatten()))
#print("----------------------------------------")
for t in strassen_tensors:
    initial_tensor -= t
    #print()
    #print(initial_tensor % 2)
    #print(np.count_nonzero(initial_tensor.flatten()))
initial_tensor = initial_tensor % 2
#print("After substracting the tensors (which describe the matrix multiplication operations) we obtain zero tensor: ", initial_tensor)
#initial_tensor = get_standard_tensor(dim)

[[[1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]]

 [[0. 0. 1. 0.]
  [0. 0. 0. 1.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 1. 0.]
  [0. 0. 0. 1.]]]
8
----------------------------------------

[[[1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]]

 [[0. 0. 1. 0.]
  [0. 0. 0. 1.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[1. 0. 1. 0.]
  [0. 0. 0. 0.]
  [1. 0. 0. 0.]
  [0. 0. 0. 1.]]]
10

[[[1. 0. 0. 0.]
  [0. 1. 0. 0.]
  [0. 0. 0. 0.]
  [1. 1. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [1. 0. 0. 0.]
  [1. 0. 0. 0.]]

 [[0. 0. 1. 0.]
  [0. 0. 0. 1.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[1. 0. 1. 0.]
  [0. 0. 0. 0.]
  [1. 0. 0. 0.]
  [0. 0. 0. 1.]]]
12

[[[1. 0. 0. 1.]
  [0. 1. 0. 1.]
  [0. 0. 0. 0.]
  [1. 1. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [1. 0. 0. 0.]
  [1. 0. 0. 0.]]

 [[0. 0. 1. 1.]
  [0. 0.

In [6]:
num_of_ones = 14
max_tensor_init = np.array(list(np.zeros(64 - num_of_ones)) + list(np.ones(num_of_ones)))
shuffle(max_tensor_init)
max_tensor = np.reshape(max_tensor_init, (4, 4, 4))

In [7]:
def construct_uvw_tensor(dim):
    indices_1 = []
    indices_2 = []
    indices_3 = []
    for x in range(dim**2):
        for y in range(dim**2):
            for z in range(dim**2):
                indices_1.append(((x, y), z))
                indices_2.append(((x, z), y))
                indices_3.append(((y, z), x))
    return indices_1, indices_2, indices_3

In [8]:
def towards_origo(initial_tensor, dim):
    vartype = dimod.BINARY
    linear = dict()
    quadratic = dict()
    offset = 0.0

    indices_1, indices_2, indices_3 = construct_uvw_tensor(dim)

    # ((x, y), z)
    for i in indices_1:
        coeff = 1
        x = "x" + str(i[0][0])
        y = "y" + str(i[0][1])
        z = "z" + str(i[1])
        pair = (x, y)
        # Penalize cases when there is difference
        if initial_tensor[i[0][0]][i[0][1]][i[1]] != 0:
            offset += 1
            coeff = -1
            
            # x + 2xy - 4x(x,y) + y - 4y(x,y) + 4(x,y)
            # For each pair x, y we create once the constraint 
            # (2*(x,y) - x - y)^2 = 4(x,y) + x + y - 4(x,y)x - 4(x,y)y + 2xy
            if (x,y) not in linear:
                linear[(x, y)] = 4
                append_linear_safe(x, 1, linear)
                append_linear_safe(y, 1, linear)
                #linear[x] = 1
                #linear[y] = 1

                quadratic[((x, y), x)] = -4
                quadratic[((x, y), y)] = -4
                quadratic[(x, y)] = 2
        
        append_quadratic_safe((pair, z), coeff, quadratic)

    # ((x, z), y)
    for i in indices_2:
        coeff = 1
        x = "x" + str(i[0][0])
        y = "y" + str(i[1])
        z = "z" + str(i[0][1])
        pair = (x, z)
        # Penalize cases when there is difference
        if initial_tensor[i[0][0]][i[1]][i[0][1]] != 0:
            offset += 1
            coeff = -1
            
            # (2*(x,z) - x - z)^2 = 4(x,z) + x + z - 4(x,z)x - 4(x,z)z + 2xz
            if (x, z) not in linear:
                linear[(x, z)] = 4
                append_linear_safe(x, 1, linear)
                append_linear_safe(z, 1, linear)
                #linear[x] = 1
                #linear[z] = 1

                quadratic[((x, z), x)] = -4
                quadratic[((x, z), z)] = -4
                quadratic[(x, z)] = 2
        
        append_quadratic_safe((pair, y), coeff, quadratic)
        

    # ((y, z), x)
    for i in indices_3:
        coeff = 1
        x = "x" + str(i[1])
        y = "y" + str(i[0][0])
        z = "z" + str(i[0][1])
        pair = (y, z)
        # Penalize cases when there is difference
        if initial_tensor[i[1]][i[0][0]][i[0][1]] != 0:
            offset += 1
            coeff = -1
            
            # (2*(y,z) - y - z)^2 = 4(y,z) + y + z - 4(y,z)y - 4(y,z)z + 2yz
            if (y, z) not in linear:
                linear[(y, z)] = 4
                append_linear_safe(y, 1, linear)
                append_linear_safe(z, 1, linear)
                #linear[y] = 1
                #linear[z] = 1

                quadratic[((y, z), y)] = -4
                quadratic[((y, z), z)] = -4
                quadratic[(y, z)] = 2
        append_quadratic_safe((pair, x), coeff, quadratic)

    bqm = dimod.BinaryQuadraticModel(linear, quadratic, offset, vartype)
    return bqm

In [9]:
def towards_standard(initial_tensor, dim):
    standard = np.array([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0]], [[0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0 ,0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]])
    vartype = dimod.BINARY
    linear = dict()
    quadratic = dict()
    offset = 0.0

    indices_1, indices_2, indices_3 = construct_uvw_tensor(dim)

    # ((x, y), z)
    for i in indices_1:
        coeff = 1
        x = "x" + str(i[0][0])
        y = "y" + str(i[0][1])
        z = "z" + str(i[1])
        pair = (x, y)
        # Penalize cases when there is difference
        if initial_tensor[i[0][0]][i[0][1]][i[1]] != standard[i[0][0]][i[0][1]][i[1]]:
            offset += 1
            coeff = -1
            
            # x + 2xy - 4x(x,y) + y - 4y(x,y) + 4(x,y)
            # For each pair x, y we create once the constraint 
            # (2*(x,y) - x - y)^2 = 4(x,y) + x + y - 4(x,y)x - 4(x,y)y + 2xy
            if (x,y) not in linear:
                linear[(x, y)] = 4
                append_linear_safe(x, 1, linear)
                append_linear_safe(y, 1, linear)
                #linear[x] = 1
                #linear[y] = 1

                quadratic[((x, y), x)] = -4
                quadratic[((x, y), y)] = -4
                quadratic[(x, y)] = 2
        
        append_quadratic_safe((pair, z), coeff, quadratic)

    # ((x, z), y)
    for i in indices_2:
        coeff = 1
        x = "x" + str(i[0][0])
        y = "y" + str(i[1])
        z = "z" + str(i[0][1])
        pair = (x, z)
        # Penalize cases when there is difference
        if initial_tensor[i[0][0]][i[1]][i[0][1]] != standard[i[0][0]][i[1]][i[0][1]]:
            offset += 1
            coeff = -1
            
            # (2*(x,z) - x - z)^2 = 4(x,z) + x + z - 4(x,z)x - 4(x,z)z + 2xz
            if (x, z) not in linear:
                linear[(x, z)] = 4
                append_linear_safe(x, 1, linear)
                append_linear_safe(z, 1, linear)
                #linear[x] = 1
                #linear[z] = 1

                quadratic[((x, z), x)] = -4
                quadratic[((x, z), z)] = -4
                quadratic[(x, z)] = 2
        
        append_quadratic_safe((pair, y), coeff, quadratic)
        

    # ((y, z), x)
    for i in indices_3:
        coeff = 1
        x = "x" + str(i[1])
        y = "y" + str(i[0][0])
        z = "z" + str(i[0][1])
        pair = (y, z)
        # Penalize cases when there is difference
        if initial_tensor[i[1]][i[0][0]][i[0][1]] != standard[i[1]][i[0][0]][i[0][1]]:
            offset += 1
            coeff = -1
            
            # (2*(y,z) - y - z)^2 = 4(y,z) + y + z - 4(y,z)y - 4(y,z)z + 2yz
            if (y, z) not in linear:
                linear[(y, z)] = 4
                append_linear_safe(y, 1, linear)
                append_linear_safe(z, 1, linear)
                #linear[y] = 1
                #linear[z] = 1

                quadratic[((y, z), y)] = -4
                quadratic[((y, z), z)] = -4
                quadratic[(y, z)] = 2
        append_quadratic_safe((pair, x), coeff, quadratic)

    bqm = dimod.BinaryQuadraticModel(linear, quadratic, offset, vartype)
    return bqm

In [10]:
def process_result(sample, dim):
    x, y, z = [], [], []
    for i in range(dim**2):
        x.append(sample["x" + str(i)])
        y.append(sample["y" + str(i)])
        z.append(sample["z" + str(i)])
    if all([i == 0 for i in x]) and all([i == 0 for i in y]) and all([i == 0 for i in z]):
        return [], [], []
    if all([i == 1 for i in x]) and all([i == 1 for i in y]) and all([i == 1 for i in z]):
        return [], [], []
    return np.array(x), np.array(y), np.array(z)

In [11]:
def search_matrix_multiplication(initial_tensor, limit, file_name):
    tensor1 = initial_tensor
    tensor2 = initial_tensor - np.tensordot([1,0,0,1], np.tensordot([1,0,0,1], [1,0,0,1], axes=0), axes=0)
    file = open("results//" + file_name + ".txt", "w")
    file.write(str(initial_tensor) + "\n")
    file.close()
    for i in range(limit):
        bqm1 = towards_standard(tensor1, 2)
        bqm2 = towards_origo(tensor2, 2)
        sample1 = solve_bqm_in_leap(bqm1, sampler ="Greedy")[0]
        sample2 = solve_bqm_in_leap(bqm2, sampler ="Greedy")[0]
        x1, y1, z1 = process_result(sample1, 2)
        x2, y2, z2 = process_result(sample2, 2)
        if len(x1) > 0 and len(x2) > 0:
            tensor1 = (tensor1 - np.tensordot(x1, np.tensordot(y1, z1, axes=0), axes=0)) % 2
            tensor2 = (tensor2 - np.tensordot(x2, np.tensordot(y2, z2, axes=0), axes=0)) % 2
            file = open("results//" + file_name + ".txt", "a")
            file.write(str(x1) + " " + str(y1) + " " + str(z1) + "\n")
            file.write(str(tensor1) + "\n")
            file.write(str(x2) + " " + str(y2) + " " + str(z2) + "\n")
            file.write(str(tensor2) + "\n")
            file.close()
            if np.count_nonzero(tensor2.flatten()) == 0:
                print("End")
                break

In [12]:
search_matrix_multiplication(initial_tensor, 3, "strassen")

Energy:  30.0
Energy:  30.0
Energy:  15.0
Energy:  15.0
Energy:  0.0
Energy:  0.0
End
