In [1]:
## Gaussian case gradient descent

In [2]:
import torch
from torch import optim
from torch import erf
import numpy as np
import numpy as np
## config Gaussian true means and covs, 20 
mus = []
covs = []
count = 0
for i in range(5):
    np.random.seed(i)
    mean = np.random.rand(2)
    covariance_matrix = np.diag(0.5+np.random.rand(2))
    mus.append(mean)
    covs.append(covariance_matrix)
    

In [3]:
covs

[array([[1.10276338, 0.        ],
        [0.        , 1.04488318]]),
 array([[0.50011437, 0.        ],
        [0.        , 0.80233257]]),
 array([[1.04966248, 0.        ],
        [0.        , 0.93532239]]),
 array([[0.79090474, 0.        ],
        [0.        , 1.01082761]]),
 array([[1.47268436, 0.        ],
        [0.        , 1.21481599]])]

In [4]:
def neglogLikelihood(x,y,mu1,mu2,sig1,sig2):
    loss = torch.sum((x-mu1)**2/(2*sig1**2)- torch.log(erf((x-mu2)/(torch.sqrt(torch.tensor(2))*sig2))+1) + torch.log(sig1) ) +\
  torch.sum((y-mu2)**2/(2*sig2**2)- torch.log(erf((y-mu1)/(torch.sqrt(torch.tensor(2))*sig1))+1) + torch.log(sig2))
    return loss

In [5]:
from scipy.linalg import inv, det
def KL(mu1,Sigma1,mu2,Sigma2):
    d = len(mu1)  
    det_Sigma1 = det(Sigma1)
    det_Sigma2 = det(Sigma2)
    inv_Sigma2 = inv(Sigma2)
    kl_divergence = 0.5 * (np.log(det_Sigma2 / det_Sigma1) - d + np.trace(np.dot(inv_Sigma2, Sigma1)) + np.dot(np.dot((mu2 - mu1).T, inv_Sigma2), (mu2 - mu1)))
    return kl_divergence

In [6]:
import time

KL_rec = []
time_rec = []
for num_samples in [1000,3000,5000,7000,9000]: 
    for q in range(5):
        for seed in range(5):
            np.random.seed(seed)
            samples = np.random.multivariate_normal(mus[q], covs[q], num_samples)

            a = samples[:,0]
            b = samples[:,1]

            a_copy = a.copy()
            b_copy = b.copy()
            a[a_copy<=b_copy] =  np.nan
            b[b_copy<a_copy] =  np.nan

            x = a[~np.isnan(a)]
            y = b[~np.isnan(b)]
            
            
            x_t,y_t = torch.Tensor(x),torch.Tensor(y)
            mu1,mu2,sig1,sig2 = torch.tensor(1.0, requires_grad=True),torch.tensor(1.0, requires_grad=True),torch.tensor(1.0, requires_grad=True),torch.tensor(1.0, requires_grad=True)

            learning_rate = 1e-2
            optimizer = optim.Adam([mu1,mu2,sig1,sig2], lr=learning_rate)
            num_iterations = 3000
            
            st = time.time()
            for i in range(num_iterations):
                optimizer.zero_grad()  

                output = neglogLikelihood(x_t,y_t,mu1,mu2,sig1,sig2)
                output.backward()  

                optimizer.step()  

                if i % 250 == 0:
                    print(f"Iteration {i}: loss = {output.item()},") 
            mu_e = np.array([mu1.item(),mu2.item()])
            cov_e = np.array([[(sig1**2).item(),0],[0,(sig2**2).item()]])

            end = time.time()

            KL_rec.append(KL(mus[q],covs[q],mu_e,cov_e))
            time_rec.append(end - st)
            print("sample size", num_samples, "dataset:",q)
            print("KL:", KL(mus[q],covs[q],mu_e,cov_e))
            print("time:",end - st)        
    

Iteration 0: loss = 460.593505859375,
Iteration 250: loss = 343.786376953125,
Iteration 500: loss = 343.7863464355469,
Iteration 750: loss = 343.78631591796875,
Iteration 1000: loss = 343.786376953125,
Iteration 1250: loss = 343.786376953125,
Iteration 1500: loss = 343.786376953125,
Iteration 1750: loss = 343.78631591796875,
Iteration 2000: loss = 343.78643798828125,
Iteration 2250: loss = 343.786376953125,
Iteration 2500: loss = 343.78643798828125,
Iteration 2750: loss = 343.786376953125,
sample size 1000 dataset: 0
KL: 0.0013953347384053757
time: 3.5672402381896973
Iteration 0: loss = 407.57415771484375,
Iteration 250: loss = 336.431396484375,
Iteration 500: loss = 336.431396484375,
Iteration 750: loss = 336.431396484375,
Iteration 1000: loss = 336.431396484375,
Iteration 1250: loss = 336.431396484375,
Iteration 1500: loss = 336.4314270019531,
Iteration 1750: loss = 336.431396484375,
Iteration 2000: loss = 336.43243408203125,
Iteration 2250: loss = 336.431396484375,
Iteration 2500: l

Iteration 1000: loss = 271.75836181640625,
Iteration 1250: loss = 271.7583923339844,
Iteration 1500: loss = 271.7583923339844,
Iteration 1750: loss = 271.7583923339844,
Iteration 2000: loss = 271.7583923339844,
Iteration 2250: loss = 271.7583923339844,
Iteration 2500: loss = 271.75836181640625,
Iteration 2750: loss = 271.75836181640625,
sample size 1000 dataset: 2
KL: 0.0029917869839122457
time: 3.54911470413208
Iteration 0: loss = 402.09381103515625,
Iteration 250: loss = 255.12173461914062,
Iteration 500: loss = 255.12173461914062,
Iteration 750: loss = 255.12173461914062,
Iteration 1000: loss = 255.1217498779297,
Iteration 1250: loss = 255.1217498779297,
Iteration 1500: loss = 255.12176513671875,
Iteration 1750: loss = 255.12176513671875,
Iteration 2000: loss = 255.1217498779297,
Iteration 2250: loss = 255.1217498779297,
Iteration 2500: loss = 255.1217498779297,
Iteration 2750: loss = 255.1217498779297,
sample size 1000 dataset: 3
KL: 0.0019269608356309038
time: 3.5535824298858643
I

Iteration 1750: loss = 1008.4708251953125,
Iteration 2000: loss = 1008.470947265625,
Iteration 2250: loss = 1008.470947265625,
Iteration 2500: loss = 1008.4708862304688,
Iteration 2750: loss = 1008.470703125,
sample size 3000 dataset: 0
KL: 0.0013233115180036202
time: 3.666642904281616
Iteration 0: loss = 1212.037841796875,
Iteration 250: loss = 938.33251953125,
Iteration 500: loss = 938.33251953125,
Iteration 750: loss = 938.3325805664062,
Iteration 1000: loss = 938.33251953125,
Iteration 1250: loss = 938.3325805664062,
Iteration 1500: loss = 938.3325805664062,
Iteration 1750: loss = 938.33251953125,
Iteration 2000: loss = 938.33251953125,
Iteration 2250: loss = 938.33251953125,
Iteration 2500: loss = 938.33251953125,
Iteration 2750: loss = 938.33251953125,
sample size 3000 dataset: 0
KL: 0.0017709992932392242
time: 3.699453115463257
Iteration 0: loss = 1075.98486328125,
Iteration 250: loss = 203.17552185058594,
Iteration 500: loss = 203.1754608154297,
Iteration 750: loss = 203.175460

sample size 3000 dataset: 3
KL: 0.0039919170223064425
time: 3.6824142932891846
Iteration 0: loss = 1178.478759765625,
Iteration 250: loss = 738.077880859375,
Iteration 500: loss = 738.077880859375,
Iteration 750: loss = 738.077880859375,
Iteration 1000: loss = 738.077880859375,
Iteration 1250: loss = 738.077880859375,
Iteration 1500: loss = 738.077880859375,
Iteration 1750: loss = 738.077880859375,
Iteration 2000: loss = 738.077880859375,
Iteration 2250: loss = 738.077880859375,
Iteration 2500: loss = 738.077880859375,
Iteration 2750: loss = 738.077880859375,
sample size 3000 dataset: 3
KL: 0.001531821190430935
time: 3.677154302597046
Iteration 0: loss = 1070.8487548828125,
Iteration 250: loss = 678.4505615234375,
Iteration 500: loss = 678.4505004882812,
Iteration 750: loss = 678.4505615234375,
Iteration 1000: loss = 678.4505615234375,
Iteration 1250: loss = 678.4505615234375,
Iteration 1500: loss = 678.4505004882812,
Iteration 1750: loss = 678.4505615234375,
Iteration 2000: loss = 678

Iteration 750: loss = 443.15997314453125,
Iteration 1000: loss = 443.15997314453125,
Iteration 1250: loss = 443.1600341796875,
Iteration 1500: loss = 443.1600646972656,
Iteration 1750: loss = 443.1600341796875,
Iteration 2000: loss = 443.1600341796875,
Iteration 2250: loss = 443.1600341796875,
Iteration 2500: loss = 443.1600341796875,
Iteration 2750: loss = 443.1600646972656,
sample size 5000 dataset: 1
KL: 0.00241624843078222
time: 3.779120922088623
Iteration 0: loss = 1827.072021484375,
Iteration 250: loss = 373.86004638671875,
Iteration 500: loss = 373.8599853515625,
Iteration 750: loss = 373.860107421875,
Iteration 1000: loss = 373.860107421875,
Iteration 1250: loss = 373.860107421875,
Iteration 1500: loss = 373.860107421875,
Iteration 1750: loss = 373.86004638671875,
Iteration 2000: loss = 373.860107421875,
Iteration 2250: loss = 373.860107421875,
Iteration 2500: loss = 373.860107421875,
Iteration 2750: loss = 373.860107421875,
sample size 5000 dataset: 1
KL: 0.0004430696887217362

Iteration 2000: loss = 2192.83935546875,
Iteration 2250: loss = 2192.83935546875,
Iteration 2500: loss = 2192.839599609375,
Iteration 2750: loss = 2192.83935546875,
sample size 5000 dataset: 4
KL: 0.0005770164912868799
time: 3.74765944480896
Iteration 0: loss = 2733.06298828125,
Iteration 250: loss = 2261.3916015625,
Iteration 500: loss = 2261.3916015625,
Iteration 750: loss = 2261.3916015625,
Iteration 1000: loss = 2261.3916015625,
Iteration 1250: loss = 2261.3916015625,
Iteration 1500: loss = 2261.391357421875,
Iteration 1750: loss = 2261.3916015625,
Iteration 2000: loss = 2261.3916015625,
Iteration 2250: loss = 2261.3916015625,
Iteration 2500: loss = 2261.3916015625,
Iteration 2750: loss = 2261.3916015625,
sample size 5000 dataset: 4
KL: 0.0023998085301866023
time: 3.7478113174438477
Iteration 0: loss = 2598.90283203125,
Iteration 250: loss = 2195.510009765625,
Iteration 500: loss = 2195.509765625,
Iteration 750: loss = 2195.509765625,
Iteration 1000: loss = 2195.510009765625,
Itera

Iteration 750: loss = 1988.66162109375,
Iteration 1000: loss = 1988.66162109375,
Iteration 1250: loss = 1988.661376953125,
Iteration 1500: loss = 1988.6611328125,
Iteration 1750: loss = 1988.6612548828125,
Iteration 2000: loss = 1988.6611328125,
Iteration 2250: loss = 1988.6611328125,
Iteration 2500: loss = 1988.6611328125,
Iteration 2750: loss = 1988.6612548828125,
sample size 7000 dataset: 2
KL: 0.0008003655584560192
time: 3.846853494644165
Iteration 0: loss = 5467.00439453125,
Iteration 250: loss = 2085.2265625,
Iteration 500: loss = 2085.13720703125,
Iteration 750: loss = 2085.13720703125,
Iteration 1000: loss = 2085.13720703125,
Iteration 1250: loss = 2085.13720703125,
Iteration 1500: loss = 2085.13720703125,
Iteration 1750: loss = 2085.13720703125,
Iteration 2000: loss = 2085.13720703125,
Iteration 2250: loss = 2085.13720703125,
Iteration 2500: loss = 2085.13720703125,
Iteration 2750: loss = 2085.13720703125,
sample size 7000 dataset: 2
KL: 0.0012131372584311922
time: 3.843500375

Iteration 2750: loss = 2988.58349609375,
sample size 9000 dataset: 0
KL: 0.00021568382581003102
time: 3.9464213848114014
Iteration 0: loss = 3818.659912109375,
Iteration 250: loss = 3027.72119140625,
Iteration 500: loss = 3027.720703125,
Iteration 750: loss = 3027.720703125,
Iteration 1000: loss = 3027.720947265625,
Iteration 1250: loss = 3027.72119140625,
Iteration 1500: loss = 3027.720947265625,
Iteration 1750: loss = 3027.720947265625,
Iteration 2000: loss = 3027.7265625,
Iteration 2250: loss = 3027.720703125,
Iteration 2500: loss = 3027.725830078125,
Iteration 2750: loss = 3027.720947265625,
sample size 9000 dataset: 0
KL: 0.00031679128845990106
time: 3.940060615539551
Iteration 0: loss = 4044.02587890625,
Iteration 250: loss = 3151.116455078125,
Iteration 500: loss = 3151.1162109375,
Iteration 750: loss = 3151.116455078125,
Iteration 1000: loss = 3151.1162109375,
Iteration 1250: loss = 3151.1162109375,
Iteration 1500: loss = 3151.1162109375,
Iteration 1750: loss = 3151.1162109375,

Iteration 1250: loss = 2192.178466796875,
Iteration 1500: loss = 2192.1787109375,
Iteration 1750: loss = 2192.178466796875,
Iteration 2000: loss = 2192.17822265625,
Iteration 2250: loss = 2192.178466796875,
Iteration 2500: loss = 2192.178466796875,
Iteration 2750: loss = 2192.178466796875,
sample size 9000 dataset: 3
KL: 0.00017262260867378852
time: 3.943803071975708
Iteration 0: loss = 3357.61767578125,
Iteration 250: loss = 2254.65478515625,
Iteration 500: loss = 2254.65478515625,
Iteration 750: loss = 2254.65478515625,
Iteration 1000: loss = 2254.65478515625,
Iteration 1250: loss = 2254.65478515625,
Iteration 1500: loss = 2254.65478515625,
Iteration 1750: loss = 2254.65478515625,
Iteration 2000: loss = 2254.65478515625,
Iteration 2250: loss = 2254.65478515625,
Iteration 2500: loss = 2254.65478515625,
Iteration 2750: loss = 2254.65478515625,
sample size 9000 dataset: 3
KL: 0.0003207458191275753
time: 3.9450180530548096
Iteration 0: loss = 3558.65087890625,
Iteration 250: loss = 2362.

In [7]:
np.mean(np.array(KL_rec).reshape(5,5,5),axis =2)

array([[0.00405247, 0.00387426, 0.00397652, 0.00317777, 0.00476907],
       [0.0015281 , 0.00247286, 0.00217112, 0.00205276, 0.00226237],
       [0.00078179, 0.00099261, 0.00105275, 0.0009349 , 0.0009475 ],
       [0.00046754, 0.00053862, 0.00054751, 0.00059379, 0.00051594],
       [0.00031726, 0.00032315, 0.00030484, 0.00032362, 0.00031001]])

In [8]:
np.std(np.array(KL_rec).reshape(5,5,5),axis =2)

array([[1.96057316e-03, 1.04213716e-03, 9.88483977e-04, 1.12499706e-03,
        1.71759038e-03],
       [9.26479607e-04, 1.84274710e-03, 1.49612305e-03, 1.29818191e-03,
        1.55229224e-03],
       [4.25111235e-04, 7.21461166e-04, 9.54781335e-04, 6.01966176e-04,
        7.28291586e-04],
       [2.47195341e-04, 2.54260170e-04, 4.00193968e-04, 3.09668989e-04,
        3.00777713e-04],
       [5.55204333e-05, 1.05729785e-04, 1.06777609e-04, 9.39559287e-05,
        8.11926576e-05]])

In [9]:
np.mean(np.array(time_rec).reshape(5,5,5),axis =2)

array([[3.55941162, 3.55248065, 3.55682321, 3.56169257, 3.56783524],
       [3.6638618 , 3.66067019, 3.68013258, 3.67840796, 3.69048057],
       [3.77580338, 3.77507691, 3.77587676, 3.76283202, 3.74994574],
       [3.83133883, 3.82894883, 3.85202417, 3.81952047, 3.85225816],
       [3.93741198, 3.9184937 , 3.9437242 , 3.94562883, 3.94038901]])

In [10]:
np.set_printoptions(precision=2, suppress=True, formatter={'float_kind': '{:.2e}'.format})
KLmatrix = np.mean(np.array(KL_rec).reshape(5,5,5),axis =2)
print(KLmatrix)


[[4.05e-03 3.87e-03 3.98e-03 3.18e-03 4.77e-03]
 [1.53e-03 2.47e-03 2.17e-03 2.05e-03 2.26e-03]
 [7.82e-04 9.93e-04 1.05e-03 9.35e-04 9.48e-04]
 [4.68e-04 5.39e-04 5.48e-04 5.94e-04 5.16e-04]
 [3.17e-04 3.23e-04 3.05e-04 3.24e-04 3.10e-04]]
