# Compression of Weights

In this notebook we use the `tensorly` library to compress the weights of neural networks. We focus on *PARAFAC* and *TUCKER* decomposition to start with. The notebook takes a set of weights in NumPy format as input (a dictionary of keys name --> tensor) and outputs a dictionary in the same format.

In [1]:
#@title Install Dependencies
#@markdown Run this cell to install all needed depedencies.
!pip -q install tensorly

In [2]:
#@title Import Modules
#@markdown We need to import the decomposition (and reconstruction) modules to
#@markdown reduce the rank of the tensors.
import joblib
from tensorly.decomposition import parafac, tucker
from tensorly.cp_tensor import cp_to_tensor
from tensorly.tucker_tensor import tucker_to_tensor
from google.colab import files

methods = {'parafac': (parafac, cp_to_tensor),
           'tucker': (tucker, tucker_to_tensor)
           }
method = "parafac" #@param ['parafac', 'tucker'] {type: "string"}
rank = 2 #@param {type: "integer"}
init = "random" #@param ['random', 'svd'] {type: "string"}
decomp, recons = methods[method]

In [3]:
#@title Upload the Weights
#@markdown We first upload the data in dictionary (name -> tensor) format.
upload = files.upload()
tensors_dict = joblib.load(list(upload.keys())[0])

In [None]:
#@title Compute the Decomposition
#@markdown Compute the decomposition and reconstruction of the weights of the
#@markdown network. We choose a **limit** in the number of the tensors to
#@markdown process as it is a **memory-consuming** operation: the manual limit
#@markdown allows for the computation to finish **without crashing the 
#@markdown notebook** (`limit` $\le 0$ disables the counter).
limit = -1 #@param {type: "integer"}
scanned_tensors = 0
for name, tensor in tensors_dict.items():

  if scanned_tensors > limit and limit > 0: break

  if tensor.ndim > 2 and "weight" in name:

    print(f'Scanning {name} of shape {list(tensor.shape)}...')

    # decomposition
    tensor_decomp = decomp(tensor, rank, init=init)

    # recompose the tensor
    tensor_recons = recons(tensor_decomp)

    # substitute the tensor
    tensors_dict[name] = tensor

    # counter
    scanned_tensors += 1

  else:

    print(f'> Skipping {name}...')

In [None]:
#@title Download the Tensors
#@markdown After the computation of the decomposition, we recover the reduced
#@markdown tensors.
weights_output = "weights_modified.joblib" #@param {type: "string"}
path = joblib.dump(tensors_dict, weights_output)
print(f'Weights saved to {path[0]}. Beginning download...')
files.download(path[0])