In [5]:
import numpy as np
# import scipy.integrate as integrate
import scipy.stats as scs
from math import *

In [6]:
def initiate(k: int):
    mu_1 = np.zeros(k)
    mu_2 = np.zeros(k)
    tau_1 = np.ones(k)
    tau_2 = np.ones(k)
    lambd =np.ones(k)
    pi = np.ones(k)/k
    return mu_1,mu_2,tau_1,tau_2,lambd,pi

def A_computation(n: int, K: int,X_1: np.ndarray,X_2: np.ndarray, S: np.ndarray, previous_mu_1: np.ndarray,previous_mu_2: np.ndarray,previous_tau_1: np.ndarray,previous_tau_2: np.ndarray,previous_lambda: np.ndarray,previous_pi: np.ndarray):
    A= np.zeros(K)
    for k in range(K):
        p_X_1 = scs.norm(previous_mu_1[k],1/previous_tau_1[k]).pdf(X_1[n])
        p_X_2 = scs.norm(previous_mu_2[k],1/previous_tau_2[k]).pdf(X_2[n])
        p_S = scs.poisson(previous_lambda[k]).pmf(S[n])
        p_Z = previous_pi[k]
        A[k]=p_X_1*p_X_2*p_S*p_Z
    return A/np.sum(A) #normalize

def Compare(previous_mu_1: np.ndarray,previous_mu_2: np.ndarray,previous_tau_1: np.ndarray,previous_tau_2: np.ndarray,previous_lambda: np.ndarray,previous_pi: np.ndarray,mu_1: np.ndarray,mu_2: np.ndarray,tau_1: np.ndarray,tau_2: np.ndarray,lambd: np.ndarray,pi):
    res = np.sum(np.power(previous_mu_1-mu_1,2))
    res += np.sum(np.power(previous_mu_2-mu_2,2))
    res += np.sum(np.power(previous_tau_1-tau_1,2))
    res += np.sum(np.power(previous_tau_2-tau_2,2))
    res += np.sum(np.power(previous_lambda-lambd,2))
    res += np.sum(np.power(previous_pi-pi,2))
    return res

def EM(threshold: float, K: int, X_1: np.ndarray,X_2: np.ndarray, S: np.ndarray):
    N= len(X_1)
    mu_1,mu_2,tau_1,tau_2,lambd,pi = initiate(K)
    previous_mu_1,previous_mu_2,previous_tau_1,previous_tau_2,previous_lambda,previous_pi = mu_1+1,mu_2,tau_1*3,tau_2*3,lambd,pi*3
    loop=0
    while(Compare(previous_mu_1,previous_mu_2,previous_tau_1,previous_tau_2,previous_lambda,previous_pi,mu_1,mu_2,tau_1,tau_2,lambd,pi)>threshold):
        loop+=1
        previous_mu_1,previous_mu_2,previous_tau_1,previous_tau_2,previous_lambda,previous_pi = mu_1.copy(),mu_2.copy(),tau_1.copy(),tau_2.copy(),lambd.copy(),pi.copy()
        A = np.zeros((N,K))
        for i in range(N):
            A[i] = A_computation(i,K,X_1,X_2,S,previous_mu_1,previous_mu_2,previous_tau_1,previous_tau_2,previous_lambda,previous_pi) #expectation
        A = A.transpose()
        for k in range(K): #maximization
            pi[k]=np.sum(A[k])
            mu_1[k] = np.sum(A[k]*X_1)/np.sum(A[k])
            mu_2[k] = np.sum(A[k]*X_2)/np.sum(A[k])
            tau_1[k] = np.sum(A[k])/(np.sum(A[k]*np.power(X_1-mu_1[k],2)))
            tau_2[k] = np.sum(A[k])/(np.sum(A[k]*np.power(X_2-mu_2[k],2)))
            lambd[k] = np.sum(A[k]*S)/(np.sum(A[k]))
        pi = pi/np.sum(pi)
        print(loop,"\t", mu_1[0],"\t", mu_2[0],"\t", tau_1[0],"\t", tau_2[0],"\t", lambd[0],"\t",pi[0])
        print(Compare(previous_mu_1,previous_mu_2,previous_tau_1,previous_tau_2,previous_lambda,previous_pi,mu_1,mu_2,tau_1,tau_2,lambd,pi))
    return mu_1,mu_2,tau_1,tau_2,lambd,pi

def read_files(X_file_name: str, S_file_name: str):
    X_file ="data\\"+ X_file_name
    S_file ="data\\"+ S_file_name
    with open(X_file) as file:
        X = file.readlines()
    X_1 = np.zeros(len(X))
    X_2 = np.zeros(len(X))
    for i,line in enumerate(X):
        X_1[i]=float(line.split(" ")[0])
        X_2[i]=float(line.split("\n")[0].split(" ")[1])
    with open(S_file) as file:
        S_temp = file.readlines()
    S = np.zeros(len(S_temp))
    for i,line in enumerate(S_temp):
        S[i]=float(line.split("\n")[0])
    return X_1, X_2, S


In [7]:
X_1,X_2,S = read_files("X.txt","S.txt")
EM(0.1,5,X_1,X_2,S)

1 	 1.9039266394177017 	 1.752210583554002 	 0.22017871906957168 	 0.21886533341209005 	 5.050000000000001 	 0.19999999999999998
121.57985587470213
2 	 1.903926639417702 	 1.752210583554002 	 0.22017871906957168 	 0.21886533341209005 	 5.050000000000001 	 0.19999999999999998
2.465190328815662e-31


(array([1.90392664, 1.90392664, 1.90392664, 1.90392664, 1.90392664]),
 array([1.75221058, 1.75221058, 1.75221058, 1.75221058, 1.75221058]),
 array([0.22017872, 0.22017872, 0.22017872, 0.22017872, 0.22017872]),
 array([0.21886533, 0.21886533, 0.21886533, 0.21886533, 0.21886533]),
 array([5.05, 5.05, 5.05, 5.05, 5.05]),
 array([0.2, 0.2, 0.2, 0.2, 0.2]))

In [2]:
import plotly.graph_objects as go
import plotly.express as px

In [8]:
(px.scatter(x=X_1,y=X_2,color=S))

In [29]:
print((px.scatter(x=X_1,y=X_2,color=S))["data"])

(Scatter({
    'hovertemplate': 'x=%{x}<br>y=%{y}<br>color=%{marker.color}<extra></extra>',
    'legendgroup': '',
    'marker': {'color': array([ 9.,  8.,  2., 21.,  3.,  0.,  5.,  0.,  0.,  5.,  4.,  0., 11.,  9.,
                                2.,  3.,  0.,  0.,  4.,  1.,  0.,  2.,  5., 10.,  0., 12., 16.,  4.,
                                1., 12.,  3.,  2.,  0.,  9., 10.,  9.,  3.,  9.,  4.,  4.,  5.,  0.,
                                0., 13.,  2.,  7., 13., 10.,  5.,  2.,  4.,  3.,  0.,  0.,  9.,  0.,
                                1., 11.,  2., 14.,  2.,  1., 10.,  2.,  0., 10.,  0.,  0., 15.,  0.,
                               10.,  2.,  8., 15., 10.,  3.,  8.,  0.,  1.,  0., 13.,  0.,  2.,  1.,
                                1., 10.,  1.,  3., 13.,  0.,  8., 11.,  0.,  2., 11.,  4., 13.,  5.,
                                7.,  0.]),
               'coloraxis': 'coloraxis',
               'symbol': 'circle'},
    'mode': 'markers',
    'name': '',
    'orientation': 

In [28]:
print(go.Figure(data=go.Scatter(x=X_1,y=X_2,mode="markers", marker_color=S,marker_showscale=True)))

Figure({
    'data': [{'marker': {'color': array([ 9.,  8.,  2., 21.,  3.,  0.,  5.,  0.,  0.,  5.,  4.,  0., 11.,  9.,
                                          2.,  3.,  0.,  0.,  4.,  1.,  0.,  2.,  5., 10.,  0., 12., 16.,  4.,
                                          1., 12.,  3.,  2.,  0.,  9., 10.,  9.,  3.,  9.,  4.,  4.,  5.,  0.,
                                          0., 13.,  2.,  7., 13., 10.,  5.,  2.,  4.,  3.,  0.,  0.,  9.,  0.,
                                          1., 11.,  2., 14.,  2.,  1., 10.,  2.,  0., 10.,  0.,  0., 15.,  0.,
                                         10.,  2.,  8., 15., 10.,  3.,  8.,  0.,  1.,  0., 13.,  0.,  2.,  1.,
                                          1., 10.,  1.,  3., 13.,  0.,  8., 11.,  0.,  2., 11.,  4., 13.,  5.,
                                          7.,  0.]),
                         'showscale': True},
              'mode': 'markers',
              'type': 'scatter',
              'x': array([ 1.27058976,  1.23278615

In [33]:
go.Figure(data=go.Scatter(x=X_1,y=X_2,mode="markers", marker={"color":S,'coloraxis':'coloraxis'}))