In [8]:
import numpy
import os, ctypes
from scipy import integrate, LowLevelCallable
%run "integration_utils.py"

In [9]:
y = numpy.matrix('1 1; 1 -1; -1 1; -1 -1', dtype=float)
q = numpy.array([0.5, 0.2, 0.2, 0.1])

lib = ctypes.CDLL(os.path.abspath('std_gaussian.so'))
lib.f.restype = ctypes.c_double
lib.f.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double))
p1 = LowLevelCallable(lib.f)
lib.g.restype = ctypes.c_double
lib.g.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.c_void_p)
p3 = []
for j in range(len(y)):
    user_data = ctypes.cast(ctypes.pointer((ctypes.c_double*2)(y[j,0], y[j,1])), ctypes.c_void_p)
    p3.append(LowLevelCallable(lib.g, user_data))

In [10]:
def g_w_one_source(v, q, y, pdf, normpdf):
    n = v.size
    w = numpy.zeros(n)
    g = v.dot(q)
    for j in range(n):
        interval = gen_intervals(j, v, y)
        integral1 = integrate_over_intervals(interval, pdf)
        integral2 = integrate_over_intervals(interval, normpdf[j])
        w[j] = q[j] - integral1
        g += integral2 - v[j] * integral1
    return g, w

In [11]:
def subgradient_1(q, y, max_step, pdf, normpdf):
    v = numpy.zeros(len(y))
    v_best = v
    g_best = - numpy.inf
    for step in range(1, max_step):
        g, w = g_w_one_source(v, q, y, pdf, normpdf)
        if g > g_best:
            v_best = v
            g_best = g
        norm = numpy.linalg.norm(w)
        if norm < 1e-8:
            break
        alpha = 1 / step
        v += alpha * w / norm
        print(step, g_best, norm)
    return g_best, v_best

In [14]:
%prun subgradient_1(q, y, 20, p1, p3)

1 0.8084617567885385 0.30000000000000004
2 1.0506394672302237 0.18554490033301
3 1.1294821099256276 0.1306433520855199
4 1.1672289852279216 0.09640029308135561
5 1.1882958871641554 0.07251941775094317
6 1.2009937799259633 0.0547180362372859
7 1.2089393570182305 0.04079825281729026
8 1.2139526120640993 0.029495984554884343
9 1.2170441350088117 0.020038483433008916
10 1.2188179426201458 0.011937856313180756
11 1.2196569691913752 0.00487894193817128
12 1.2198160217609908 0.0013498893165808936
13 1.2198160217609908 0.004353249074084919
14 1.2198232621536684 0.0009219216532886724
15 1.2198232621536684 0.004184782834895293
16 1.2198232621536684 0.0028774917579033484
17 1.2198232621536684 0.008919506327937735
18 1.2198232621536684 0.0022471049057198084
19 1.2198232621536684 0.008295139147164142
 