# Exercise 4 - Tensor Networks
In this exercise, we will inspect the canonical parameterization of a graphical model and calculate the normalization constant to answer inference queries.

Later, we will compare the speed of calculating the normalization constant using different orders of tensor contractions.

In the event of a persistent problem, do not hesitate to contact the course instructors under
- paul.kahlmeyer@uni-jena.de

### Submission

- Deadline of submission:
        27.11.2022
- Submission on [moodle page](https://moodle.uni-jena.de/course/view.php?id=34630)

### Help
In case you cannot solve a task, you can use the saved values within the `help` directory:
- Load arrays with [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.load.html)
```
np.load('help/array_name.npy')
```
- Load functions with [Dill](https://dill.readthedocs.io/en/latest/dill.html)
```
import dill
with open('help/some_func.pkl', 'rb') as f:
    func = dill.load(f)
```

to continue working on the other tasks.

## Graphical Models
Let $p(x)$ be a multivariate categorical on the sample space $\mathcal{X}$.
In the canonical parameterization we define $p$ to be an exponentiated sum of interaction order parameters:
\begin{align}
p(x) = \exp\left(q(x)\right)\,,
\end{align}
where $q(x)$ is a sum of all possible interaction orders
\begin{align}
q(x) = \sum\limits_{k=1}^n\sum\limits_{i=(i_1,\dots,i_k)}q_i(x_{i_1}, \dots, x_{i_k})\,.
\end{align}
In graphical models, we reduce the number of parameters by setting specific interactions $q_i$ to 0.

This notation is a little confusing, so lets exercise trough a **concrete example**.

Consider a multivariate categorical $p(x_0,x_1,x_2,x_3)$.
Furthermore we restrict ourselves to unary and pairwise interaction orders (interactions of order >2 have been set to 0).

This means, that we have single interaction parameter vectors $q_0, q_1, q_2, q_3$ and parwise interaction parameter matrices $q_{01}, q_{02}, q_{03}, q_{12}, q_{13}, q_{23}$.
The $q_i$ hold the (unary) interaction parameters for $x_i$ and $q_{ij}$ holds the interaction parameters for $x_i$ and $x_j$.

With these parameters, the canonical parameterization from above looks like this:
\begin{align}
q(x = [v_0, v_1, v_2, v_3]^T) &=\sum_{i=0}^3 q_i[v_i] + \sum_{j=0, j > i}^3 q_{ij}[v_i, v_j]\\
&=q_0[v_0] + q_1[v_1] + q_2[v_2] + q_3[v_3]\\
&+q_{01}[v_0, v_1] + q_{02}[v_0, v_2] + q_{03}[v_0, v_3]\\
&+q_{12}[v_1, v_2]+q_{13}[v_1, v_3]\\
&+q_{23}[v_2, v_3]\,.
\end{align}



### Task 1

Load $q_i$ and $q_ij$ from the pickeled files `q_i.p` and `q_ij.p` respectively.
How large are the sample spaces for each $x_i$?

In [1]:
import dill
import numpy as np
with open('q_i.p', 'rb') as f:
    q_i = dill.load(f)
with open('q_ij.p', 'rb') as f:
    q_ij = dill.load(f)
space_sizes = [len(x) for x in q_i]
space_sizes

[15, 50, 100, 10]

## Normalization Constant

Here we have unnormalized probabilities, so we need to calculate the normalization constant first
\begin{align}
K &= \sum_{x}p(x)\\
&= \sum_{x}\exp\left(q(x)\right)\\
&= \sum_{x}\prod_{i} \exp(q_i[x_i])\prod_{j > i} \exp(q_{ij}[x_i, x_j])\\
&= \sum_{x}\prod_{i} t_i[x_i]\prod_{j > i} t_{ij}[x_i, x_j]\,,
\end{align}
where $t_i = \exp(q_i)$ and $t_{ij} = \exp(q_{ij})$ with the elementwise exponential function.

### Task 2

A straighforward way to calculate this constant is iterating over every $x$ and summing up the $p(x)$.

Calculate $K$ using for loops.

In [3]:
import itertools
def norm_const_naive(t_i:list, t_ij:list) -> float:
    '''
    Calculates normalization constant by iterating over each x.
    
    @Params:
        t_i... unary interaction parameters (exponentiated)
        t_ij... binary interaction parameters (exponentiated)

    @Returns:
        normalization constant
    '''
    ranges = [range(len(t_i[x])) for x in range(len(t_i))] # Ranges containing indices for each feature
    xs = itertools.product(*ranges) # All possible multi-indices
    s = 0
    for x in xs:
        p = 1
        for i in range(len(t_i)):
            p *= t_i[i][x[i]]
            for j in range(len(t_i)):
                if j > i:
                    p *= t_ij[i][j][x[i]][x[j]]
        s += p
    return s

t_i = [np.exp(x) for x in q_i]
t_ij = [[np.exp(x) for x in q_inner] for q_inner in q_ij]
K = norm_const_naive(t_i, t_ij)
K

159744720.1663634

## Inference Queries

With this normalization constant, we can now actually calculate probabilities and answer inference queries.

### Task 3
Calculate the prior marginal 
\begin{align}
p(x_3)\,.
\end{align}

In [3]:
marginal = []
for x3_idx in range(space_sizes[2]):
    ranges = [range(space_sizes[0]), range(space_sizes[1]), [x3_idx], range(space_sizes[3])]
    xs = itertools.product(*ranges)
    s = 0
    for x in xs:
        p = 1
        for i in range(len(t_i)):
            p *= t_i[i][x[i]]
            for j in range(len(t_i)):
                if j > i:
                    p *= t_ij[i][j][x[i]][x[j]]
        s += p
    marginal.append(s)
marginal/K

array([0.00879772, 0.01169877, 0.00651817, 0.01446938, 0.00709923,
       0.01425781, 0.00753366, 0.01436575, 0.00900863, 0.01185483,
       0.00512198, 0.0087633 , 0.0055628 , 0.00806047, 0.00558564,
       0.00796673, 0.00881161, 0.00508333, 0.01033149, 0.01022891,
       0.00783624, 0.00933844, 0.00577802, 0.01019469, 0.01415153,
       0.00752854, 0.01109822, 0.00543556, 0.01140845, 0.00682193,
       0.00538809, 0.01115068, 0.00699935, 0.01391158, 0.00538041,
       0.01152469, 0.00931455, 0.01101375, 0.0160934 , 0.00721566,
       0.00951947, 0.00878677, 0.01142611, 0.00707002, 0.01209959,
       0.00803305, 0.01332485, 0.01158588, 0.00727241, 0.01215217,
       0.00861954, 0.01421548, 0.00940168, 0.01625568, 0.0118876 ,
       0.01437833, 0.00830323, 0.0146015 , 0.01115906, 0.01009373,
       0.01081118, 0.00500729, 0.00971673, 0.01183612, 0.00739292,
       0.01145389, 0.00939361, 0.00678895, 0.00870947, 0.0069429 ,
       0.01281083, 0.00868186, 0.01071913, 0.01269037, 0.00819

### Task 4

Calculate the probability 
\begin{equation}
p(x_2>20)\,.
\end{equation}

In [4]:
marginal = []
for x2_idx in range(space_sizes[1]):
    ranges = [range(space_sizes[0]), [x2_idx], range(space_sizes[2]), range(space_sizes[3])]
    xs = itertools.product(*ranges)
    s = 0
    for x in xs:
        p = 1
        for i in range(len(t_i)):
            p *= t_i[i][x[i]]
            for j in range(len(t_i)):
                if j > i:
                    p *= t_ij[i][j][x[i]][x[j]]
        s += p
    marginal.append(s)
sum((marginal/K)[20:])

0.5806563926647166

## Tensor Contraction
Calculating $K$ by iterating over every $x$ is quite slow.
Lets look at how we can speed up this calculation.

We can rewrite the calculation of $K$ as

\begin{align}
K &= \sum_{x}p(x)\\
&= \sum_{x}\prod_{i} \exp(q_i[x_i])\prod_{j > i} \exp(q_{ij}[x_i, x_j])\\
&= \sum_{x}\prod_{i} t_i[x_i]\prod_{j > i} t_{ij}[x_i, x_j]\\
&= \sum_{v_0=1}^{n_0}\sum_{v_1=1}^{n_1}\sum_{v_2=1}^{n_2}\sum_{v_3=1}^{n_3}\prod_{i} t_i[v_i]\prod_{j > i} t_{ij}[v_i, v_j]\,.
\end{align}

In this form, calculating the normalization constant boils down to a single tensor contraction. 

Since contracting tensors in numpy is implemented in C under the hood, we can expect a significant speedup.

### Task 5
Calculate the normalization constant using a **single** contraction using the [Einstein-Summation](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).

For a brief introduction into `einsum`, see [here](https://ajcr.net/Basic-guide-to-einsum/) and [here](https://medium.com/ibm-data-ai/einsum-an-easy-intuitive-way-to-write-tensor-operation-9e12b8a80570).

Make sure that you result is correct by comparing the result to the naive implementation.

In [37]:
np.einsum('i,j,k,l,ij,ik,il,jk,jl,kl->', *t_i, t_ij[0][1], t_ij[0][2], t_ij[0][3], t_ij[1][2], t_ij[1][3], t_ij[2][3])

159744720.16636375

### Task 6

Compare the execution times of calculating $K$ the naive way vs. using `einsum`.

In [5]:
import time

start_naive = time.time()
norm_const_naive(t_i, t_ij)
end_naive = time.time()

start_smart = time.time()
np.einsum('i,j,k,l,ij,ik,il,jk,jl,kl->', *t_i, t_ij[0][1], t_ij[0][2], t_ij[0][3], t_ij[1][2], t_ij[1][3], t_ij[2][3])
end_smart = time.time()

print(f'Naive took {end_naive-start_naive} seconds, while smart took {end_smart-start_smart} seconds.')

Naive took 3.0428404808044434 seconds, while smart took 0.012710094451904297 seconds.


## Contraction order

We see that using contraction speeds up the calculation. This however is not the end of optimization:\
The order of contraction can be permutated, potentially reducing the number of calculations. Here we want to permutate the order in which the variables are marginalized out.

For example for two variables $x_0, x_1$:
\begin{align}
K &= \sum_{v_0=1}^{n_0}\sum_{v_1=1}^{n_1} t_0[v_0]t_1[v_1]t_{01}[v_0, v_1]\\
(1) &= \sum_{v_0=1}^{n_0}t_0[v_0]\sum_{v_1=1}^{n_1}t_1[v_1]t_{01}[v_0, v_1]\\
(2) &= \sum_{v_1=1}^{n_1}t_1[v_1]\sum_{v_0=1}^{n_0}t_0[v_1]t_{01}[v_0, v_1]\\
\end{align}

Can be calculated as (1)
1. Contracting $t_{01}$ and $t_{1}$ over the index $x_1$
2. Contracting the result from 1. with $t_0$ over the index $x_0$

or (2)
1. Contracting $t_{01}$ and $t_{0}$ over the index of $x_0$
2. Contracting the result from 1. with $t_1$ over the index of $x_1$

Depending on the tensor dimensions, one calculation can be faster than the other.


### Task 7

Implement the following function that contracts the tensors in a given order.

As an example for three variables, the order

```
['i', 'j', 'k']
```

with the tensor dictionary

```
tensor_dict = {
'i' : t_i,
'j' : t_j,
'k' : t_k,
'ij' : t_ij,
'ik' : t_ik,
'jk' : t_jk
}
```
will perform the following contractions

1. `tmp = np.einsum('i, ij, ik -> jk', t_i, t_ij, t_ik) # marginalize out i`
2. `tmp = np.einsum('j, jk, jk -> k', t_j, t_jk, tmp) # marginalize out j`
3. `tmp = np.einsum('k, k -> ', t_k, tmp) # marginalize out k`

Make sure that the results are correct and compare the times of different marginalization orders to those from Task 6.

In [27]:
def norm_const_order(order:list, tensor_dict:dict) -> float:
    '''
    Calculates the normalization constant using tensor contraction with a specific order.
    
    @Params:
        order... list of variables in the order of their marginalization
        tensor_dict... dict that stores which tensors are for which variable combination
          
    @Returns:
        normalization constant K
    
    '''
    
    result = 1
    last_kept_indices = ''
    
    tensors = tensor_dict.keys()

    for index in order:
        related_tensors = []
        kept_indices = ''
        for key in tensors:
            if index in key:
                related_tensors.append(key) # all the tensors that have this axis
                rest = key.replace(index, '')
                for r in rest:
                    if not r in kept_indices and r in tensors:
                        kept_indices += r # keep the indices that are not being marginalized out
        einsum_string = f'{",".join(related_tensors)},{last_kept_indices} -> {kept_indices}'
        result = np.einsum(einsum_string, *[tensor_dict[x] for x in related_tensors], result)
        last_kept_indices = kept_indices
        tensors = [x for x in tensors if x not in related_tensors]
    return result

dictionary = {'i': t_i[0],
              'j': t_i[1],
              'k': t_i[2],
              'l': t_i[3],
              'ij': t_ij[0][1],
              'ik': t_ij[0][2],
              'il': t_ij[0][3],
              'jk': t_ij[1][2],
              'jl': t_ij[1][3],
              'kl': t_ij[2][3]}

norm_const_order(['i', 'j', 'k', 'l'], dictionary)

159744720.16636378

In [None]:
import itertools
import time

permutations = itertools.permutations(['i', 'j', 'k', 'l'])
minimum = 10000
minimum_perm = 0
for perm in permutations:
    start_naive = time.time()
    norm_const_order(perm, dictionary)
    end_naive = time.time()
    measured_time = end_naive - start_naive
    if measured_time < minimum:
        minimum = measured_time
        minimum_perm = perm

print(f'Schnellste Kontraktionsreihenfolge: {minimum_perm} mit {minimum}s.')

NameError: name 'norm_const_order' is not defined

## Optimal contraction order

We see that the contraction order has quite a lot of effect on the computation times.

In fact, the problem of finding the best contraction order is generally NP-hard and an active area of research.
In Python, the package [opt_einsum](https://optimized-einsum.readthedocs.io/en/stable/) provides heuristics to find an (near-)optimal contraction order.

### Task 8

Use `opt_einsum` to calculate $K$, make sure result is correct.
Again measure the execution time and compare to the other methods.

Note: if you are interested, you can use `opt_einsum.contract_path` to have a look at the optimal contraction order that was used.

In [10]:
from opt_einsum import contract
from opt_einsum import contract_path

start = time.time()
contract('i,j,k,l,ij,ik,il,jk,jl,kl->', *t_i, t_ij[0][1], t_ij[0][2], t_ij[0][3], t_ij[1][2], t_ij[1][3], t_ij[2][3])
end = time.time()
path = contract_path('i,j,k,l,ij,ik,il,jk,jl,kl->', *t_i, t_ij[0][1], t_ij[0][2], t_ij[0][3], t_ij[1][2], t_ij[1][3], t_ij[2][3])
print(path)
print(f'in {end-start}s.')

([(1, 8), (1, 7), (0, 4), (0, 6), (0, 3), (3, 4), (0, 2), (1, 2), (0, 1)],   Complete contraction:  i,j,k,l,ij,ik,il,jk,jl,kl->
         Naive scaling:  4
     Optimized scaling:  4
      Naive FLOP count:  7.500e+6
  Optimized FLOP count:  1.542e+6
   Theoretical speedup:  4.864e+0
  Largest intermediate:  1.500e+4 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   2              0               jl,j->jl             i,k,l,ij,ik,il,jk,kl,jl->
   2              0               kl,k->kl               i,l,ij,ik,il,jk,jl,kl->
   2              0               il,i->il                 l,ij,ik,jk,jl,kl,il->
   2              0               il,l->il                   ij,ik,jk,jl,kl,il->
   3              0             jl,ij->jli                     ik,jk,kl,il,jli->
   3              0      