In [1]:
import torch
from torch import optim
from torch import erf
import numpy as np
## config Gaussian true means and covs, 20 
mus = []
covs = []
count = 0
for i in range(5):
    np.random.seed(i)
    mean = np.random.rand(2)
    covariance_matrix = np.diag(0.5+np.random.rand(2))
    rhomax  = np.sqrt(np.prod(np.diagonal(covariance_matrix)))
    rho_true = np.random.uniform(-rhomax,rhomax)
    covariance_matrix[0,1] = rho_true
    covariance_matrix[1,0] = rho_true
    mus.append(mean)
    covs.append(covariance_matrix)

In [2]:
covs

[array([[ 1.10276338, -0.16390295],
        [-0.16390295,  1.04488318]]),
 array([[ 0.50011437, -0.44752449],
        [-0.44752449,  0.80233257]]),
 array([[ 1.04966248, -0.15780625],
        [-0.15780625,  0.93532239]]),
 array([[0.79090474, 0.70269127],
        [0.70269127, 1.01082761]]),
 array([[1.47268436, 0.52894445],
        [0.52894445, 1.21481599]])]

In [3]:
def neglogLikelihood(x,y,mu1,mu2,sig1,sig2,rho):
    #simplified version
    eps = 1e-7
    
    det = (sig1**2 * sig2**2-rho**2)
    
    
    loss1 =  torch.log(sig1)+ 1/2/sig1**2 *(x-mu1)**2 - torch.log(erf((rho/sig1*(x-mu1)-sig1*(x-mu2))/(-torch.sqrt(2*det)))+1)
    loss2 =  torch.log(sig2)+ 1/2/sig2**2 *(y-mu2)**2 - torch.log(erf((rho/sig2*(y-mu2)-sig2*(y-mu1))/(-torch.sqrt(2*det)))+1)
    
    loss = torch.sum(loss1)+torch.sum(loss2)

    return loss


In [4]:
from scipy.linalg import inv, det
def KL(mu1,Sigma1,mu2,Sigma2):
    d = len(mu1)  
    det_Sigma1 = det(Sigma1)
    det_Sigma2 = det(Sigma2)
    inv_Sigma2 = inv(Sigma2)
    kl_divergence = 0.5 * (np.log(det_Sigma2 / det_Sigma1) - d + np.trace(np.dot(inv_Sigma2, Sigma1)) + np.dot(np.dot((mu2 - mu1).T, inv_Sigma2), (mu2 - mu1)))
    return kl_divergence

In [11]:
import time

KL_rec = []
time_rec = []
for num_samples in [1000,3000,5000,7000,9000]: 
    for q in range(5):
        for seed in range(5):
            np.random.seed(seed)
            samples = np.random.multivariate_normal(mus[q], covs[q], num_samples)

            a = samples[:,0]
            b = samples[:,1]

            a_copy = a.copy()
            b_copy = b.copy()
            a[a_copy<=b_copy] =  np.nan
            b[b_copy<a_copy] =  np.nan

            x = a[~np.isnan(a)]
            y = b[~np.isnan(b)]
            x_t,y_t = torch.Tensor(x),torch.Tensor(y)

            # initialize
            mu1,mu2,sig1,sig2,rho =torch.tensor(1.0, requires_grad=True), torch.tensor(1.0, requires_grad=True),torch.tensor(1.0, requires_grad=True),torch.tensor(1.0, requires_grad=True),torch.tensor(0.0, requires_grad=True)

            learning_rate = 5e-4
            optimizer = optim.Adam([mu1,mu2,sig1,sig2,rho], lr=learning_rate)
            num_iterations = 15000

            st = time.time()

            for i in range(num_iterations):

            #     if i >3000:
            #         learning_rate = 0.00000000001
            #         optimizer = optim.SGD([mu1,mu2,sig1,sig2,rho], lr=learning_rate)
                optimizer.zero_grad()  

                output = neglogLikelihood(x_t,y_t,mu1,mu2,sig1,sig2,rho)
                output.backward()  

                optimizer.step()  

                if i % 1000 == 0:
                    print(f"Iteration {i}: loss = {output.item()},")    
            mu_e = np.array([mu1.item(),mu2.item()])
            cov_e = np.array([[(sig1**2).item(),rho.item()],[rho.item(),(sig2**2).item()]])

            end = time.time()

            KL_rec.append(KL(mus[q],covs[q],mu_e,cov_e))
            time_rec.append(end - st)
            print("sample size", num_samples, "dataset:",q)
            print("KL:", KL(mus[q],covs[q],mu_e,cov_e))
            print("time:",end - st)        
    

Iteration 0: loss = 350.68475341796875,
Iteration 1000: loss = 270.4920654296875,
Iteration 2000: loss = 270.28851318359375,
Iteration 3000: loss = 270.00665283203125,
Iteration 4000: loss = 269.6132507324219,
Iteration 5000: loss = 269.2387390136719,
Iteration 6000: loss = 269.06976318359375,
Iteration 7000: loss = 269.04754638671875,
Iteration 8000: loss = 269.0470886230469,
Iteration 9000: loss = 269.04705810546875,
Iteration 10000: loss = 269.0470886230469,
Iteration 11000: loss = 269.0470886230469,
Iteration 12000: loss = 269.0470886230469,
Iteration 13000: loss = 269.04705810546875,
Iteration 14000: loss = 269.0470886230469,
sample size 1000 dataset: 0
KL: 0.002800854352296275
time: 24.586008548736572
Iteration 0: loss = 361.9060974121094,
Iteration 1000: loss = 299.08929443359375,
Iteration 2000: loss = 298.953369140625,
Iteration 3000: loss = 298.80633544921875,
Iteration 4000: loss = 298.6278076171875,
Iteration 5000: loss = 298.48193359375,
Iteration 6000: loss = 298.42303466

Iteration 7000: loss = 272.45697021484375,
Iteration 8000: loss = 272.0450744628906,
Iteration 9000: loss = 271.8448486328125,
Iteration 10000: loss = 271.80841064453125,
Iteration 11000: loss = 271.8070373535156,
Iteration 12000: loss = 271.8070068359375,
Iteration 13000: loss = 271.8070068359375,
Iteration 14000: loss = 271.8070373535156,
sample size 1000 dataset: 2
KL: 0.030171187490370437
time: 24.581180095672607
Iteration 0: loss = 733.2820434570312,
Iteration 1000: loss = 263.34039306640625,
Iteration 2000: loss = 247.25732421875,
Iteration 3000: loss = 246.95687866210938,
Iteration 4000: loss = 246.59979248046875,
Iteration 5000: loss = 246.01214599609375,
Iteration 6000: loss = 245.04742431640625,
Iteration 7000: loss = 243.47756958007812,
Iteration 8000: loss = 241.29977416992188,
Iteration 9000: loss = 239.72592163085938,
Iteration 10000: loss = 239.41824340820312,
Iteration 11000: loss = 239.4098663330078,
Iteration 12000: loss = 239.40985107421875,
Iteration 13000: loss = 2

Iteration 1000: loss = 512.9376220703125,
Iteration 2000: loss = 509.77734375,
Iteration 3000: loss = 509.7655029296875,
Iteration 4000: loss = 509.7640075683594,
Iteration 5000: loss = 509.7619323730469,
Iteration 6000: loss = 509.7596435546875,
Iteration 7000: loss = 509.75750732421875,
Iteration 8000: loss = 509.7563781738281,
Iteration 9000: loss = 509.7559814453125,
Iteration 10000: loss = 509.7559509277344,
Iteration 11000: loss = 509.7559814453125,
Iteration 12000: loss = 509.75592041015625,
Iteration 13000: loss = 509.75592041015625,
Iteration 14000: loss = 509.7559509277344,
sample size 1000 dataset: 4
KL: 0.022548417248061402
time: 24.510227918624878
Iteration 0: loss = 651.1151123046875,
Iteration 1000: loss = 487.6683349609375,
Iteration 2000: loss = 477.7252197265625,
Iteration 3000: loss = 477.56292724609375,
Iteration 4000: loss = 477.5084228515625,
Iteration 5000: loss = 477.42828369140625,
Iteration 6000: loss = 477.3177795410156,
Iteration 7000: loss = 477.17749023437

Iteration 11000: loss = -447.48956298828125,
Iteration 12000: loss = -447.48944091796875,
Iteration 13000: loss = -447.4895324707031,
Iteration 14000: loss = -447.489501953125,
sample size 3000 dataset: 1
KL: 0.014259254548679014
time: 24.89289426803589
Iteration 0: loss = 1941.40283203125,
Iteration 1000: loss = 726.0023803710938,
Iteration 2000: loss = 697.2880249023438,
Iteration 3000: loss = 696.7861328125,
Iteration 4000: loss = 696.101318359375,
Iteration 5000: loss = 695.0242919921875,
Iteration 6000: loss = 693.40478515625,
Iteration 7000: loss = 691.1729736328125,
Iteration 8000: loss = 688.75341796875,
Iteration 9000: loss = 687.322998046875,
Iteration 10000: loss = 687.0533447265625,
Iteration 11000: loss = 687.0447387695312,
Iteration 12000: loss = 687.044677734375,
Iteration 13000: loss = 687.0447998046875,
Iteration 14000: loss = 687.044677734375,
sample size 3000 dataset: 2
KL: 0.003400369217362171
time: 24.80709743499756
Iteration 0: loss = 2052.88916015625,
Iteration 1

Iteration 6000: loss = 1553.9111328125,
Iteration 7000: loss = 1553.90087890625,
Iteration 8000: loss = 1553.888916015625,
Iteration 9000: loss = 1553.8779296875,
Iteration 10000: loss = 1553.87158203125,
Iteration 11000: loss = 1553.86962890625,
Iteration 12000: loss = 1553.86962890625,
Iteration 13000: loss = 1553.869384765625,
Iteration 14000: loss = 1553.8695068359375,
sample size 3000 dataset: 4
KL: 0.0283788734174696
time: 24.838602304458618
Iteration 0: loss = 2189.242431640625,
Iteration 1000: loss = 1586.6483154296875,
Iteration 2000: loss = 1542.079345703125,
Iteration 3000: loss = 1540.8511962890625,
Iteration 4000: loss = 1540.772705078125,
Iteration 5000: loss = 1540.65771484375,
Iteration 6000: loss = 1540.4906005859375,
Iteration 7000: loss = 1540.2706298828125,
Iteration 8000: loss = 1540.0283203125,
Iteration 9000: loss = 1539.836181640625,
Iteration 10000: loss = 1539.7490234375,
Iteration 11000: loss = 1539.733642578125,
Iteration 12000: loss = 1539.73291015625,
Iter

Iteration 1000: loss = -675.7640380859375,
Iteration 2000: loss = -709.9598999023438,
Iteration 3000: loss = -719.065185546875,
Iteration 4000: loss = -719.5358276367188,
Iteration 5000: loss = -719.538330078125,
Iteration 6000: loss = -719.538330078125,
Iteration 7000: loss = -719.538330078125,
Iteration 8000: loss = -719.538330078125,
Iteration 9000: loss = -719.5384521484375,
Iteration 10000: loss = -719.5383911132812,
Iteration 11000: loss = -719.538330078125,
Iteration 12000: loss = -719.5382080078125,
Iteration 13000: loss = -719.538330078125,
Iteration 14000: loss = -719.5383911132812,
sample size 5000 dataset: 1
KL: 0.0018848060136436632
time: 25.260265588760376
Iteration 0: loss = 695.195068359375,
Iteration 1000: loss = -732.0653076171875,
Iteration 2000: loss = -753.4146118164062,
Iteration 3000: loss = -758.8843994140625,
Iteration 4000: loss = -759.2648315429688,
Iteration 5000: loss = -759.267822265625,
Iteration 6000: loss = -759.2680053710938,
Iteration 7000: loss = -75

Iteration 13000: loss = 1832.524658203125,
Iteration 14000: loss = 1832.524658203125,
sample size 5000 dataset: 3
KL: 0.005272489508927602
time: 25.244797945022583
Iteration 0: loss = 3467.5419921875,
Iteration 1000: loss = 2572.687255859375,
Iteration 2000: loss = 2521.96484375,
Iteration 3000: loss = 2521.06298828125,
Iteration 4000: loss = 2520.8876953125,
Iteration 5000: loss = 2520.635009765625,
Iteration 6000: loss = 2520.302734375,
Iteration 7000: loss = 2519.9228515625,
Iteration 8000: loss = 2519.5732421875,
Iteration 9000: loss = 2519.3466796875,
Iteration 10000: loss = 2519.2607421875,
Iteration 11000: loss = 2519.248046875,
Iteration 12000: loss = 2519.24755859375,
Iteration 13000: loss = 2519.24755859375,
Iteration 14000: loss = 2519.24755859375,
sample size 5000 dataset: 4
KL: 0.20214898366307654
time: 25.214844942092896
Iteration 0: loss = 3582.367431640625,
Iteration 1000: loss = 2644.57568359375,
Iteration 2000: loss = 2579.100830078125,
Iteration 3000: loss = 2577.552

Iteration 10000: loss = -960.9517211914062,
Iteration 11000: loss = -960.95166015625,
Iteration 12000: loss = -960.95166015625,
Iteration 13000: loss = -960.9514770507812,
Iteration 14000: loss = -960.9517822265625,
sample size 7000 dataset: 1
KL: 0.0014870001183888237
time: 25.703140258789062
Iteration 0: loss = 1025.1768798828125,
Iteration 1000: loss = -911.1903076171875,
Iteration 2000: loss = -959.1634521484375,
Iteration 3000: loss = -974.6278686523438,
Iteration 4000: loss = -975.7517700195312,
Iteration 5000: loss = -975.7623291015625,
Iteration 6000: loss = -975.76220703125,
Iteration 7000: loss = -975.7621459960938,
Iteration 8000: loss = -975.76220703125,
Iteration 9000: loss = -975.7623291015625,
Iteration 10000: loss = -975.7622680664062,
Iteration 11000: loss = -975.7620849609375,
Iteration 12000: loss = -975.7622680664062,
Iteration 13000: loss = -975.7623291015625,
Iteration 14000: loss = -975.7621459960938,
sample size 7000 dataset: 1
KL: 0.0009665382547135663
time: 25

Iteration 4000: loss = 2653.679931640625,
Iteration 5000: loss = 2653.43896484375,
Iteration 6000: loss = 2653.1005859375,
Iteration 7000: loss = 2652.680419921875,
Iteration 8000: loss = 2652.26025390625,
Iteration 9000: loss = 2651.9765625,
Iteration 10000: loss = 2651.880859375,
Iteration 11000: loss = 2651.8701171875,
Iteration 12000: loss = 2651.8701171875,
Iteration 13000: loss = 2651.8701171875,
Iteration 14000: loss = 2651.86962890625,
sample size 7000 dataset: 3
KL: 0.06861749160916754
time: 25.70235776901245
Iteration 0: loss = 5508.8642578125,
Iteration 1000: loss = 2880.207763671875,
Iteration 2000: loss = 2635.12060546875,
Iteration 3000: loss = 2631.384765625,
Iteration 4000: loss = 2631.353271484375,
Iteration 5000: loss = 2631.318359375,
Iteration 6000: loss = 2631.26904296875,
Iteration 7000: loss = 2631.205078125,
Iteration 8000: loss = 2631.13525390625,
Iteration 9000: loss = 2631.07666015625,
Iteration 10000: loss = 2631.04541015625,
Iteration 11000: loss = 2631.037

Iteration 4000: loss = -1259.1285400390625,
Iteration 5000: loss = -1259.147216796875,
Iteration 6000: loss = -1259.147705078125,
Iteration 7000: loss = -1259.1474609375,
Iteration 8000: loss = -1259.1474609375,
Iteration 9000: loss = -1259.1478271484375,
Iteration 10000: loss = -1259.1478271484375,
Iteration 11000: loss = -1259.1474609375,
Iteration 12000: loss = -1259.1475830078125,
Iteration 13000: loss = -1259.1474609375,
Iteration 14000: loss = -1259.1473388671875,
sample size 9000 dataset: 1
KL: 0.0013454517656867507
time: 26.064138412475586
Iteration 0: loss = 1271.00732421875,
Iteration 1000: loss = -1134.9271240234375,
Iteration 2000: loss = -1199.503662109375,
Iteration 3000: loss = -1220.60986328125,
Iteration 4000: loss = -1222.1082763671875,
Iteration 5000: loss = -1222.122314453125,
Iteration 6000: loss = -1222.1226806640625,
Iteration 7000: loss = -1222.1226806640625,
Iteration 8000: loss = -1222.122314453125,
Iteration 9000: loss = -1222.12255859375,
Iteration 10000: lo

Iteration 1000: loss = 3983.29833984375,
Iteration 2000: loss = 3602.302734375,
Iteration 3000: loss = 3594.21826171875,
Iteration 4000: loss = 3594.10009765625,
Iteration 5000: loss = 3593.966796875,
Iteration 6000: loss = 3593.764892578125,
Iteration 7000: loss = 3593.477783203125,
Iteration 8000: loss = 3593.11083984375,
Iteration 9000: loss = 3592.7255859375,
Iteration 10000: loss = 3592.44580078125,
Iteration 11000: loss = 3592.34228515625,
Iteration 12000: loss = 3592.330078125,
Iteration 13000: loss = 3592.32958984375,
Iteration 14000: loss = 3592.32958984375,
sample size 9000 dataset: 3
KL: 0.07036204174311972
time: 26.073105573654175
Iteration 0: loss = 6940.5126953125,
Iteration 1000: loss = 3658.235595703125,
Iteration 2000: loss = 3389.7255859375,
Iteration 3000: loss = 3385.87890625,
Iteration 4000: loss = 3385.69677734375,
Iteration 5000: loss = 3385.431640625,
Iteration 6000: loss = 3385.05810546875,
Iteration 7000: loss = 3384.5927734375,
Iteration 8000: loss = 3384.126

In [12]:
np.set_printoptions(precision=2, suppress=True, formatter={'float_kind': '{:.2e}'.format})
KLmatrix = np.mean(np.array(KL_rec).reshape(5,5,5),axis =2)
print(KLmatrix)


[[1.50e-02 1.49e-02 2.33e-02 3.20e-01 4.71e+01]
 [1.18e-02 6.08e-03 7.80e-03 6.89e-01 8.26e-02]
 [1.57e-02 3.94e-03 1.53e-02 6.04e-01 8.21e-02]
 [3.45e-03 2.32e-03 5.42e-03 7.91e-02 4.33e-02]
 [2.99e-03 1.73e-03 5.91e-03 4.22e-02 4.69e-02]]


In [13]:
np.mean(np.array(time_rec).reshape(5,5,5),axis =2)

array([[2.46e+01, 2.46e+01, 2.46e+01, 2.46e+01, 2.45e+01],
       [2.49e+01, 2.49e+01, 2.48e+01, 2.48e+01, 2.48e+01],
       [2.53e+01, 2.53e+01, 2.52e+01, 2.52e+01, 2.52e+01],
       [2.56e+01, 2.57e+01, 2.56e+01, 2.57e+01, 2.56e+01],
       [2.61e+01, 2.61e+01, 2.60e+01, 2.61e+01, 2.60e+01]])