# Background on processes and Markov Chains

The [transportation problem](https://en.wikipedia.org/wiki/Transportation_theory_(mathematics)) is concerned with finding how much it costs to move mass (represented as a probability measure) to another arrangement or distribution (i.e. another probability measure). Probability measures appear naturally from many sources. Instead of considering a single measure one might index these and consider the notion of process. The simplest processes are [Markov Chains](https://en.wikipedia.org/wiki/Markov_chain) which indexed by positive integers, i.e. [discrete time Markov Chains (DTMC)](https://en.wikipedia.org/wiki/Discrete-time_Markov_chain). We provide and un-pack the following standard defintion for clarity.

````{prf:definition} Discrete Time Markov Chain
:label: dtmc

A discrete-time Markov chain is a sequence of random variables $X_0, X_1, X_2, \ldots$ with the Markov property, namely that the probability of moving to the next state depends only on the present state and not on the previous states:

$\Pr(X_{n+1}=x\mid X_1=x_1, X_2=x_2, \ldots, X_n=x_n) = \Pr(X_{n+1}=x\mid X_n=x_n)$, if both conditional probability are well defined, that is, if $\Pr(X_1=x_1,\ldots,X_n=x_n)>0$.

The possible values of $X_i$ form a countable set $S$ called the state space of the chain.
````

Let's illustrate this definition by choosing a state space, $S = \{0,1\}$ and assume time-homogeneity or stationarity (i.e. $\Pr(X_{n+1}=x\mid X_n=y) = \Pr(X_n=x\mid X_{n-1}=y)$). This means practically that we can represent the conditional probabilities in a matrix as follows:
```{math}
:label: chain_1
\begin{aligned}
\Pr(X_{n+1} = i \mid X_n = j) &= 
\begin{bmatrix}
\Pr(X_{n+1} = 0 \mid X_n = 0) & \Pr(X_{n+1} = 0 \mid X_n = 1) \\
\Pr(X_{n+1} = 1 \mid X_n = 0) & \Pr(X_{n+1} = 1 \mid X_n = 1)
\end{bmatrix}
=
\begin{bmatrix}
1/2 & 1/2 \\
1/2 & 1/2
\end{bmatrix}
\end{aligned}
```
Here we have selected some probabilities for illustration purposed.


In this context, key observation can be made about optimal transport related to stationary Markov Chains. Namely, optimal transport on stationary distributions alone does not distinguish chains with different dynamics. The following is given in Section 2. of [Optimal Transport for Stationary Markov Chains via Policy Iteration](https://jmlr.csail.mit.edu/papers/v23/21-0519.html). In addition to the chain specified in {numref}`chain_1` consider teh following additional chain:

```{math}
:label: chain_2
\begin{aligned}
\Pr(X_{n+1} = i \mid X_n = j) &= 
\begin{bmatrix}
\Pr(X_{n+1} = 0 \mid X_n = 0) & \Pr(X_{n+1} = 0 \mid X_n = 1) \\
\Pr(X_{n+1} = 1 \mid X_n = 0) & \Pr(X_{n+1} = 1 \mid X_n = 1)
\end{bmatrix}
=
\begin{bmatrix}
0 & 1 \\
1 & 0
\end{bmatrix}
\end{aligned}
```


In both of these case one can readily verify the following relations which demonstrate the stationary distribution of these chains are identical.

```{math}
:label: stationary_dist
\begin{aligned}
\begin{bmatrix}
1/2 & 1/2 \\
1/2 & 1/2
\end{bmatrix}
\begin{bmatrix}
1/2 \\
1/2
\end{bmatrix}
= 
\begin{bmatrix}
1/2 \\
1/2
\end{bmatrix},
\ \
\begin{bmatrix}
0 & 1 \\
1 & 0
\end{bmatrix}
\begin{bmatrix}
1/2 \\
1/2
\end{bmatrix}
= 
\begin{bmatrix}
1/2 \\
1/2
\end{bmatrix}
\end{aligned}
```

Thus, we compare these chains by solving the transportation problem between their stationary distributions with any allowable cost, we find a zero cost transportation plan. We note, as is noted by the authors that the first chain is IID and while the second is deterministic (after conditioning). Adding a practical dimension, we have coin flipping deciding a state and fixed rules deciding state in the other. It would be desirable to have a way to separate these situations. This is precisely what the work: [Optimal Transport for Stationary Markov Chains via Policy Iteration](https://jmlr.csail.mit.edu/papers/v23/21-0519.html) does.

## Reformulating the transition coupling problem as an MDP

Following [Optimal Transport for Stationary Markov Chains via Policy Iteration](https://jmlr.csail.mit.edu/papers/v23/21-0519.html) we develop and implement their algorithm to resolve the problem demonstrated above. Their approach is to recognize that we can distinguish these stationary chains consider transition coupling that intuitively include more subspace information. This is done by recognizing one can seek so called *transition couplings* which as their name suggests couple transition matrices.

````{prf:definition} Transition Couplings
:label: transiton_coupling

Let $P$ and $Q$ be transition matrices on finite state spaces $X$ and $Y$, respectively.
A transition matrix $R \in [0, 1]^{d^2\times d^2}$ is a *transition coupling* of $P$ and $Q$ 
if for every paired-state $(x,y) \in X \times Y$, the distribution $R((x,y),\cdot)$ is a coupling
of the distributions $P(x, \cdot)$ and $Q(y,\cdot)$, formally $R((x,y),\cdot) \in \Pi(P(x,\cdot), Q(y,\cdot))$.
Let $\prod_{\text{TC}}(P,Q)$ denote the set of all transition couplings of $P$ and $Q$.
````

Let's build on our simplest example from above, where $X = Y = \{0, 1\}$ and

```{math}
\begin{aligned}
P = 
\begin{bmatrix}
1/2 & 1/2 \\
1/2 & 1/2
\end{bmatrix}, \ \ 

P = 
\begin{bmatrix}
0 & 1 \\
1 & 0
\end{bmatrix}.
\end{aligned}
```

Let $Z = X \times Y = \{(0,0), (0,1), (1, 0), (1,1)\}$ we want coupling of the product which express as $R: Z \to Z$ which we express using rows. We have determied the probabilities assuming independence of $P$ and $Q$.

```{math}
\begin{aligned}
R = \Pr(z | (x,y))
= \begin{bmatrix}
1/2\cdot 0 & 1/2 \cdot 1 & 1/2 \cdot 0 & 1/2 \cdot 1 \\
1/2\cdot 0 & 1/2 \cdot 1 & 1/2 \cdot 0 & 1/2 \cdot 1 \\
1/2\cdot 0 & 1/2 \cdot 1 & 1/2 \cdot 0 & 1/2 \cdot 1 \\
1/2\cdot 0 & 1/2 \cdot 1 & 1/2 \cdot 0 & 1/2 \cdot 1
\end{bmatrix}
= \begin{bmatrix}
0 & 1/2 &  0 & 1/2  \\
0 & 1/2 &  0 & 1/2  \\
0 & 1/2 &  0 & 1/2  \\
0 & 1/2 &  0 & 1/2 
\end{bmatrix}.
\end{aligned}
```

We could perform this with matrix operations applying a change of coordinates and then performing an outer product. One could also consider other representation such as the tensor product, but this one allows the row interpretation.

We verify that retain stationary with $R$ since

```{math}
\begin{aligned}
\begin{bmatrix}
0 & 1/2 &  0 & 1/2  \\
0 & 1/2 &  0 & 1/2  \\
0 & 1/2 &  0 & 1/2  \\
0 & 1/2 &  0 & 1/2 
\end{bmatrix}
\begin{bmatrix}
1/2 \\
1/2 \\
1/2 \\
1/2
\end{bmatrix}
=\begin{bmatrix}
1/2 \\
1/2 \\
1/2 \\
1/2
\end{bmatrix}
\end{aligned}

```

## Algorithm and implementation

In the following we provide a simple implementation of the OTC algorithm using the python numerical stack. We use mainly numpy and the python optimal transport library.


In [37]:
# dependencies
import numpy as np
import ot # POT

In [38]:
def exact_tce(R: np.ndarray, c: np.ndarray)->tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Function that implements ExactTCE algorithm (TCE == transition coupling evaluation)

    Algorithm 1a from:
    https://jmlr.csail.mit.edu/papers/volume23/21-0519/21-0519.pdf
    """
    assert R.shape[0] == R.shape[1]
    n = R.shape[0]
    I = np.eye(n)
    Zeros = np.zeros_like(R)
    M = np.block([[I-R, Zeros, Zeros], 
                  [I, I-R,  Zeros],
                  [Zeros, I, I-R]])
    print(M)
    rhs = np.block([[np.zeros((n,1))], [c], [np.zeros((n,1))]])
    x = np.linalg.solve(M, rhs)
    g, h, w = x[0:n], x[n:2*n], x[2*n:]
    return g, h, w

In [40]:
# test exact_tce

def example_R():
    zeros_col = np.zeros(4).reshape(-1,1)
    halves_col = 1/2*np.ones(4).reshape(-1,1)
    R = np.block([zeros_col, halves_col, zeros_col, halves_col])
    return R

def test_exact_tce():
    c = np.array([0,1,1,0]).reshape(-1,1)
    R = example_R()
    exact_tce(R=R, c=c)

test_exact_tce()

[[ 1.  -0.5  0.  -0.5  0.   0.   0.   0.   0.   0.   0.   0. ]
 [ 0.   0.5  0.  -0.5  0.   0.   0.   0.   0.   0.   0.   0. ]
 [ 0.  -0.5  1.  -0.5  0.   0.   0.   0.   0.   0.   0.   0. ]
 [ 0.  -0.5  0.   0.5  0.   0.   0.   0.   0.   0.   0.   0. ]
 [ 1.   0.   0.   0.   1.  -0.5  0.  -0.5  0.   0.   0.   0. ]
 [ 0.   1.   0.   0.   0.   0.5  0.  -0.5  0.   0.   0.   0. ]
 [ 0.   0.   1.   0.   0.  -0.5  1.  -0.5  0.   0.   0.   0. ]
 [ 0.   0.   0.   1.   0.  -0.5  0.   0.5  0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   1.   0.   0.   0.   1.  -0.5  0.  -0.5]
 [ 0.   0.   0.   0.   0.   1.   0.   0.   0.   0.5  0.  -0.5]
 [ 0.   0.   0.   0.   0.   0.   1.   0.   0.  -0.5  1.  -0.5]
 [ 0.   0.   0.   0.   0.   0.   0.   1.   0.  -0.5  0.   0.5]]


LinAlgError: Singular matrix

In [43]:
example_R()

array([[0. , 0.5, 0. , 0.5],
       [0. , 0.5, 0. , 0.5],
       [0. , 0.5, 0. , 0.5],
       [0. , 0.5, 0. , 0.5]])

In [53]:
np.linalg.linalg.lstsq((np.eye(4) - example_R()), np.zeros(4), rcond=None)

(array([0., 0., 0., 0.]),
 array([], dtype=float64),
 3,
 array([1.41421356e+00, 1.00000000e+00, 1.00000000e+00, 7.37425190e-17]))

In [51]:
np.linalg.eig(np.eye(4) - example_R())

EigResult(eigenvalues=array([1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.11022302e-16]), eigenvectors=array([[ 1.        ,  0.        ,  0.40824829,  0.5       ],
       [ 0.        ,  0.        ,  0.57735027,  0.5       ],
       [ 0.        ,  1.        ,  0.40824829,  0.5       ],
       [ 0.        ,  0.        , -0.57735027,  0.5       ]]))