In [1]:
import numpy
import os, ctypes
from scipy import integrate, LowLevelCallable
from math import exp, pi
%run "integration_utils.py"
%run "q_solver.py"

In [2]:
def pdf1(y, x):
    return 0.5 / pi * exp(-(x+1)**2/2 - y**2/2)

def pdf2(y, x):
    return 0.5 / pi * exp(-(x-1)**2/2 - y**2/2)

pdfs = [pdf1, pdf2]

In [3]:
y = numpy.matrix('1 1; 1 -1; -1 1; -1 -1', dtype=float)


lib = ctypes.CDLL(os.path.abspath('std_gaussian.so'))
lib.f.restype = ctypes.c_double
lib.f.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double))
lib.g.restype = ctypes.c_double
lib.g.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.c_void_p)
p1 = LowLevelCallable(lib.f)

In [4]:
def g_q_w_multi_source(v, y, pdfs, eps=0.693147, weight=None):
    n, m = v.shape
    if not weight:
        weight = numpy.ones(m)
    q, g = q_solver(v, weight, eps, 0.01)
    w = numpy.zeros((n, m))
    for k in range(m):
        for j in range(n):
            interval = gen_intervals(j, v[:, k].A1, y)
            integral1 = integrate_over_intervals(interval, pdfs[k])
            normpdf = lambda x1, x0: pdfs[k](x1, x0) * ((x1 - y[j,1])**2 + (x0 - y[j,0])**2)
            integral2 = integrate_over_intervals(interval, normpdf)
            w[j, k] = weight[k] * (q[j, k] - integral1)
            g += weight[k] * (integral2 - v[j, k] * integral1)
    return g, q, w

In [5]:
def subgradient_2(y, max_step, pdfs, eps=0.693147, weight=None):
    n = len(y)
    m = len(pdfs)
    v = numpy.matrix(numpy.zeros((n, m)))
    v_best = v
    g_best = - numpy.inf
    q_best = None
    for step in range(1, max_step):
        g, q, w = g_q_w_multi_source(v, y, pdfs, eps, weight)
        if g > g_best:
            v_best = v
            g_best = g
            q_best = q
        norm = numpy.linalg.norm(w)
        if norm < 1e-8:
            break
        alpha = 1 / step
        v += alpha * w / norm
        print(step, g_best, norm)
        print(v)
    return g_best, v_best, q_best

In [6]:
y = numpy.matrix('1 0; 0 0; -1 0', dtype=float)
v = numpy.matrix('0 0; 0 0; 0 0', dtype=float)

In [7]:
%%time
subgradient_2(y, 10, pdfs)

1 3.0915865953443573 0.850503562628605
[[ 0.30161681 -0.05266977]
 [ 0.49962905  0.10770447]
 [-0.80124586 -0.0550347 ]]
2 3.0915865953443573 1.3746066757235325
[[ 0.28528137 -0.29007911]
 [ 0.31962097  0.00934386]
 [-0.60490233  0.28073525]]
3 3.1009114313827384 0.6208144904223368
[[ 0.42419644 -0.28485602]
 [ 0.09929952 -0.12855304]
 [-0.52349596  0.41340906]]
4 3.2855396863607256 0.8094656158693696
[[ 0.39932777 -0.48357731]
 [ 0.09579487  0.00104826]
 [-0.49512264  0.48252905]]
5 3.2855396863607256 0.5497985790966122
[[ 0.48707205 -0.46236077]
 [-0.02067456 -0.10239176]
 [-0.46639748  0.56475253]]
6 3.3699293856916417 0.4810801277945116
[[ 0.56455086 -0.45041603]
 [-0.11576542 -0.18649332]
 [-0.44878544  0.63690935]]
7 3.3699293856916417 1.3710764091128904
[[ 0.5527711  -0.51538873]
 [-0.04036331 -0.10815972]
 [-0.5124078   0.62354846]]
8 3.39867292504499 0.4753264083973464
[[ 0.60946783 -0.50395769]
 [-0.11176006 -0.17261574]
 [-0.49770777  0.67657343]]
9 3.411808378810708 1.35307

(3.411808378810708, matrix([[ 0.59983783, -0.554126  ],
         [-0.05281135, -0.11168212],
         [-0.54702648,  0.66580812]]), array([[0.01, 0.02],
        [0.97, 0.97],
        [0.02, 0.01]]))

In [11]:
gen_intervals(0, v[:, 0].A1, y)

[(1/2, -10, 10), (10, -10, 10)]

In [12]:
gen_intervals1(0, v[:, 0].A1, y)

[(0.5, 10.0, 10.0), (10.0, -10.0, 10.0)]

In [19]:
def g_w_one_source(v, q, y, pdf):
    n = v.size
    w = numpy.zeros(n)
    g = v.dot(q)
    for j in range(n):
        interval = gen_intervals(j, v, y)
        integral1 = integrate_over_intervals(interval, pdf)
        normpdf = lambda x1, x0: pdf(x1, x0) * ((x1 - y[j,1])**2 + (x0 - y[j,0])**2)
        integral2 = integrate_over_intervals(interval, normpdf)
        w[j] = q[j] - integral1
        g += integral2 - v[j] * integral1
    return g, w

In [20]:
def subgradient_1(q, y, max_step, pdf):
    v = numpy.zeros(len(y))
    v_best = v
    g_best = - numpy.inf
    for step in range(1, max_step):
        g, w = g_w_one_source(v, q, y, pdf)
        if g > g_best:
            v_best = v
            g_best = g
        norm = numpy.linalg.norm(w)
        if norm < 1e-8:
            break
        alpha = 1 / step
        v += alpha * w / norm
        print(step, g_best, norm)
    return g_best, v_best

In [13]:
vtest = numpy.array([1.3426809, -0.26822099, -1.07445991])
vtest = v[:, 0].A1
qtest = numpy.array([0.24366897, 0.21774166, 0.53858937])

In [22]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [23]:
%lprun -f g_w_one_source g_w_one_source(vtest, qtest, y, pdf1)

In [20]:
subgradient_1(numpy.array([0.19956162, 0.23498216, 0.56545622]), y, 1000, pdf1)

1 1.5457932976721787 0.1831579821674202
2 1.6672672025848825 0.09364635457337682
3 1.6738462074163363 0.11250419957004741
4 1.6835624144726156 0.05648923037575609
5 1.6835624144726156 0.07001157358541649
6 1.6858832518499673 0.032378146629174195
7 1.6858832518499673 0.05287920732081616
8 1.6865148535162415 0.020394446388071946
9 1.6865148535162415 0.04368656858947612
10 1.6867490990280887 0.013320907410921116
11 1.6867490990280887 0.037965932104239576
12 1.6868491063991637 0.008678545353474603
13 1.6868491063991637 0.034067090846541144
14 1.6868942384303227 0.005401900188444597
15 1.6868942384303227 0.0312403941345789
16 1.6869141635503244 0.0029663292050666167
17 1.6869141635503244 0.02909743153041522
18 1.6869216112871606 0.0010850863881335035
19 1.6869216112871606 0.02741708884223373
20 1.686922596031659 0.00041158913093546567
21 1.686922596031659 0.02512683880458417
22 1.686922596031659 0.0008097266709138054
23 1.686922596031659 0.022430580469777892
24 1.686922596031659 0.000400545

192 1.686922756605942 7.078591389181247e-05
193 1.686922756605942 0.0025950552041230685
194 1.6869227583297652 5.697015481256067e-05
195 1.6869227583297652 0.0025813881867344864
196 1.686922759656829 4.3437323870224473e-05
197 1.686922759656829 0.0025679932151228364
198 1.6869227606068602 3.019718422068537e-05
199 1.6869227606068602 0.00255365458789906
200 1.6869227606068602 2.5436686598530846e-05
201 1.6869227606068602 0.0019152863585459475
202 1.6869227606068602 0.0006809223936095786
203 1.6869227606068602 0.0016475462138279208
204 1.6869227606068602 0.0008742906107669446
205 1.6869227606068602 0.0016343168675751574
206 1.6869227606068602 0.0008628465733527549
207 1.6869227606068602 0.0016221922406022072
208 1.6869227606068602 0.000850839571483921
209 1.6869227606068602 0.0016103050074335382
210 1.6869227606068602 0.0008390612597101259
211 1.6869227606068602 0.0015986437715697686
212 1.6869227606068602 0.0008275057898827548
213 1.6869227606068602 0.001587202021096371
214 1.6869227606

375 1.6869227606068602 0.0010662468793893732
376 1.6869227606068602 0.0002988675280412807
377 1.6869227606068602 0.0010626166756984756
378 1.6869227606068602 0.000295255749971813
379 1.6869227606068602 0.001059024833697749
380 1.6869227606068602 0.00029168204266624
381 1.6869227606068602 0.0010554707485123536
382 1.6869227606068602 0.0002881458073412883
383 1.6869227606068602 0.0010519538279156646
384 1.6869227606068602 0.0002846464577052779
385 1.6869227606068602 0.0010484734920034279
386 1.6869227606068602 0.0002811834196334426
387 1.6869227606068602 0.0010450291728731022
388 1.6869227606068602 0.0002777561308555619
389 1.6869227606068602 0.0010416203143165592
390 1.6869227606068602 0.00027436404064955
391 1.6869227606068602 0.001038246371519456
392 1.6869227606068602 0.00027100660954828984
393 1.6869227606068602 0.0010349068107700409
394 1.6869227606068602 0.00026768330905085593
395 1.6869227606068602 0.0010316011091801133
396 1.6869227606068602 0.0002643936213488335
397 1.686922760

557 1.6869227606068602 0.0008427450873744342
558 1.6869227606068602 7.631773150439239e-05
559 1.6869227606068602 0.0008410981495890235
560 1.6869227606068602 7.46764311546979e-05
561 1.6869227606068602 0.0008394629646556583
562 1.6869227606068602 7.304682347454875e-05
563 1.6869227606068602 0.0008378394072449937
564 1.6869227606068602 7.142878397810053e-05
565 1.6869227606068602 0.0008362273537184053
566 1.6869227606068602 6.982218988533093e-05
567 1.6869227606068602 0.0008346266823492585
568 1.6869227606068602 6.82269202561492e-05
569 1.6869227606068602 0.0008330372727920583
570 1.6869227606068602 6.664285562913698e-05
571 1.6869227606068602 0.0008314590070728911
572 1.6869227606068602 6.50698786720675e-05
573 1.6869227606068602 0.0008298917674105706
574 1.6869227606068602 6.350787272324522e-05
575 1.6869227606068602 0.0008283354408172866
576 1.6869227606068602 6.195672487671699e-05
577 1.6869227606068602 0.0008267899086990678
578 1.6869227606068602 6.041631914666844e-05
579 1.6869227

739 1.6869227606068602 0.0005772786333891134
740 1.6869227606068602 0.00011545640962051324
741 1.6869227606068602 0.0005763425598289619
742 1.6869227606068602 0.00011452275326882717
743 1.6869227606068602 0.0005754115289351147
744 1.6869227606068602 0.00011359412009008045
745 1.6869227606068602 0.0005744855000687038
746 1.6869227606068602 0.00011267046965391728
747 1.6869227606068602 0.0005735644330276619
748 1.6869227606068602 0.00011175176196484964
749 1.6869227606068602 0.0005726482880386447
750 1.6869227606068602 0.00011083795745443092
751 1.6869227606068602 0.0005717370257535554
752 1.6869227606068602 0.00010992901697221683
753 1.6869227606068602 0.0005708306072406757
754 1.6869227606068602 0.0001090249017872419
755 1.6869227606068602 0.0005699289939825275
756 1.6869227606068602 0.00010812557357676479
757 1.6869227606068602 0.0005690321478689095
758 1.6869227606068602 0.00010723099442331294
759 1.6869227606068602 0.000568140031190563
760 1.6869227606068602 0.0001063411268078532
76

921 1.6869227606068602 0.0005087526553077202
922 1.6869227606068602 4.709013677800814e-05
923 1.6869227606068602 0.0005081498171217029
924 1.6869227606068602 4.648854838630969e-05
925 1.6869227606068602 0.00050754958762048
926 1.6869227606068602 4.588956043341569e-05
927 1.6869227606068602 0.0005069519485561425
928 1.6869227606068602 4.5293155230073175e-05
929 1.6869227606068602 0.0005063568855863678
930 1.6869227606068602 4.469931778730945e-05
931 1.6869227606068602 0.0005057643773592735
932 1.6869227606068602 4.410802879365618e-05
933 1.6869227606068602 0.0005051744152952006
934 1.6869227606068602 4.351927844148793e-05
935 1.6869227606068602 0.0005045869637988065
936 1.6869227606068602 4.293304304373729e-05
937 1.6869227606068602 0.0005040020276183199
938 1.6869227606068602 4.234934172471491e-05
939 1.6869227606068602 0.0005034194812046242
940 1.6869227606068602 4.1768174369969897e-05
941 1.6869227606068602 0.0005028392406576282
942 1.6869227606068602 4.118985771678348e-05
943 1.6869

(1.6869227606068602, array([ 1.09926014, -0.21456888, -0.88469125]))

In [82]:
%%timeit
Point2D(0.24366897, 0.21774166)

31 ms ± 999 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [32]:
%%timeit
integrate_over_intervals(interv, pdf1)

3.11 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
coeff = find_coefficients(0, vtest, y)

In [100]:
%lprun -f find_intersections1 find_intersections1(coeff)

In [134]:
%timeit gen_intervals(1, numpy.array([0.24366897, 0.21774166, 0.53858937]), y)

71.1 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [133]:
%timeit gen_intervals1(1, numpy.array([0.24366897, 0.21774166, 0.53858937]), y)

197 µs ± 4.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [68]:
%lprun -f g_w_one_source g_w_one_source(vtest, qtest, y, pdf1)

In [32]:
from scipy.spatial import ConvexHull

def find_intersections1(coefficients):
    res = []
    for i in range(len(coefficients)):
        L1 = coefficients[i]
        for j in range(i):
            L2 = coefficients[j]
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = - L1[2] * L2[1] + L1[1] * L2[2]
            Dy = - L1[0] * L2[2] + L1[2] * L2[0]
            if D != 0:
                valid = True
                x, y = Dx / D, Dy / D
                for a, b, c in coefficients:
                    if a * x + b * y + c > 1e-8:
                        valid = False
                        break
                if valid:
                    res.append((x, y))
    res1 = numpy.zeros((len(res), 2))
    for i in range(len(res)):
        res1[i, 0] = res[i][0]
        res1[i, 1] = res[i][1]
    return res1


def gen_intervals1(j, v, y):
    coefficients = find_coefficients(j, v, y)
    filtered_points = find_intersections1(coefficients)
    if len(filtered_points) < 3:
        return None
    hull = ConvexHull(filtered_points)
    x_points = sorted(set(p[0] for p in filtered_points))
    intervals = []
    for x in x_points:
        tp = []
        for i1 in range(len(hull.vertices)):
            i2 = (i1 + 1) % len(hull.vertices)
            x1, x2 = filtered_points[i1, 0], filtered_points[i2, 0]
            y1, y2 = filtered_points[i1, 1], filtered_points[i2, 1]
            if x1 == x:
                tp.append(y1)
            elif (x - x1) * (x - x2) < 0:
                tp.append(y1 + (y2 - y1) * (x - x1) / (x2 - x1))
        tp.sort()
        intervals.append((x, tp[0], tp[-1]))
    return intervals

In [34]:
gen_intervals(0, vtest, y)

[(1/2, -10, 10), (10, -10, 10)]

In [35]:
gen_intervals1(0, vtest, y)

[(0.5, -10.0, 10.0), (10.0, -10.0, 10.0)]

In [22]:
find_intersections(coeff)

[Point2D(10, -10), Point2D(10, 10), Point2D(1/2, -10), Point2D(1/2, 10)]

In [24]:
points = find_intersections1(coeff)

In [25]:

hull = ConvexHull(points)

In [30]:
hull.vertices

array([2, 0, 1, 3], dtype=int32)

In [109]:
x_points = sorted(set(p[0] for p in points))

In [110]:
x_points

[-0.339576145, 0.487036345]

In [124]:
gen_intervals(2, numpy.array([0.24366897, 0.21774166, 0.53858937]), y)

[(-10, -10, 10), (-67915229/200000000, -10, 10)]

In [125]:
gen_intervals1(2, numpy.array([0.24366897, 0.21774166, 0.53858937]), y)

[(-10.0, 10.0, 10.0), (-0.339576145, -10.0, 10.0)]