In [None]:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt

## Tutorial 12
### Problem 1 :  Gibbs Sampling for MIMO Channel
We consider a point-to-point MIMO system with $N_t$ transmit and $N_r$ receive antennas over a frequency
flat fading channel $H$. The
corresponding channel model is

$$\underline{Y} = H\underline{X} + \underline{N}$$
where $\underline{X} \sim \mathcal{N}(\underline{0},\sigma_{x}^2I)$ and $\underline{N} \sim \mathcal{N}(\underline{0},\sigma_{n}^2I)$. The entries of $H$ are
assumed iid zero-mean and unit-variance Gaussian.


**a)** Implement the system model so that you can obtain received samples $\vec{y}$.

In [None]:
# Parameters
Nt = 4 # number of transmit antennas
Nr = 2 # number of receive antennas
sigma2x = 2 # variance of transmit signal
sigma2n = 3 # variance of noise

In [None]:
# System model
H = # ...
x = # ...
n = # ...
y = # ...

**b)** Implement the closed form expression for the mean-square error estimate.

In [None]:
# Closed form solution
Cx = # ...
Cn = # ...
Cy = # ...
xhat = # ...
print(f'Closed form solution:\n {xhat}')

**d)** Implement a Gibbs sampler using the previously derived expressions.

In [None]:
def gibbs_sampling(sample_fun, y, x0, num_samples):
    """
    Implements the Gibbs sampling procedure.
    The parameter sample_fun is a function handle that
    implements the sampling from the individual conditional
    component distributions. Its signature is:

    def sample_fun(k, x0): pass

    where k defines the component and x0 is the vector that specifies
    in which the other components should be evaluated.
    
    Returns a matrix x of size Nt x num_samples that contains the samples
    returned by the sample_fun()
    """

    Nt = x0.shape[0]
    x = np.empty((Nt,num_samples), dtype=float)

    s  = 0

    while s < num_samples:
        for k in np.arange(Nt):
            # ...
            # ...
        # ...
    
    return x

def sampling_funs(y, H, k, x0):
    """
    Implements the actual sampling for the conditional mean case. The intermediate steps
    and quantities (means, covariance matrices, etc.) have been derived and explained
    in the lecture notes.    
    """

    # ...
    # ...
    # ...
    # ...

    return np.random.randn(1,1) * np.sqrt(sigma2) + mu

In [None]:
# calculate xhat thorugh Gibbs sampling
x0 = np.random.randn(Nt,1)
num_samples = 1000
x = gibbs_sampling(lambda k, x0: sampling_funs(y, H, k, x0), y, x0, num_samples)

In [None]:
print(f'Closed form solution:\n {xhat}')
print(f'Gibbs sampling approximation:\n {np.mean(x,axis=1).reshape(xhat.shape)}')

**d)** Compare the quality of the Gibbs sampler for different sampling strategies.

In [None]:
num_samples = 10 ** np.array([1, 2, 3, 4])
trials = 10

xhat_tmp = np.zeros((Nt, trials))
mse = np.zeros(len(num_samples))

for i in range(len(num_samples)):
    print('Working on sample size: {}\n'.format(num_samples[i]))
    for j in range(trials):
        print(' {}'.format(j), end =" ")
        # ...
        # ...
        # ...
   
    # ...
    # ...

In [None]:
plt.loglog(num_samples, mse)
plt.xlabel('Number of samples')
plt.ylabel('E[||x_mmse - x_mmse_gibbs||^2]')
plt.grid()
plt.show()