In [2]:
import itertools
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns
import time

from scipy.stats import binom, norm

In [3]:
K_true = 3
M_true = 4

R_true = np.array([[0.1, 0.3, 0.4, 0.2],
              [0.2, 0.6, 0.1, 0.1],
              [0.5, 0.2, 0.1, 0.2]])

P_true = np.array([0.25, 0.35, 0.4])
Q_true = np.array([0.5, 0.9, 0.5])


X_param_true = np.array([[-5, 1.5],
                    [2, 1.6],
                    [7, 2.2]])


Y_param_true = np.array([[[-1, 0.5], [5, 1.5], [9, 0.5], [12, 0.3]],
                    [[-2, 0.5], [8, 0.9], [12, 0.8], [16, 0.8]],
                    [[-3, 1.5], [10, 0.8], [14, 1.2], [20, 0.7]]])

In [4]:
def generate_data(R, P, Q, X_param, Y_param):
    K = P.shape[0]
    M = R.shape[1]
    
    # Choose idx of mixture on X
    i = np.random.choice(K, p=P)

    # Choose amount of latent variables in this mixture
    eta = 1 + np.random.binomial(n=M-1, p=Q[i])

    # Choose subset of size eta
    idx_subset = np.random.choice(M, replace=False, size=eta, p=R[i])

    # Choose mixture latent variable
    l = np.random.choice(a=idx_subset)

    x = np.random.normal(X_param[i][0], np.sqrt(X_param[i][1]))
    y = np.random.normal(Y_param[i, l][0], np.sqrt(Y_param[i, l][1]))

    return x, y, i*M + l

In [5]:
def generate_dataset(n_samples, R, P, Q, X_param, Y_param, random_state=42):
    np.random.seed(random_state)

    data = []
    labels = []

    for _ in range(n_samples):
        x, y, label = generate_data(R, P, Q, X_param, Y_param)
        data.append(np.array([x, y]))
        labels.append(label)

    return np.array(data), np.array(labels)

In [119]:
class EMModel:
    def __init__(self):
        self.bits_cache = None
        self.denominator_pi = None

        self.K = None
        self.M = None

        self.R = None
        self.P = None
        self.Q = None
        self.X_param = None
        self.Y_param = None

    def fit(self, data, R, P, Q, X_param, Y_param, epochs):
        n = data.shape[0]
        self.K = P.shape[0]
        self.M = R.shape[1]
        self.R = np.copy(R)
        self.P = np.copy(P)
        self.Q = np.copy(Q)
        self.X_param = np.copy(X_param)
        self.Y_param = np.copy(Y_param)

        self.bits_cache = [list(itertools.combinations(range(self.M), j)) for j in range(1, self.M + 1)]
        
        resp_shape = (n, self.K, self.M, 2**self.M - 1, self.M)
        responsibilities = np.zeros(resp_shape)

        start_time = time.time()

        for epoch in range(epochs):
            print("Epoch:", epoch)
            # E-Step
            # Calculate responsibilities 

            for t, (x, y) in enumerate(data):
                for i in range(self.K):
                    for j in range(0, self.M):
                        for s in self.bits_cache[j]:
                            s_idx = np.sum(2 ** np.array(s)) - 1
                            for l in s:
                                responsibilities[t, i, j, s_idx, l] = self.calculate_responsb_(x, y, i, j, s, l)
               
                responsibilities[t, :, :, :, :] /= np.sum(responsibilities[t, :, :, :, :])



            print("E-Step done")
            end_time = time.time()
            print(f"Time elapsed: {end_time - start_time} seconds")
             
            # M-Step
            # Update parameters

            # Update P
            for i in range(self.K):
                self.P[i] = np.sum(responsibilities[:, i, :, :, :]) / n

            # Update Q
            for i in range(self.K):
                res = 0
                for s in self.bits_cache[0]:
                    s_idx = np.sum(2 ** np.array(s)) - 1
                    for l in s:
                        res += np.sum(responsibilities[:, i, 0, s_idx, l]) 
                
                q_i_1 = res / n

                self.Q[i] = 1 - q_i_1**(1/(self.M - 1))

            # Update R
            for i in range(self.K):
                denominator = 2 * np.sum(responsibilities[:, i, :, :, :])
                for l in range(self.M):
                    res = 0

                    for t in range(n):
                        for j in range(self.M):
                            for s in self.bits_cache[j]:
                                s_idx = np.sum(2 ** np.array(s)) - 1
                                indicator = l in s
                                res += responsibilities[t, i, j, s_idx, l] * (1 + indicator)
                
                    self.R[i, l] = res / denominator
            

            
            # Update mean for X
            for i in range(self.K):
                res = 0
                for t in range(n):
                    res += np.sum(responsibilities[t, i, :, :, :]) * data[t, 0]
                self.X_param[i, 0] = res / np.sum(responsibilities[:, i, :, :, :])
            
            # Update variance for X
            for i in range(self.K):
                res = 0
                for t in range(n):
                    res += np.sum(responsibilities[t, i, :, :, :]) * (data[t, 0] - self.X_param[i, 0])**2
                self.X_param[i, 1] = res / np.sum(responsibilities[:, i, :, :, :])
            
            
            # Update mean for Y
            for i in range(self.K):
                for l in range(self.M):
                    res = 0
                    for t in range(n):
                        res += data[t, 1] * np.sum(responsibilities[t, i, :, :, l])

                    self.Y_param[i, l, 0] = res / np.sum(responsibilities[:, i, :, :, l])
            
            # Update variance for Y
            for i in range(self.K):
                for l in range(self.M):
                    res = 0
                    for t in range(n):
                        res += np.sum(responsibilities[t, i, :, :, l]) * (data[t, 1] - self.Y_param[i, l, 0])**2

                    self.Y_param[i, l, 1] = res / np.sum(responsibilities[:, i, :, :, l])
            
            print("M-step done")

            print("P:\n", self.P)
            print("Q:\n", self.Q)
            print("R:\n", self.R)
            print("X_param\n", self.X_param)
            print("Y_param\n", self.Y_param)

        
        return responsibilities
        
    
    def set_prob_(self, i, j, s):
        if self.denominator_pi is None:
            self.denominator_pi = 0
            
            # Retrive all bits with sum = j
            for bits in self.bits_cache[j]:
                self.denominator_pi += self.R[i][bits]

        # s - indexes that coresspond to 1's bit
        return np.prod(self.R[i][[s]]) / self.denominator_pi


    def calculate_pi_(self, i, j, s, l):
        return self.P[i] * binom.pmf(k=j, n=self.M-1, p=self.Q[i]) * self.set_prob_(i, j, s) * self.R[i, l] / np.sum(self.R[i][[s]])


    def calculate_responsb_(self, x, y, i, j, s, l):
        return self.calculate_pi_(i, j, s, l) * norm.pdf(x, loc=self.X_param[i][0], scale=np.sqrt(self.X_param[i][1])) * \
        norm.pdf(y, loc=self.Y_param[i,l][0], scale=np.sqrt(self.Y_param[i,l][1]))

In [270]:
def generate_est_parameters(data, random_state):
    np.random.seed(random_state)

    # Estimate R
    R_est = np.ones(shape=(K_true, M_true)) / M_true
    noise = np.random.dirichlet(np.ones(M_true), size=K_true)
    R_est += noise * 0.05
    R_est /= np.sum(R_est, axis=1).reshape(-1, 1)


    # Estimate P
    P_est = np.full(K_true, 1/K_true)
    noise = np.random.dirichlet(np.ones(K_true))
    P_est += noise * 0.05
    P_est /= np.sum(P_est)
    
    Q_est = np.random.uniform(size=K_true)

    X_param_est = np.zeros((K_true, 2))
    Y_param_est = np.zeros((K_true, M_true, 2))

    min_x = np.min(data[:, 0])
    max_x = np.max(data[:, 0])
    intervals = np.linspace(min_x, max_x, K_true + 1)

    for i in range(0, K_true):
        X_param_est[i][0] = np.random.uniform(intervals[i], intervals[i+1])
        X_param_est[i][1] = np.random.uniform(0, (intervals[i+1] - intervals[i])/6)
    
    min_y = np.min(data[:, 1])
    max_y = np.max(data[:, 1])

    intervals = np.linspace(min_y, max_y, M_true + 1)

    for i in range(0, K_true):
        for j in range(0, M_true):
            Y_param_est[i, j, 0] = np.random.uniform(intervals[j], intervals[j+1])
            Y_param_est[i, j, 1] = np.random.uniform(0, (intervals[j+1] - intervals[j])/6)

    return R_est, P_est, Q_est, X_param_est, Y_param_est

In [271]:
data, labels = generate_dataset(400, R_true, P_true, Q_true, X_param_true, Y_param_true)

In [285]:
R_est, P_est, Q_est, X_param_est, Y_param_est = generate_est_parameters(data, 12)

In [286]:
P_est

array([0.34607041, 0.33644686, 0.31748274])

In [287]:
R_est

array([[0.24118152, 0.26292995, 0.2437284 , 0.25216014],
       [0.23823884, 0.2626469 , 0.26068655, 0.23842771],
       [0.27095315, 0.23963695, 0.24158262, 0.24782728]])

In [288]:
X_param_est

array([[-2.73155006,  0.17902626],
       [ 3.92810946,  0.02318054],
       [ 6.40536343,  0.12951932]])

In [289]:
Y_param_est

array([[[-2.97002819,  0.74879492],
        [ 4.80058367,  0.91017986],
        [10.27630774,  0.81757219],
        [19.73109684,  0.36530106]],

       [[-2.80442336,  1.09071716],
        [ 5.82668238,  1.05977681],
        [13.473919  ,  0.92003918],
        [17.75064571,  0.50329304]],

       [[-2.36291591,  1.10976463],
        [ 2.83563219,  1.07347506],
        [11.14390155,  0.47289756],
        [18.12875126,  0.41677101]]])

In [290]:
Q_est

array([0.52122603, 0.55203763, 0.48537741])

In [291]:
model = EMModel()
res = model.fit(data, R_est, P_est, Q_est, X_param_est, Y_param_est, epochs=10)

Epoch: 0
E-Step done
Time elapsed: 12.886703491210938 seconds
M-step done
P:
 [0.42741514 0.18040733 0.39217753]
Q:
 [0.44093656 0.59319241 0.44004539]
R:
 [[0.20349732 0.20969775 0.55775397 0.02905095]
 [0.23654084 0.29729501 0.2906349  0.17552925]
 [0.330983   0.00206482 0.35811079 0.30884138]]
X_param
 [[-2.48723187 10.21864765]
 [ 3.1252725   0.53204343]
 [ 7.10960365  1.99045186]]
Y_param
 [[[-1.43184612  0.7628129 ]
  [ 5.78939923  1.81927303]
  [10.92352848  4.85789698]
  [16.81807766  0.13779922]]

 [[-2.58998556  0.35185   ]
  [ 8.22610165  0.7042893 ]
  [12.30527378  3.88875518]
  [16.82484673  1.26917007]]

 [[-2.85322969  1.13685379]
  [ 7.22637911  5.37947575]
  [11.30578978  3.09280046]
  [19.22501656  3.52454569]]]
Epoch: 1
E-Step done
Time elapsed: 26.340823650360107 seconds
M-step done
P:
 [0.42828031 0.17236021 0.39935948]
Q:
 [0.37726731 0.61389523 0.40079434]
R:
 [[1.93177941e-01 1.97057723e-01 5.89552199e-01 2.02121365e-02]
 [2.62254818e-01 2.89723659e-01 2.4409586

In [292]:
P_true

array([0.25, 0.35, 0.4 ])

In [293]:
model.P

array([0.57055803, 0.05941714, 0.37002483])

In [294]:
Q_true

array([0.5, 0.9, 0.5])

In [295]:
model.Q

array([0.23621955, 0.7526451 , 0.39298445])

In [296]:
R_true

array([[0.1, 0.3, 0.4, 0.2],
       [0.2, 0.6, 0.1, 0.1],
       [0.5, 0.2, 0.1, 0.2]])

In [297]:
model.R

array([[1.94066544e-01, 1.41905693e-01, 6.43312288e-01, 2.07154753e-02],
       [2.81250097e-01, 9.49371012e-02, 1.74988280e-04, 6.23637814e-01],
       [3.57978864e-01, 3.19053666e-04, 3.68211691e-01, 2.73490390e-01]])

In [298]:
np.sum(model.R, axis=1)

array([1., 1., 1.])

In [299]:
X_param_true

array([[-5. ,  1.5],
       [ 2. ,  1.6],
       [ 7. ,  2.2]])

In [300]:
model.X_param

array([[-0.76778529, 17.16982934],
       [ 2.41025697,  0.59502389],
       [ 6.98283891,  2.48073609]])

In [301]:
Y_param_true

array([[[-1. ,  0.5],
        [ 5. ,  1.5],
        [ 9. ,  0.5],
        [12. ,  0.3]],

       [[-2. ,  0.5],
        [ 8. ,  0.9],
        [12. ,  0.8],
        [16. ,  0.8]],

       [[-3. ,  1.5],
        [10. ,  0.8],
        [14. ,  1.2],
        [20. ,  0.7]]])

In [302]:
model.Y_param

array([[[-1.50893849e+00,  7.37569873e-01],
        [ 6.29685080e+00,  2.98052066e+00],
        [ 1.03873550e+01,  5.76609118e+00],
        [ 1.71003779e+01,  1.49607252e-02]],

       [[-2.73698529e+00,  1.63420233e-01],
        [ 8.09501091e+00,  1.03319343e-01],
        [ 1.16602585e+01,  2.85348148e+00],
        [ 1.61162333e+01,  2.21852045e-01]],

       [[-2.96262551e+00,  9.51677913e-01],
        [ 1.00177475e+01,  1.12572260e-01],
        [ 1.23034956e+01,  5.54252835e+00],
        [ 2.00396314e+01,  6.08391680e-01]]])