<a href="https://colab.research.google.com/github/vnicula/ai-research-code/blob/master/algorithms/Copy_of_Explore_factorizations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Loading factorizations found by AlphaTensor and recombination.

- Copyright 2022 DeepMind Technologies Limited
- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0
- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode
- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.
- This is not an official Google product.

In [None]:
import numpy as np
from google.colab import files

Upload one of the two files provided in the same folder: `factorization_r.npz` (algorithms in standard arithmetic) or `factorization_f2.npz` (algorithms in arithmetic modulo 2).

In [None]:
uploaded = files.upload()
filename = list(uploaded.keys())[0]
with open(filename, 'rb') as f:
  factorizations = dict(np.load(f, allow_pickle=True))

In [None]:
# Print available factorizations and their shapes.
for key in factorizations:
  u, v, w = factorizations[key]
  rank = u.shape[-1]
  assert rank == v.shape[-1] and rank == w.shape[-1]
  print(f'{key}: rank={u.shape[-1]}')

Please note that as provided, the factorizations decompose the *symmetrized* version of the matrix multiplication tensor, representing the bilinear operation $\mathbf{A}, \mathbf{B} \mapsto (\mathbf{A} \cdot \mathbf{B})^T$. This is standard in the literature, and factorizations can be easily converted
between the symmetrized and non-symmetrized versions.

In [None]:
def get_mamu_tensor_rectangular(a: int, b: int, c: int) -> np.ndarray:
  """Returns the symmetrized matrix multiplication tensor T_{a, b, c}."""
  result = np.full((a*b, b*c, c*a), 0, dtype=np.int32)
  for i in range(a):
    for j in range(b):
      for k in range(c):
        result[i * b  + j][j * c + k][k * a + i] = 1
  return result


# Test correctness of a factorization.
tensor = get_mamu_tensor_rectangular(2, 2, 3)
u, v, w = factorizations['2,2,3']
reconstruction = np.einsum('ir,jr,kr->ijk', u, v, w)
if np.array_equal(tensor, reconstruction):
  print('Factorization is correct in R (standard arithmetic).')
elif np.array_equal(tensor, np.mod(reconstruction, 2)):
  print('Factorization is correct in F2 (modular arithmetic).')
else:
  print('Factorization is incorrect.')

In [None]:
print("u vector:\n", u)

In [None]:
print("v vector:\n", v)

In [None]:
print("w vector:\n", w)

In [None]:
print(tensor)

In [None]:
u0 = u[:, 0]
v0 = v[:, 0]
w0 = w[:, 0]
c0 = np.einsum('i,j,k->ijk', u0, v0, w0)

u1 = u[:, 1]
v1 = v[:, 1]
w1 = w[:, 1]
c1 = np.einsum('i,j,k->ijk', u1, v1, w1)

for i in range(0, len(u[0])):
  print(u[:, i])
  print(v[:, i])
  print(w[:, i])

In [None]:
import matplotlib.colors
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import cm
import numpy as np


def explode(data):
    size = np.array(data.shape)*2
    data_e = np.zeros(size - 1, dtype=data.dtype)
    data_e[::2, ::2, ::2] = data
    return data_e

def plot_tensor(n_voxels, ind=0, save=False, A=False):

  facecolors = np.where(n_voxels, '#FFD65DA0', '#EEEEFF00')
  # print(np.where(n_voxels, '#FFD65DA0', '#7A88CCA0').shape)
  facecolors[n_voxels>2] = '#00004d70'
  facecolors[n_voxels==2] = '#2e5dcc70'
  facecolors[n_voxels==1] = '#5DADE270'
  facecolors[n_voxels==-1] = '#CD615570'
  facecolors[n_voxels==-2] = '#ff661f70'
  facecolors[n_voxels<-2] = '#80000070'
  facecolors[n_voxels==0] = '#EEEEFF00'

  edgecolors = np.where(n_voxels!=0, '#DDDDDE30', '#CDDDDD3F')
  filled = np.ones(n_voxels.shape)

  # upscale the above voxel image, leaving gaps
  filled_2 = explode(filled)
  fcolors_2 = explode(facecolors)
  ecolors_2 = explode(edgecolors)

  # Shrink the gaps
  x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2
  x[0::2, :, :] += 0.03
  y[:, 0::2, :] += 0.03
  z[:, :, 0::2] += 0.03
  x[1::2, :, :] += 0.97
  y[:, 1::2, :] += 0.97
  z[:, :, 1::2] += 0.97

  fig = plt.figure(figsize=(11, 11))
  ax = fig.add_subplot(projection='3d')
  ax.set_yticklabels([])
  ax.set_xticklabels([])
  ax.set_zticklabels([])
  ax.grid(False)
  ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
  ax.zaxis.set_major_locator(ticker.MultipleLocator(1))

  if A:
    ax.set_ylabel('B')
    ax.set_xlabel('A')
    ax.set_zlabel('C')
  else:
    ax.set_ylabel('V')
    ax.set_xlabel('U')
    ax.set_zlabel('W')

  p=ax.voxels(x, y, z, filled_2, facecolors=fcolors_2, edgecolors=ecolors_2)

  # cmap = matplotlib.colors.ListedColormap(['#FF338DA0', '#FF555DA0', 'lightgray', '#55A6FFA0', '#7796FFF0']) #.with_extremes(over='purple', under='blue')
  # bounds = [-2.0, -1.0, -0.1, 0.1, 1.0, 2.0]
  # norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
  # m = cm.ScalarMappable(cmap=cmap, norm=norm)
  # m.set_array([])
  # fig.colorbar(
  #     m,
  #     ax=ax,
  #     # extend='both',
  #     # extendfrac='auto',
  #     # ticks=bounds,
  #     spacing='uniform',
  #     orientation='vertical',
  #     # label='Custom extension lengths, some other units',
  #     shrink=0.7, aspect=20*0.7
  # )

  if save:
    filename = f'frame_{ind}.png'
    plt.savefig(filename)
  else:
    plt.show()
  
  plt.close()

In [None]:
def plot_solution(mamu, u, v, w, save=False):

  current = mamu
  for i in range(0, len(u[0])):
    print("Step %d" % i)
    plot_tensor(current, 2*i, save, i==0)
    c = np.einsum('i,j,k->ijk', u[:, i], v[:, i], w[:, i])
    plot_tensor(c, 2*i+1, save)
    current = current - c
  plot_tensor(current, len(u[0])*2, save)


In [None]:
tensor = get_mamu_tensor_rectangular(2, 2, 3)
u, v, w = factorizations['2,2,3']
plot_solution(tensor, u, v, w, False)

In [None]:
import imageio
import os
with imageio.get_writer('myframes.gif', mode='I') as writer:
    for i in range(0, len(u[0])):
      filename = f"frame_{i}.png"
      image = imageio.imread(filename)
      writer.append_data(image)
print('Gif saved\n')


In [None]:
for i in range(0, 2*len(u[0])+1):
  if os.path.exists(f"frame_{i}.png"):
    os.remove(f"frame_{i}.png")