In [1]:
import numpy
import os, ctypes
from scipy import integrate, LowLevelCallable
%load_ext line_profiler
%run "integration_utils.py"
%run "q_solver.py"

In [2]:
centeroids = numpy.array([[-1, 0], [1, 0]])

def pdf1(y, x):
    return 0.5 / pi * exp(-(x+1)**2/2 - y**2/2)

def pdf2(y, x):
    return 0.5 / pi * exp(-(x-1)**2/2 - y**2/2)

pdfs_t = [pdf1, pdf2]

In [17]:
lib = ctypes.CDLL(os.path.abspath('std_gaussian.so'))
lib.f.restype = ctypes.c_double
lib.f.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.c_void_p)
lib.g.restype = ctypes.c_double
lib.g.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.c_void_p)
lib.gx.restype = ctypes.c_double
lib.gx.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.c_void_p)
lib.gy.restype = ctypes.c_double
lib.gy.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.c_void_p)



def gen_pdf(label):
    user_data = ctypes.cast(
        ctypes.pointer((ctypes.c_double*2)(centeroids[label, 0], centeroids[label, 1])), 
        ctypes.c_void_p
    )
    return LowLevelCallable(lib.f, user_data)


def gen_norm_pdf(label, yj):
    user_data = ctypes.cast(
        ctypes.pointer((ctypes.c_double*4)(centeroids[label, 0], centeroids[label, 1], yj[0], yj[1])), 
        ctypes.c_void_p
    )
    return LowLevelCallable(lib.g, user_data)


def gen_x_pdf(label):
    user_data = ctypes.cast(
        ctypes.pointer((ctypes.c_double*2)(centeroids[label, 0], centeroids[label, 1])), 
        ctypes.c_void_p
    )
    return LowLevelCallable(lib.gx, user_data)


def gen_y_pdf(label):
    user_data = ctypes.cast(
        ctypes.pointer((ctypes.c_double*2)(centeroids[label, 0], centeroids[label, 1])), 
        ctypes.c_void_p
    )
    return LowLevelCallable(lib.gy, user_data)


def g_w_one_source(v, q, y, k):
    n = v.size
    w = numpy.zeros(n)
    g = v.dot(q)
    for j in range(n):
        interval = gen_intervals(j, v, y)
        pdf = gen_pdf(k)
        integral1 = integrate_over_intervals(interval, pdf)
        normpdf = gen_norm_pdf(k, y[j])
        integral2 = integrate_over_intervals(interval, normpdf)
        w[j] = q[j] - integral1
        g += integral2 - v[j] * integral1
    return g, w


def q_from_v(v, y):
    n, m = v.shape
    q = numpy.zeros(v.shape)
    for k in range(m):
        w = g_w_one_source(v[:, k], numpy.zeros(n), y, k)[1]
        for j in range(n):
            q[j, k] = - w[j]
    return q


def subgradient_1(q, y, max_step, k):
    v = numpy.zeros(len(y))
    v_best = v
    g_best = - numpy.inf
    for step in range(1, max_step):
        g, w = g_w_one_source(v, q, y, k)
        if g > g_best:
            v_best = v
            g_best = g
        norm = numpy.linalg.norm(w)
        if norm < 1e-8:
            break
        alpha = 1 / step
        v += alpha * w / norm
        print(step, g_best, norm)
    return g_best, v_best


def g_w_multi_source(v, y, eps=0.693147, weight=None):
    n, m = v.shape
    if not weight:
        weight = numpy.ones(m)
    q, g = q_solver(v, weight, eps)
    w = numpy.zeros((n, m))
    for k in range(m):
        for j in range(n):
            interval = gen_intervals(j, v[:, k], y)
            pdf = gen_pdf(k)
            integral1 = integrate_over_intervals(interval, pdf)
            normpdf = gen_norm_pdf(k, y[j])
            integral2 = integrate_over_intervals(interval, normpdf)
            w[j, k] = weight[k] * (q[j, k] - integral1)
            g += weight[k] * (integral2 - v[j, k] * integral1)
    return g, w


def subgradient_2(y, max_step, eps=0.693147, weight=None, log=False):
    n = len(y)
    m = len(centeroids)
    v = numpy.zeros((n, m))
    v_best = v
    g_best = - numpy.inf
    for step in range(1, max_step + 1):
        g, w = g_w_multi_source(v, y, eps, weight)
        if g > g_best:
            v_best = v
            g_best = g
        norm = numpy.linalg.norm(w)
        if norm < 1e-8:
            break
        alpha = 1 / step
        v += alpha * w / norm
        if log:
            print(step, g_best, norm)
            print(v)
    return g_best, v_best


def next_y_from_v(v, y, weight=None):
    n, m = v.shape
    res = numpy.zeros(y.shape)
    if not weight:
        weight = numpy.ones(m)
    for j in range(n):
        integral_x = numpy.zeros(m)
        integral_y = numpy.zeros(m)
        integral_1 = numpy.zeros(m)
        for k in range(m):
            interval = gen_intervals(j, v[:, k], y)
            integral_x[k] = integrate_over_intervals(interval, gen_x_pdf(k))
            integral_y[k] = integrate_over_intervals(interval, gen_y_pdf(k))
            integral_1[k] = integrate_over_intervals(interval, gen_pdf(k))
        res[j, 0] = integral_x.dot(weight) / integral_1.dot(weight)
        res[j, 1] = integral_y.dot(weight) / integral_1.dot(weight)
    return res

In [28]:
history1 = history

In [38]:
y = numpy.array([[1, 1], [0, 1], [-1, 2]], dtype='float')

In [39]:
history = []

def lloyd_max(y0, max_step, eps=0.693147, weight=None):
    n = len(y0)
    m = len(centeroids)
    y = y0
    for step in range(max_step):
        subgradient_step = 100000 if step == max_step - 1 else 10000
        g, v = subgradient_2(y, subgradient_step, eps, weight)
        y = next_y_from_v(v, y, weight)
        print(step, g)
        print(y)
        history.append((g, y))
    return y

In [40]:
lloyd_max(y, 100)

0 6.487522967259352
[[ 1.28425596 -0.00671767]
 [-0.61842594 -0.24322852]
 [-1.74071195  1.11354887]]
1 3.3126502457289506
[[ 1.21295435  0.06071151]
 [-0.70463359 -0.48054454]
 [-1.55020801  0.98703233]]
2 2.9163596623662222
[[ 1.21365071  0.36940155]
 [-0.74407589 -0.67848475]
 [-1.45001673  0.90196481]]
3 2.81942850889651
[[ 1.23705294  0.39171237]
 [-0.6707907  -0.82204504]
 [-1.38820274  0.84529245]]
4 2.8819265462206127
[[ 1.22994474  0.35434714]
 [-0.60198791 -0.89002206]
 [-1.33096825  0.7802762 ]]
5 2.8978933771836837
[[ 1.23001783  0.33275429]
 [-0.5608728  -0.93729109]
 [-1.29443065  0.73167446]]
6 2.911063870199581
[[ 1.2285137   0.32004228]
 [-0.52676796 -0.97102416]
 [-1.27833098  0.69618009]]
7 2.9257918023674407
[[ 1.22756923  0.31190887]
 [-0.50057029 -0.99429292]
 [-1.26680396  0.67024446]]
8 2.938190651049596
[[ 1.22629338  0.30676486]
 [-0.47957816 -1.01106057]
 [-1.25831411  0.65133474]]
9 2.9477051129283933
[[ 1.22586191  0.30326178]
 [-0.46277573 -1.02331456]
 [-

KeyboardInterrupt: 

In [16]:
next_y_from_v(v, y)

array([[ 1.10723877, -0.51130627],
       [ 0.93766535,  1.20647982],
       [-1.10803701, -0.10171452]])

In [37]:
%%time
g_best, v = subgradient_2(y, 1000)

CPU times: user 19.2 s, sys: 156 ms, total: 19.3 s
Wall time: 19.8 s


In [12]:
v1 = v[:, 0]
v2 = v[:, 1]

In [15]:
g_w_one_source(v1, -numpy.array([0, 0, 0]), y, 0)

(2.1151445627892356, array([-0.16818423, -0.11243459, -0.71938118]))

In [105]:
g_w_one_source(v2, -numpy.array([-0.43552972, -0.24347129, -0.32099899]), y, 1)

(1.556305804627386, array([-4.83802998e-11,  3.68660216e-09, -3.63028685e-09]))

In [38]:
vtest = numpy.array([0.68320418, 0.50182977, -1.18503395])
# vtest = v[:, 0].A1
qtest = numpy.array([-0.25212491, -0.21446891, -0.53340619])

In [42]:
g_w_one_source(vtest, numpy.zeros(3), y, 0)

(2.1174017510681487, array([-0.32121089, -0.04052876, -0.63826035]))

In [13]:
g_w_one_source(vtest, qtest, y, pdf1)

0.4259433582073181 5.883275666020487
0.3067293059967917 0.236169608973349
1.0387099487136147 1.0387099487136147


(1.7713826049706172,
 array([-4.19964724e-09,  2.73372172e-09,  1.46592682e-09]))

In [109]:
subgradient_1(-numpy.array([-0.21464961, -0.11449747, -0.67085292]), y, 1000, 0)

1 1.5732428797134006 0.1792966575672009
2 1.6935196270822575 0.0584391236477842
3 1.7072320072227436 0.017823717691813266
4 1.7072320072227436 0.03571461744681597
5 1.7081490402287345 0.005880343673818787
6 1.7081490402287345 0.027044946239934912
7 1.7082450187763565 0.0013429000460139563
8 1.7082450187763565 0.016800366453745764
9 1.7082450187763565 0.0037471161598532873
10 1.7082450187763565 0.013113280160898019
11 1.7082450187763565 0.0036426038675776338
12 1.7082450187763565 0.011501626566711854
13 1.7082450187763565 0.0025931266837296444
14 1.7082450187763565 0.010373441764770284
15 1.7082450187763565 0.0017089536983449886
16 1.7082450187763565 0.009541983970256655
17 1.7082495707073937 0.0010414646935942048
18 1.7082495707073937 0.008846335480269479
19 1.7082515946727967 0.0006059879771790155
20 1.7082515946727967 0.007826129871014966
21 1.7082515946727967 0.0009127259276010853
22 1.7082515946727967 0.005786805529145124
23 1.7082515946727967 0.0017273443344876008
24 1.70825159467

191 1.7082527837124273 0.000173631454945469
192 1.7082527837124273 0.0007181830372778312
193 1.7082527837124273 0.0001689830119292761
194 1.7082527837124273 0.0007135898462543811
195 1.7082527837124273 0.00016443020104791694
196 1.7082527837124273 0.00070909059608768
197 1.7082527837124273 0.00015997010122945708
198 1.7082527837124273 0.0007046824339990139
199 1.7082527837124273 0.00015559990915000522
200 1.7082527837124273 0.0007003626215993668
201 1.7082527837124273 0.00015131693338559473
202 1.7082527837124273 0.0006961285291554232
203 1.7082527837124273 0.00014711858887070525
204 1.7082527837124273 0.0006919776302855339
205 1.7082527837124273 0.00014300239170344985
206 1.7082527837124273 0.0006879074969122621
207 1.7082527837124273 0.00013896595423971907
208 1.7082527837124273 0.00068391579453964
209 1.7082527837124273 0.00013500698047895049
210 1.7082527837124273 0.0006800002777794175
211 1.7082527837124273 0.00013112326170682833
212 1.7082527837124273 0.0006761587861458829
213 1.

378 1.7082527893294903 0.0003766422868604484
379 1.7082527893294903 7.407837072946788e-05
380 1.7082527893294903 0.0003754535142142279
381 1.7082527893294903 7.289493207400681e-05
382 1.7082527893294903 0.0003742772032816535
383 1.7082527893294903 7.172387157229615e-05
384 1.7082527893294903 0.0003731131591333798
385 1.7082527893294903 7.056499603098283e-05
386 1.7082527893294903 0.0003719611908845771
387 1.7082527893294903 6.941811625756972e-05
388 1.7082527893294903 0.0003708211115908017
389 1.7082527893294903 6.828304695428292e-05
390 1.7082527893294903 0.00036969273814598365
391 1.7082527893294903 6.715960662166845e-05
392 1.7082527893294903 0.00036857589118581395
393 1.7082527893294903 6.604761745950543e-05
394 1.7082527893294903 0.0003674703949903074
395 1.7082527893294903 6.494690527455287e-05
396 1.7082527893294903 0.00036637607739467456
397 1.7082527893294903 6.38572993877906e-05
398 1.7082527893294903 0.0003652927696978158
399 1.7082527893294903 6.277863254810401e-05
400 1.70

568 1.7082527900641717 0.00024427413310315014
569 1.7082527900641717 5.570697191944292e-05
570 1.7082527900641717 0.0002437470624751666
571 1.7082527900641717 5.518147779563923e-05
572 1.7082527900641717 0.00024322368039472636
573 1.7082527900641717 5.465965571140551e-05
574 1.7082527900641717 0.0002427039482770259
575 1.7082527900641717 5.4141467311794935e-05
576 1.7082527900641717 0.00024218782807361154
577 1.7082527900641717 5.362687477411559e-05
578 1.7082527900641717 0.00024167528226294366
579 1.7082527900641717 5.3115840798389194e-05
580 1.7082527900641717 0.00024116627384139288
581 1.7082527900641717 5.2608328599439226e-05
582 1.7082527900641717 0.00024066076631457645
583 1.7082527900641717 5.2104301896737564e-05
584 1.7082527900641717 0.0002401587236882256
585 1.7082527900641717 5.160372490645991e-05
586 1.7082527900641717 0.00023966011045984038
587 1.7082527900641717 5.110656233340403e-05
588 1.7082527900641717 0.00023916489161010813
589 1.7082527900641717 5.0612779361773986e-

783 1.7082527905130758 4.39004276379824e-05
784 1.7082527905130758 0.00017365168200221022
785 1.7082527905130758 4.369108120308626e-05
786 1.7082527905130758 0.00017336169946791024
787 1.7082527905130758 4.342913543023084e-05
788 1.7082527905130758 0.0001730836098254794
789 1.7082527905130758 4.315739722854055e-05
790 1.7082527905130758 0.00017280893497969335
791 1.7082527905130758 4.2884623403779016e-05
792 1.7082527905130758 0.0001725361882176765
793 1.7082527905130758 4.261272825296388e-05
794 1.7082527905130758 0.00017226488741107763
795 1.7082527905130758 4.2342081842279504e-05
796 1.7082527905130758 0.0001719949896304121
797 1.7082527905130758 4.207277287073238e-05
798 1.7082527905130758 0.00017172644344540225
799 1.7082527905130758 4.180480568531725e-05
800 1.7082527905130758 0.00017145924550743025
801 1.7082527905130758 4.1538177074101964e-05
802 1.7082527905130758 0.0001711933795158993
803 1.7082527905130758 4.1272876872367225e-05
804 1.7082527905130758 0.0001709288378526934
8

971 1.7082527905130758 2.289159958235571e-05
972 1.7082527905130758 0.00015259549760656815
973 1.7082527905130758 2.2711029861730737e-05
974 1.7082527905130758 0.0001524153538188872
975 1.7082527905130758 2.2531201375591846e-05
976 1.7082527905130758 0.0001522359486475131
977 1.7082527905130758 2.235210956909192e-05
978 1.7082527905130758 0.0001520572775587679
979 1.7082527905130758 2.2173749924787724e-05
980 1.7082527905130758 0.00015187933605639555
981 1.7082527905130758 2.1996117962321545e-05
982 1.7082527905130758 0.00015170211968059836
983 1.7082527905130758 2.1819209236862356e-05
984 1.7082527905130758 0.00015152562400826724
985 1.7082527905130758 2.164301934081487e-05
986 1.7082527905130758 0.00015134984465163648
987 1.7082527905130758 2.1467543901783807e-05
988 1.7082527905130758 0.00015117477725866844
989 1.7082527905130758 2.129277858273573e-05
990 1.7082527905130758 0.0001510004175130252
991 1.7082527905130758 2.1118719081680712e-05
992 1.7082527905130758 0.00015082676113227

(1.7082527905130758, array([ 0.68249648,  0.50207506, -1.18456981]))

In [82]:
%%timeit
Point2D(0.24366897, 0.21774166)

31 ms ± 999 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [32]:
%%timeit
integrate_over_intervals(interv, pdf1)

3.11 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
coeff = find_coefficients(0, vtest, y)

In [91]:
%lprun -f g_q_w_multi_source subgradient_2(y, 100, pdfs)

1 2.8872918302854838 1.0415520304813857
[[ 0.21470768  0.08152096]
 [ 0.56462887  0.05347567]
 [-0.77933655 -0.13499663]]
2 2.8872918302854838 1.1393321601536315
[[ 0.15242981 -0.17949667]
 [ 0.50798248 -0.06758496]
 [-0.66041229  0.24708163]]
3 2.9883470931661695 0.43379262818778164
[[ 0.30615355 -0.08731889]
 [ 0.41304722 -0.29011962]
 [-0.71920077  0.37743852]]
4 3.022955952401207 0.795598852309584
[[ 0.25786571 -0.27252382]
 [ 0.48250907 -0.15602197]
 [-0.74037478  0.42854579]]
5 3.060420553728194 0.4195443273299524
[[ 0.34649232 -0.20896512]
 [ 0.42535431 -0.29227289]
 [-0.77184663  0.50123801]]
6 3.1110685613179285 0.764526887339136
[[ 0.31169307 -0.3312537 ]
 [ 0.47338955 -0.20238431]
 [-0.78508262  0.53363802]]
7 3.1110685613179285 0.41373398594846844
[[ 0.3732698  -0.2818729 ]
 [ 0.43253594 -0.30076429]
 [-0.80580574  0.5826372 ]]
8 3.1516932590508677 0.38431014095756144
[[ 0.42845882 -0.24207875]
 [ 0.39567222 -0.38588812]
 [-0.82413105  0.62796687]]
9 3.1516932590508677 0.77

72 3.2550703365289686 0.70927262180728
[[ 0.57310967 -0.58353184]
 [ 0.42396411 -0.50342401]
 [-0.99707378  1.08695585]]
73 3.2556600209637394 0.37306018174521466
[[ 0.57819218 -0.57695271]
 [ 0.41995876 -0.51302587]
 [-0.99815094  1.08997858]]
74 3.2556600209637394 0.7087950769042602
[[ 0.57446099 -0.58629068]
 [ 0.42424369 -0.50525612]
 [-0.99870468  1.0915468 ]]
75 3.2562247295611577 0.37308335025707573
[[ 0.57939955 -0.57986813]
 [ 0.42034537 -0.51460074]
 [-0.99974491  1.09446887]]
76 3.2562247295611577 0.7083369124238003
[[ 0.57576039 -0.58895517]
 [ 0.42451997 -0.50703139]
 [-1.00028036  1.09598656]]
77 3.2567661536474577 0.3731067078860149
[[ 0.58056278 -0.58268173]
 [ 0.42072309 -0.51613209]
 [-1.00128587  1.09881381]]
78 3.2567661536474577 0.7078968525518793
[[ 0.57701115 -0.59153085]
 [ 0.42479294 -0.50875297]
 [-1.00180409  1.10028382]]
79 3.257285829440118 0.37313019744783643
[[ 0.58168453 -0.58539956]
 [ 0.42109235 -0.51762208]
 [-1.00277688  1.10302164]]
80 3.25728582944

In [134]:
%timeit gen_intervals(1, numpy.array([0.24366897, 0.21774166, 0.53858937]), y)

71.1 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [133]:
%timeit gen_intervals1(1, numpy.array([0.24366897, 0.21774166, 0.53858937]), y)

197 µs ± 4.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [68]:
%lprun -f g_w_one_source g_w_one_source(vtest, qtest, y, pdf1)

In [43]:
ii = gen_intervals(1, numpy.array([0.24366897, 0.21774166, 0.53858937]), y)

In [84]:
%timeit integrate_over_intervals(ii, pdf1)

1.45 ms ± 29.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [88]:
%timeit integrate_over_intervals(ii, pdfs[0])

692 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [87]:
%timeit integrate_over_intervals(ii, pdf1c)

828 µs ± 33.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
