In [None]:
import numpy as np
import matplotlib.pyplot as plt

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [None]:
np.random.seed(42)  # results will be respoducible, "it works on my machine" won't make a good excuse :)

In [None]:
# pooling function [kids' code, you can try and improve it :)]
def pooling(x, kernel=2, pool='avg'):
    m, n = x.shape
    k = kernel
    mk = m // k
    nk = n // k
    
    y = np.zeros((mk, nk))
    for i in range(y.shape[0]):
        for j in range(y.shape[1]):
            window = x[i*k:(i+1)*k,j*k:(j+1)*k]
            if pool == 'avg':
                y[i, j] = window.mean()
            elif pool == 'max':
                y[i, j] = window.max()

    return y

In [None]:
digit = 0  # TODO: replace this with the last digit of you student's ID
kernel = 2  # change this to 3 or 4 if you run into numerical stability problems

In [None]:
datapath = 'data'

data_test = MNIST(
    root = datapath, 
    train = False, 
    transform = ToTensor(),
)

x_data = [pooling(image.numpy().reshape(28, 28), kernel=kernel, pool='avg') for image, label in data_test if label==digit]
x_data = np.asarray(x_data)
x_data[x_data > .5] = 1
x_data[x_data <= .5] = 0

flatsize = (28//kernel)**2
x_data = x_data.reshape(-1, flatsize)

In [None]:
x_data.shape

## TODO: Complete the code below

Now it's your time to shine! 
You need to write the code to run the EM algorithm for a Bernoulli mixture and to sample for it.

You can (1) do it from scrath or (2) use the skeleton of the class I have provided here.
If you go for the 2nd option, you should write the code for
- the static method `compute_log_bernoulli`, which returns the log-probabilities for multivariate Bernoulli distributions:
\begin{equation*}
\log p(x|\mu) = \sum_{j=1}^{d} x_{j} \log \mu_j + (1 - x_j) \log \mu_j
\end{equation*}
These allow to compute responsibilities with better numerical stability. Static methods don't take `self` as input and are essentially equivalent to normal functions. 
- the `_em_step` method, which updates the parameters `self.pi` and `self.mu` of the Bernoulli mixture based on the input observations `x`.
- the `sample` method, which returns `n` data points sampled from the Bernoulli mixture 

In [None]:
class BMM:

    """
    Class constructor

    Parameters: 
    - n_components: number of Bernoulli vectors in the mixture (K)
    - pi: initial array of priors/weights for each Bernoulli vector [shape: (K,)]
          If None, the initial array is defined automatically based on the input of fit()
    - mu: initial array of Bernoulli parameters, where each column refers to a different bernoulli vector [shape: (d, K)].
          If None, the initial array is defined automatically based on the input of fit()
    - alpha: stability parameter. Bernoulli parameters are constrained between
    """
    def __init__(self, n_components=2, pi=None, mu=None, alpha=0.001) -> None:

        self.n_components = n_components
        self.pi = pi
        self.mu = mu
        self.alpha = alpha

    """
    Method: _init_params

    Initializes parameters
    """
    def _init_params(self, x):
        if self.pi is None:
            self.pi = np.ones(self.n_components)/self.n_components
        if self.mu is None:
            self.mu = (1-self.alpha)*np.random.rand(x.shape[1], self.n_components)+self.alpha
        return

    """
    Methods: get_params

    Returns a dictionary with all the BMM's parameters
    """
    def get_params(self) -> dict:
        return {
            'pi': self.pi.copy(),
            'mu': self.mu.copy(),
        }

    @staticmethod
    def compute_log_bernoulli(x, mu):
        log_prob = np.zeros((x.shape[0], mu.shape[1]))  # allocate array of log-likelihoods
        
        # ++++++++++++++++++++++++
        # TODO: Add your code here
        # ++++++++++++++++++++++++

        return log_prob

    """
    Method: _em_step

    Performs a single EM step based on the input x
    """
    def _em_step(self, x):

        # ++++++++++++++++++++++++
        # TODO: Add your code here
        # ++++++++++++++++++++++++

        return
    

    """
    Method: fit

    Fits the Bernoulli mixture by repeating the EM steps and updating the parameters self.pi and self.mu 
    until convergence is reached

    Parameters:
    - x: Input data [shape (n, d)]
    - eps: Convergence parameter, you can try and play with it
    - max_iters: Maximum number of iterations
    - verbose: if True, prints the parameters variation at each iteration
    """
    def fit(self, x, eps=1, max_iters=100, verbose=False):

        self._init_params(x)

        converged = False
        num_iters = 0

        while (not converged):
            
            params_old = self.get_params()
            self._em_step(x)  # perform EM step
            params_new = self.get_params()
            
            converged = True

            for name in params_old.keys():
                # compute RMSE between old and new params
                delta = np.sqrt(np.mean((params_new[name] - params_old[name])**2))
                print(f"Variation of {name} at iter {num_iters+1:03d}: {delta}")
                if delta > eps:
                    converged = False

            num_iters += 1

            if num_iters >= max_iters:
                if verbose:
                    print("Maximum number of iterations reached: stop fitting.")
        return self

    """
    Method: sample

    Returns n datapoints sampled from the Bernoulli mixture [output shape: (n, d)]
    """
    def sample(self, n):
        samples = np.zeros(n, self.mu.shape[0])
        # ++++++++++++++++++++++++
        # TODO: Add your code here
        # ++++++++++++++++++++++++
        return samples

In [None]:
bmm = BMM(n_components=10).fit(x_data, eps=0.1, max_iters=10, verbose=True)

In [None]:
# sample 10 images (flattened)
x_sample = bmm.sample(n=10)

In [None]:
import os

if not os.path.isdir(f'results/{digit}'):
    os.makedirs(f'results/{digit}')

# plot the generated images and save them
for i, x in enumerate(x_sample):
    image = x.reshape(28//kernel, 28//kernel)
    plt.figure()
    plt.imshow(image, cmap='binary')
    plt.savefig(f'results/{digit}/example_{i+1:03d}.png')
    plt.draw()