In [1]:
import math
import numpy as np
from torch import nn
import torch

from tltorch.functional import factorized_linear
from tltorch.factorized_tensors import TensorizedTensor

from tltorch.factorized_layers.factorized_linear import FactorizedLinear

import tensorly as tl



For a **linear layer**  
$\mathbf{y = Wx + b}$  
where  
$\mathbf{y} \in \mathbb{R}^M$  
$\mathbf{W} \in \mathbb{R}^{M \times N}$  
$\mathbf{x} \in \mathbb{R}^N$  
$\mathbf{b} \in \mathbb{R}^M$


its **tensorized linear layer** is  
$\mathcal{Y = WX + B}$  
where  
$\mathcal{Y} \in \mathbb{R}^{m_1 \times m_2 \times \cdots m_{d_M}}$  
$\mathcal{W} \in \mathbb{R}^{m_1 \times m_2 \times \cdots m_{d_M} \times n_1 \times n_2 \times \cdots n_{d_N}}$  
$\mathcal{X} \in \mathbb{R}^{n_1 \times n_2 \times \cdots n_{d_N}}$  
$\mathcal{B} \in \mathbb{R}^{m_1 \times m_2 \times \cdots m_{d_M}}$  
$M = \prod_{k=1}^{d_M} m_k$  
$N = \prod_{k=1}^{d_N} n_k$  

In [2]:
fl = FactorizedLinear(in_tensorized_features=(3, 2, 2, 4), out_tensorized_features=(2, 2, 4), bias=True, factorization='cp', rank=10, n_layers=1, factorized_forward=False)

In this example    
`out_features` $M=16$ and `out_tensorized_features` $(m_1, m_2, m_3) = (2, 2, 4)$  
`in_features` $N=48$ and `in_tensorized_features` $(n_1, n_2, n_3, n_4) = (3, 2, 2, 4)$  


Therefore  
$\mathbf{W}$ has the `weight_shape` of $(M, N) = (16, 48)$  
$\mathcal{W}$ has the `tensorized_shape` of $(m_1, m_2, m_3, n_1, n_2, n_3, n_4) = (2, 2, 4, 3, 2, 2, 4)$  with the the `order` of 7  

In [3]:
print('out_features: {}'.format(fl.out_features))
print('in_features: {}'.format(fl.in_features))
print('out_tensorized_features: {}'.format(fl.out_tensorized_features))
print('in_tensorized_features: {}'.format(fl.in_tensorized_features))
print('weight_shape: {}'.format(fl.weight_shape))
print('tensorized_shape: {}'.format(fl.tensorized_shape))
print('order: {}'.format(fl.weight.order))

out_features: 16
in_features: 48
out_tensorized_features: (2, 2, 4)
in_tensorized_features: (3, 2, 2, 4)
weight_shape: (16, 48)
tensorized_shape: (2, 2, 4, 3, 2, 2, 4)
order: 7


$\mathcal{W}$ is **factorized** into `rank` $R=10$ `factors` = [`out_factors`, `in_factors`] using **CP-decomposition**

This can be expressed as  
$\mathcal{W} = \sum_{r=1}^{R} \mathbf{gm}_1[:,r] \otimes \mathbf{gm}_2[:,r] \otimes \mathbf{gm}_3[:,r] \otimes \mathbf{gn}_1[:,r] \otimes \mathbf{gn}_2[:,r] \otimes \mathbf{gn}_3[:,r] \otimes \mathbf{gn}_4[:,r]$  
where  
`out_factors` $\mathbf{gm}_k \in \mathbb{R}^{m_k \times R}\ \forall k \in [1,2,...,d_M]$  
`in_factors` $\mathbf{gn}_k \in \mathbb{R}^{n_k \times R}\ \forall k \in [1,2,...,d_N]$  

In [4]:
print('decomposition: {}'.format(fl.weight.name))
print('rank: {}'.format(fl.rank))
print('factors: {}'.format(fl.weight.factors))
out_factors = fl.weight.factors[:len(fl.out_tensorized_features)]
print("out_factors: {}".format(out_factors))
in_factors = fl.weight.factors[len(fl.out_tensorized_features):]
print("in_factors: {}".format(in_factors))

decomposition: CP
rank: 10
factors: FactorList(
    (factor_0): Parameter containing: [torch.FloatTensor of size 2x10]
    (factor_1): Parameter containing: [torch.FloatTensor of size 2x10]
    (factor_2): Parameter containing: [torch.FloatTensor of size 4x10]
    (factor_3): Parameter containing: [torch.FloatTensor of size 3x10]
    (factor_4): Parameter containing: [torch.FloatTensor of size 2x10]
    (factor_5): Parameter containing: [torch.FloatTensor of size 2x10]
    (factor_6): Parameter containing: [torch.FloatTensor of size 4x10]
)
out_factors: FactorList(
    (factor_0): Parameter containing: [torch.FloatTensor of size 2x10]
    (factor_1): Parameter containing: [torch.FloatTensor of size 2x10]
    (factor_2): Parameter containing: [torch.FloatTensor of size 4x10]
)
in_factors: FactorList(
    (factor_0): Parameter containing: [torch.FloatTensor of size 3x10]
    (factor_1): Parameter containing: [torch.FloatTensor of size 2x10]
    (factor_2): Parameter containing: [torch.Fl

The original **tltorch** `FactorizedLinear` implementation reconstructs the $\mathbf{W}$ from the `factors` and use regular `Linear` layer during the forward propagation

In [5]:
vector_input = torch.rand(size=(fl.in_features,))
regular_forward_output = fl.forward(vector_input)
print('Regular Forward Propagation Output:\n{}'.format(regular_forward_output))

Regular Forward Propagation Output:
tensor([ 0.0068,  0.1732,  0.0455, -0.1927,  0.2168,  0.0134,  0.2249,  0.7036,
        -0.1213,  0.3807, -0.6335, -1.5395, -0.0290, -0.1962, -0.4274,  1.3908],
       grad_fn=<AddBackward0>)




In order to do **tensorized forward propagation** and **factorized forward propagation**, we need to tensorize $\mathbf{x}$ into $\mathcal{X}$ and $\mathbf{b}$ into $\mathcal{B}$ using the same **bijective mapping functions** that tensorize $\mathbf{W}$ into $\mathcal{W}$  

`out_index_to_tensorized_out_index` $\mathbf{f}_i: \mathbb{Z}_+ \rightarrow \mathbb{Z}_+^{d_M}$ is a function that transforms `out_index` $p \in \{1,2,...,M\}$ into `tensorized_out_index` $\mathbf{f}_i(p)=[i_1(p),i_2(p),...,i_M(p)]$ 

In [6]:
out_indices = torch.arange(fl.out_features)
tensorized_out_indices = tl.base.vec_to_tensor(vec=out_indices, shape=fl.out_tensorized_features)

def out_index_to_tensorized_out_index(out_index):
    tensorized_out_index = (tensorized_out_indices == out_index).nonzero().squeeze().tolist()
    tensorized_out_index = tuple(tensorized_out_index)
    return tensorized_out_index

out_index = torch.randint(low=0, high=fl.out_features, size=(1,)).item()
tensorized_out_index = out_index_to_tensorized_out_index(out_index=out_index)
print("Tensorized Out Index for {}: {}".format(out_index, tensorized_out_index))
out_index_check = tensorized_out_indices[tensorized_out_index]
print("Out Index for {}: {}".format(tensorized_out_index, out_index_check))

Tensorized Out Index for 13: (1, 1, 1)
Out Index for (1, 1, 1): 13


`in_index_to_tensorized_in_index` $\mathbf{f}_j: \mathbb{Z}_+ \rightarrow \mathbb{Z}_+^{d_N}$ is a function that transforms `in_index` $q \in \{1,2,...,N\}$ into `tensorized_in_index` $\mathbf{f}_j(q)=[j_1(q),j_2(q),...,j_N(q)]$  

In [7]:
in_indices = torch.arange(fl.in_features)
tensorized_in_indices = tl.base.vec_to_tensor(vec=in_indices, shape=fl.in_tensorized_features)

def in_index_to_tensorized_in_index(in_index):
    tensorized_in_index = (tensorized_in_indices == in_index).nonzero().squeeze().tolist()
    tensorized_in_index = tuple(tensorized_in_index)
    return tensorized_in_index


in_index = torch.randint(low=0, high=fl.in_features, size=(1,)).item()
tensorized_in_index = in_index_to_tensorized_in_index(in_index=in_index)
print("Tensorized In Index for {}: {}".format(in_index, tensorized_in_index))
in_index_check = tensorized_in_indices[tensorized_in_index]
print("In Index for {}: {}".format(tensorized_in_index, in_index_check))

Tensorized In Index for 23: (1, 0, 1, 3)
In Index for (1, 0, 1, 3): 23


A reality check that $\mathbf{W}$ and $\mathcal{W}$ have equal elements

In [8]:
matrix_weight = fl.weight.to_matrix()
tensor_weight = fl.weight.to_tensor()
print('Matrix form element sum: {}'.format(matrix_weight.sum().item()))
print('Tensor form element sum: {}'.format(tensor_weight.sum().item()))

Matrix form element sum: -1.3805129528045654
Tensor form element sum: -1.3805129528045654


Another reality check that $\mathbf{W}(p,q)$ equals $\mathcal{W}(\mathbf{f}_i(p),\mathbf{f}_j(q))$  

In [9]:
matrix_index = (out_index, in_index)
print('Matrix form element at (out_index, in_index): {}'.format(matrix_weight[matrix_index].item()))
tensorized_index = tensorized_out_index + tensorized_in_index
print('Tensor form element at (tensorized_out_index, tensorized_in_index): {}'.format(tensor_weight[tensorized_index].item()))

Matrix form element at (out_index, in_index): -0.01583530567586422
Tensor form element at (tensorized_out_index, tensorized_in_index): -0.01583530567586422


Final reality check that $\mathbf{x}(p,q)$ equals $\mathcal{X}(\mathbf{f}_i(p),\mathbf{f}_j(q))$ and $\mathbf{b}(p,q)$ equals $\mathcal{B}(\mathbf{f}_i(p),\mathbf{f}_j(q))$  

In [10]:
tensorized_input = tl.base.vec_to_tensor(vec=vector_input, shape=fl.in_tensorized_features)
vector_bias = fl.bias
tensorized_bias = tl.base.vec_to_tensor(vec=vector_bias, shape=fl.out_tensorized_features)
print('Vector form input element at (in_index): {}'.format(tensorized_input[tensorized_in_index]))
print('Tensor form input element at (tensorized_in_index): {}'.format(tensorized_input[tensorized_in_index]))
print('Vector form bias element at (out_index): {}'.format(tensorized_bias[tensorized_out_index]))
print('Tensor form bias element at (tensorized_out_index): {}'.format(tensorized_bias[tensorized_out_index]))

Vector form input element at (in_index): 0.02932971715927124
Tensor form input element at (tensorized_in_index): 0.02932971715927124
Vector form bias element at (out_index): -0.04427202790975571
Tensor form bias element at (tensorized_out_index): -0.04427202790975571


**Tensorized forward propagation** is  
$\mathcal{Y}(\mathbf{f}_i(p)) = \sum_{q=1}^N \mathcal{W}(\mathbf{f}_i(p),\mathbf{f}_j(q)) \mathcal{X}(\mathbf{f}_j(q)) + \mathcal{B}(\mathbf{f}_i(p))$  

In [11]:
dims = torch.arange(fl.weight.order)
in_dims = tuple(dims[len(fl.out_tensorized_features):].tolist())
tensorized_forward_output = (tensor_weight * tensorized_input).sum(dim=in_dims) + tensorized_bias
print('Tensorized Forward Propagation Output:\n{}'.format(tensorized_forward_output))
tensorized_regular_forward_output = tl.base.vec_to_tensor(vec=regular_forward_output, shape=fl.out_tensorized_features)
print('Tensorized Regular Forward Output:\n{}'.format(tensorized_regular_forward_output))

Tensorized Forward Propagation Output:
tensor([[[ 0.0068,  0.1732,  0.0455, -0.1927],
         [ 0.2168,  0.0134,  0.2249,  0.7036]],

        [[-0.1213,  0.3807, -0.6335, -1.5395],
         [-0.0290, -0.1962, -0.4274,  1.3908]]], grad_fn=<AddBackward0>)
Tensorized Regular Forward Output:
tensor([[[ 0.0068,  0.1732,  0.0455, -0.1927],
         [ 0.2168,  0.0134,  0.2249,  0.7036]],

        [[-0.1213,  0.3807, -0.6335, -1.5395],
         [-0.0290, -0.1962, -0.4274,  1.3908]]],
       grad_fn=<ReshapeAliasBackward0>)


We can reduce the use of mapping functions with the following equation  

$\mathbf{y}(p) = \sum_{q=1}^N \mathcal{W}(\mathbf{f}_i(p),\mathbf{f}_j(q)) \mathbf{x}(q) + \mathbf{b}(p)$  

This can be re-written in **factorized forward propagation** as  
$\mathbf{y}(p) = \sum_{q=1}^N \left( \sum_{r=1}^R \left( \prod_{k=1}^{d_M} \mathbf{gm}_{k,r}(i_k(p)) \prod_{k=1}^{d_N} \mathbf{gm}_{k,r}(j_k(q)) \right) \right) \mathbf{x}(q) + \mathbf{b}(p)$  

In [12]:
factorized_forward_output = torch.zeros(size=(fl.out_features,))
for out_index in range(fl.out_features):
    for in_index in range(fl.in_features):
        out = 1
        tensorized_out_index = out_index_to_tensorized_out_index(out_index=out_index)
        for (factor, index) in zip(out_factors, tensorized_out_index):
            out *= factor[index]
        tensorized_in_index = in_index_to_tensorized_in_index(in_index=in_index)
        for (factor, index) in zip(in_factors, tensorized_in_index):
            out *= factor[index]
        out = out.sum()
        out *= vector_input[in_index]
        factorized_forward_output[out_index] += out
    factorized_forward_output[out_index] += vector_bias[out_index]

In [13]:
print('Regular Foward Output:\n{}'.format(regular_forward_output))
print('Factorized Forward Output:\n{}'.format(factorized_forward_output))

Regular Foward Output:
tensor([ 0.0068,  0.1732,  0.0455, -0.1927,  0.2168,  0.0134,  0.2249,  0.7036,
        -0.1213,  0.3807, -0.6335, -1.5395, -0.0290, -0.1962, -0.4274,  1.3908],
       grad_fn=<AddBackward0>)
Factorized Forward Output:
tensor([ 0.0068,  0.1732,  0.0455, -0.1927,  0.2168,  0.0134,  0.2249,  0.7036,
        -0.1213,  0.3807, -0.6335, -1.5395, -0.0290, -0.1962, -0.4274,  1.3908],
       grad_fn=<CopySlices>)


In [14]:
fl = FactorizedLinear(in_tensorized_features=(3, 2, 2, 4), out_tensorized_features=(2, 2, 4), bias=True, factorization='cp', rank=10, n_layers=1, factorized_forward=True)

In [15]:
fl.forward(vector_input)

tensor([ 0.1718, -0.0702, -0.1124, -0.2579, -0.1783, -0.0350,  0.1880,  0.0045,
        -0.0248,  0.0566, -0.2910,  0.0991,  0.0888, -0.2756, -0.0525, -0.3440],
       grad_fn=<CopySlices>)