In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import optim
from torch import erf
import torch.nn.functional as F

In [2]:
from sklearn.mixture import GaussianMixture
np.random.seed(0)
data = np.random.rand(100,2)
gmm_gt = GaussianMixture(n_components=3)
gmm_est = GaussianMixture(n_components=3)
gmm_gt.fit(data)
gmm_est.fit(data)

In [4]:
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture

np.random.seed(0)
data = np.random.rand(100,2)
gmm_gt = GaussianMixture(n_components=3)
gmm_est = GaussianMixture(n_components=3)
gmm_gt.fit(data)
gmm_est.fit(data)

def gmm_pdf(x, weights, means, covariances):

    K = len(weights)
    pdf = 0.0
    
    for k in range(K):
        # Calculate the PDF for each Gaussian component
        component_pdf = weights[k] * multivariate_normal.pdf(x, mean=means[k], cov=covariances[k])
        pdf += component_pdf

    return pdf

def gmm_kl(gmm_p, gmm_q, n_samples=1e7):
    X = gmm_p.sample(n_samples)[0]
    p_X = (gmm_pdf(X, gmm_p.weights_, gmm_p.means_, gmm_p.covariances_))
    q_X = (gmm_pdf(X, gmm_q.weights_, gmm_q.means_, gmm_q.covariances_))
    return np.mean(np.log(p_X/q_X))

In [5]:
def loss_MoG(x,y,mu1,mu2,sig1,sig2,pi):
    loss = torch.sum(torch.log(torch.sum(pi*torch.exp(-(x_t - mu1)**2/2/sig1**2)/sig1 *(erf((x_t -mu2)/(torch.sqrt(torch.tensor(2))*sig2))+1),0))) +\
           torch.sum(torch.log(torch.sum(pi*torch.exp(-(y_t - mu2)**2/2/sig2**2)/sig2 *(erf((y_t -mu1)/(torch.sqrt(torch.tensor(2))*sig1))+1),0)))
    return -loss

In [6]:
pi_true = np.array([0.2,0.3,0.5])
mu1_true =  np.array([2,10,6])
mu2_true =  np.array([1,9,5])

sig1_true = np.array([1,1,1])
sig2_true = np.array([1,1,1])

KL_list = []

for num_samples in [1e2,1e3,1e4,1e5]:


    a,b,c = [],[],[]
    K = len(pi_true)
    np.random.seed(0)
    for i in range(K):
        a_i = sig1_true[i]*np.random.randn(int(num_samples*pi_true[i]))+mu1_true[i]
        b_i = sig2_true[i]*np.random.randn(int(num_samples*pi_true[i]))+mu2_true[i]
        a,b = np.concatenate([a,a_i]),np.concatenate([b,b_i])
    a_copy = a.copy()
    b_copy = b.copy()

    a[a_copy<=b_copy] =  np.nan
    b[b_copy<a_copy] =  np.nan

    x = a[~np.isnan(a)]
    y = b[~np.isnan(b)]
    x_t,y_t = torch.Tensor(x),torch.Tensor(y)
    
    # initialization
    pi = torch.Tensor([1.0,2.0,3.0]).reshape(-1,1)
    gmm = GaussianMixture(n_components=3)
    gmm.fit(a[~np.isnan(a)].reshape(-1,1))
    mu1_init = gmm.means_
    sig1_init = gmm.covariances_.flatten()
    gmm.fit(b[~np.isnan(b)].reshape(-1,1))
    mu2_init = gmm.means_
    sig2_init = gmm.covariances_.flatten()
    
    mu1 = torch.Tensor(np.sort(mu1_init.flatten())).reshape(-1,1)
    mu2 = torch.Tensor(np.sort(mu2_init.flatten())).reshape(-1,1)
    sig1 = torch.Tensor([1.0,1.0,1.0]).reshape(-1,1)
    sig2 = torch.Tensor([1.0,1.0,1.0]).reshape(-1,1) 
    
    pi.requires_grad = True 
    mu1.requires_grad = True
    mu2.requires_grad = True
    sig1.requires_grad = True
    sig2.requires_grad = True

    learning_rate = 0.001
    optimizer = optim.Adam([pi,mu1,mu2,sig1,sig2], lr=learning_rate)
    num_iterations = 10000
    for i in range(num_iterations):
        optimizer.zero_grad() 

        output = loss_MoG(x_t,y_t,mu1,mu2,sig1,sig2,F.softmax(pi,dim=0))
        output.backward() 

        optimizer.step()  

        if i % 1000 == 0:
            print(f"Iteration {i}: loss = {output.item()},")
            
    gmm_gt.weights_ = pi_true
    gmm_gt.means_ = np.concatenate([mu1_true.reshape(-1,1),mu2_true.reshape(-1,1)],axis=1)
    for i in range(K):
        gmm_gt.covariances_[i] = np.diag([sig1_true[i],sig2_true[i]])

    gmm_est.weights_ = F.softmax(pi,dim=0).detach().numpy().flatten()
    gmm_est.means_ = np.concatenate([mu1.detach().numpy().reshape(-1,1),mu2.detach().numpy().reshape(-1,1)],axis=1)
    for i in range(K):
        gmm_est.covariances_[i] = np.diag([sig1.detach().numpy().flatten()[i],sig2.detach().numpy().flatten()[i]])
        
    KL = gmm_kl(gmm_gt,gmm_est)
#     print(mu1,mu2,sig1,sig2)
    print("KL",KL)
    KL_list.append(KL)
    

Iteration 0: loss = 164.26007080078125,
Iteration 1000: loss = 117.06965637207031,
Iteration 2000: loss = 115.56617736816406,
Iteration 3000: loss = 115.5250015258789,
Iteration 4000: loss = 115.52375793457031,
Iteration 5000: loss = 115.52374267578125,
Iteration 6000: loss = 115.52375793457031,
Iteration 7000: loss = 115.52374267578125,
Iteration 8000: loss = 115.52375793457031,
Iteration 9000: loss = 115.52375793457031,
KL 0.19728769839223517
Iteration 0: loss = 1670.328369140625,
Iteration 1000: loss = 1185.9508056640625,
Iteration 2000: loss = 1171.931396484375,
Iteration 3000: loss = 1171.665283203125,
Iteration 4000: loss = 1171.659912109375,
Iteration 5000: loss = 1171.6600341796875,
Iteration 6000: loss = 1171.659912109375,
Iteration 7000: loss = 1171.659912109375,
Iteration 8000: loss = 1171.6600341796875,
Iteration 9000: loss = 1171.659912109375,
KL 0.011774248205265205
Iteration 0: loss = 16822.39453125,
Iteration 1000: loss = 11970.08984375,
Iteration 2000: loss = 11828.937

In [12]:
for number in KL_list:
    formatted_number = "{:.2e}".format(number)
    print(formatted_number)

1.97e-01
1.18e-02
1.03e-03
6.81e-05
