In [1]:
import itertools
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns
import time

from scipy.stats import binom, norm
from math import comb

In [2]:
K_true = 3
M_true = 4

R_true = np.array([[0.1, 0.3, 0.4, 0.2],
                   [0.2, 0.6, 0.1, 0.1],
                   [0.5, 0.2, 0.1, 0.2]])

P_true = np.array([0.25, 0.35, 0.4])
Q_true = np.array([0.1, 0.2, 0.6])


X_param_true = np.array([[-5, 1.5],
                         [2, 1.6],
                         [7, 2.2]])


Y_param_true = np.array([[[-1, 0.5], [5, 1.5], [9, 0.5], [12, 0.8]],
                         [[-2, 0.5], [8, 0.9], [12, 0.8], [16, 0.8]],
                         [[-2, 0.5], [7, 0.9], [13, 0.8], [17, 0.8]]])

In [3]:
def generate_data(R, P, Q, X_param, Y_param):
    K = P.shape[0]
    M = R.shape[1]
    
    # Choose idx of mixture on X
    i = np.random.choice(K, p=P)

    # Choose amount of latent variables in this mixture
    eta = 1 + np.random.binomial(n=M-1, p=Q[i])

    # Choose subset of size eta
    idx_subset = np.random.choice(M, replace=False, size=eta, p=R[i])

    # Choose mixture latent variable
    l = np.random.choice(a=idx_subset, p=R_true[i][idx_subset]/np.sum(R_true[i][idx_subset]))

    x = np.random.normal(X_param[i][0], np.sqrt(X_param[i][1]))
    y = np.random.normal(Y_param[i, l][0], np.sqrt(Y_param[i, l][1]))

    return x, y, i*M + l

In [4]:
def generate_dataset(n_samples, R, P, Q, X_param, Y_param, random_state=42):
    np.random.seed(random_state)

    data = []
    labels = []

    for _ in range(n_samples):
        x, y, label = generate_data(R, P, Q, X_param, Y_param)
        data.append(np.array([x, y]))
        labels.append(label)

    return np.array(data), np.array(labels)

In [5]:
class EMModel:
    def __init__(self):
        self.bits_cache = None
        self.denominator_pi = None

        self.K = None
        self.M = None

        self.R = None
        self.P = None
        # Q - mmatrix
        self.Q = None
        self.X_param = None
        self.Y_param = None

        self.current_likelihood_ = None

    def fit(self, data, R, P, Q, X_param, Y_param, epochs):
        n = data.shape[0]
        self.K = P.shape[0]
        self.M = R.shape[1]
        self.R = np.copy(R)
        self.P = np.copy(P)
        self.Q = np.copy(Q)
        self.X_param = np.copy(X_param)
        self.Y_param = np.copy(Y_param)

        self.bits_cache = [list(itertools.combinations(range(self.M), j)) for j in range(1, self.M + 1)]
        
        resp_shape = (n, self.K, self.M, 2**self.M - 1, self.M)
        responsibilities = np.zeros(resp_shape)

        start_time = time.time()

        self.current_likelihood_ = np.zeros(epochs)
        for epoch in range(epochs):
            print("Epoch:", epoch)
            # E-Step
            # Calculate responsibilities 

            for t, (x, y) in enumerate(data):
                for i in range(self.K):
                    for j in range(0, self.M):
                        for s in self.bits_cache[j]:
                            s_idx = np.sum(2 ** np.array(s)) - 1
                            
                            for l in s:
                                responsibilities[t, i, j, s_idx, l] = self.calculate_responsb_(x, y, i, j, s, l)
                

                s_t = np.sum(responsibilities[t, :, :, :, :])
                responsibilities[t, :, :, :, :] /= s_t
                self.current_likelihood_[epoch] += np.log(s_t)
            
            print("E-Step done")
            end_time = time.time()
            print(f"Time elapsed: {end_time - start_time} seconds")
             
            # M-Step
            # Update parameters

            # Update P
            for i in range(self.K):
                self.P[i] = np.sum(responsibilities[:, i, :, :, :]) / n

            # Update Q
            for i in range(self.K):
                for j in range(self.M):
                    res = 0
                    for s in self.bits_cache[j]:
                        s_idx = np.sum(2 ** np.array(s)) - 1
                        for l in s:
                            res += np.sum(responsibilities[:, i, j, s_idx, l])
                
                    self.Q[i, j] = res / np.sum(responsibilities[:, i, :, :, :])

            
            # Update R
            for i in range(self.K):
                for l in range(self.M):
                    res = 0

                    for j in range(self.M):

                        numerator_a = 0
                        denominator_a = 0
                        for s in self.bits_cache[j]:
                            s_idx = np.sum(2 ** np.array(s)) - 1
                            indicator = l in s

                            numerator_a += indicator * np.prod(self.R[i][[s]])
                            denominator_a += np.prod(self.R[i][[s]])

                        denominator_b = 0
                        for s in self.bits_cache[j]:
                            s_idx = np.sum(2 ** np.array(s)) - 1
                            
                            for k in range(self.M):
                                if k!= l:
                                    denominator_b += (k in s) * self.R[i, k]
                        
                        denominator_b = 1 + denominator_b / self.R[i, l]


                        tmp = 0
                        #### Final
                        for s in self.bits_cache[j]:
                            if l in s:
                                s_idx = np.sum(2 ** np.array(s)) - 1
                                tmp += np.sum(responsibilities[:, i, j, s_idx, l])
    
                        res += tmp * (2 - numerator_a/denominator_a - 1/denominator_b)

                    self.R[i, l] = res
                
                self.R[i] /= np.sum(self.R[i])

            
            # Update mean for X
            for i in range(self.K):
                res = 0
                for t in range(n):
                    res += np.sum(responsibilities[t, i, :, :, :]) * data[t, 0]
                self.X_param[i, 0] = res / np.sum(responsibilities[:, i, :, :, :])
            
            # Update variance for X
            for i in range(self.K):
                res = 0
                for t in range(n):
                    res += np.sum(responsibilities[t, i, :, :, :]) * (data[t, 0] - self.X_param[i, 0])**2
                self.X_param[i, 1] = res / np.sum(responsibilities[:, i, :, :, :])
            
            
            # Update mean for Y
            for i in range(self.K):
                for l in range(self.M):
                    res = 0
                    for t in range(n):
                        res += data[t, 1] * np.sum(responsibilities[t, i, :, :, l])

                    self.Y_param[i, l, 0] = res / np.sum(responsibilities[:, i, :, :, l])
            
            # Update variance for Y
            for i in range(self.K):
                for l in range(self.M):
                    # max_val = -np.inf

                    res = 0
                    for t in range(n):
                        # if np.sum(responsibilities[t, i, :, :, l]) > max_val:
                        #     max_val = np.sum(responsibilities[t, i, :, :, l])

                        res += np.sum(responsibilities[t, i, :, :, l]) * (data[t, 1] - self.Y_param[i, l, 0])**2

                    self.Y_param[i, l, 1] = (1/self.K + res) / (6 + np.sum(responsibilities[:, i, :, :, l]))
                    # self.Y_param[i, l, 1] = res / (np.sum(responsibilities[:, i, :, :, l]))
            
            print("M-step done")

            print("Current log likelihood:", self.current_likelihood_[epoch])
            print("P:\n", self.P)
            print("Q:\n", self.Q)
            print("R:\n", self.R)
            print("X_param\n", self.X_param)
            print("Y_param\n", self.Y_param)
        
    
    def set_prob_(self, i, j, s):
        denominator = 0
        # Retrive all bits with sum = j
        for bits in self.bits_cache[j]:
            denominator += np.prod(self.R[i][[bits]])

        # s - indexes that coresspond to 1's bit
        return np.prod(self.R[i][[s]]) / denominator


    def calculate_pi_(self, i, j, s, l):
        s_l = l in s
        return self.P[i] * self.Q[i, j] * self.set_prob_(i, j, s) * s_l * self.R[i, l] / np.sum(self.R[i][[s]])


    def calculate_responsb_(self, x, y, i, j, s, l):
        return self.calculate_pi_(i, j, s, l) * \
            norm.pdf(x, loc=self.X_param[i,0], scale=np.sqrt(self.X_param[i,1])) * \
            norm.pdf(y, loc=self.Y_param[i,l][0], scale=np.sqrt(self.Y_param[i,l][1]))

In [6]:
def generate_est_parameters(data, random_state):
    np.random.seed(random_state)

    # Estimate R
    R_est = np.ones(shape=(K_true, M_true)) / M_true
    noise = np.random.dirichlet(np.ones(M_true), size=K_true)
    R_est += noise * 0.05
    R_est /= np.sum(R_est, axis=1).reshape(-1, 1)


    # Estimate P
    P_est = np.full(K_true, 1/K_true)
    noise = np.random.dirichlet(np.ones(K_true))
    P_est += noise * 0.05
    P_est /= np.sum(P_est)

    # Estimate Q
    # Q_est = np.zeros(shape=(K_true, M_true))
    # for i in range(K_true):
    #     for j in range(M_true):
    #         Q_est[i, j] = binom.pmf(k=j, n=M_true-1, p=0.5)
    
    # Estimate Q
    Q_est = np.ones(shape=(K_true, M_true)) / M_true
    noise = np.random.dirichlet(np.ones(M_true), size=K_true)
    Q_est += noise * 0.05
    Q_est /= np.sum(Q_est, axis=1).reshape(-1, 1)

    
    # Estimate X, Y
    X_param_est = np.zeros((K_true, 2))
    Y_param_est = np.zeros((K_true, M_true, 2))

    min_x = np.min(data[:, 0])
    max_x = np.max(data[:, 0])
    intervals = np.linspace(min_x, max_x, K_true + 1)

    for i in range(0, K_true):
        X_param_est[i][0] = np.random.uniform(intervals[i], intervals[i+1])
        X_param_est[i][1] = np.random.uniform(0, (intervals[i+1] - intervals[i])/6)
    
    min_y = np.min(data[:, 1])
    max_y = np.max(data[:, 1])

    intervals = np.linspace(min_y, max_y, M_true + 1)

    for i in range(0, K_true):
        for j in range(0, M_true):
            Y_param_est[i, j, 0] = np.random.uniform(intervals[j], intervals[j+1])
            Y_param_est[i, j, 1] = np.random.uniform(0, (intervals[j+1] - intervals[j])/6)

    return R_est, P_est, Q_est, X_param_est, Y_param_est

In [7]:
def q_to_matrix(Q):
    # Suppose Q is 1d array

    Q_matrix = np.zeros((K_true, M_true))
    
    for i in range(0, K_true):
        for j in range(1, M_true + 1):
            Q_matrix[i, j-1] = binom.pmf(k=j-1, n=M_true-1, p=Q[i])
    
    return Q_matrix

In [8]:
data, labels = generate_dataset(500, R_true, P_true, Q_true, X_param_true, Y_param_true)

In [9]:
R_est, P_est, Q_est, X_param_est, Y_param_est = generate_est_parameters(data, 15)

In [10]:
P_est

array([0.32970847, 0.33689972, 0.3333918 ])

In [11]:
R_est

array([[0.2728179 , 0.24171782, 0.23912256, 0.24634173],
       [0.24659182, 0.2580095 , 0.24772676, 0.24767192],
       [0.23985341, 0.24236182, 0.2751384 , 0.24264637]])

In [12]:
Q_est

array([[0.26027926, 0.25526799, 0.2425835 , 0.24186925],
       [0.23932514, 0.27404927, 0.24066296, 0.24596264],
       [0.24508304, 0.24585991, 0.2696991 , 0.23935795]])

In [13]:
X_param_est

array([[-5.46618745,  0.15695443],
       [ 1.67387439,  0.33152539],
       [ 8.32257252,  0.7380382 ]])

In [14]:
Y_param_est

array([[[-1.93441212,  0.6112717 ],
        [ 5.59134063,  0.41206313],
        [10.06840928,  0.38305901],
        [15.97262934,  0.07816401]],

       [[-3.7446108 ,  0.07685399],
        [ 5.90306767,  0.02860917],
        [10.63502052,  0.13013352],
        [13.76249849,  0.62743195]],

       [[-1.45363521,  0.95623058],
        [ 3.74899167,  0.19174012],
        [11.72146488,  0.34447894],
        [18.12043413,  0.77687024]]])

In [15]:
model = EMModel()
res = model.fit(data, R_est, P_est, Q_est, X_param_est, Y_param_est, epochs=25)

Epoch: 0
E-Step done
Time elapsed: 12.51832628250122 seconds
M-step done
Current log likelihood: -7730.453079544016
P:
 [0.25200134 0.33515476 0.4128439 ]
Q:
 [[0.26152458 0.25355561 0.24189332 0.24302649]
 [0.23945457 0.27385576 0.24059401 0.24609566]
 [0.24556348 0.24520713 0.26940222 0.23982716]]
R:
 [[4.31716185e-02 3.25486574e-01 6.31341803e-01 4.37783790e-09]
 [9.18319510e-02 1.41347903e-01 3.30510185e-01 4.36309961e-01]
 [4.79589158e-01 8.49129386e-02 2.21879421e-01 2.13618483e-01]]
X_param
 [[-5.08913354  1.53303962]
 [ 2.02922127  1.42999143]
 [ 6.69008106  3.3858219 ]]
Y_param
 [[[-0.93087764  0.42065367]
  [ 5.29653645  1.61735351]
  [10.28396489  2.78582541]
  [13.71124665  0.05555555]]

 [[-2.28268722  0.14698306]
  [ 6.87147595  0.08453634]
  [ 8.80685232  0.51921673]
  [11.5226655  13.2344253 ]]

 [[-2.07092755  0.4836534 ]
  [ 6.65414322  0.12244769]
  [ 9.604084    5.73239899]
  [16.93453461  1.08595112]]]
Epoch: 1
E-Step done
Time elapsed: 24.812299251556396 seconds
M