### Preparation Instructions
* Download mnist8m from 
https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist8m.bz2
* Decompress the first 2 million samples and split them into two chunks with 1 million samples each.
```
bzip2 -dc mnist8m.bz2 | head -n 2000000 | split -l 1000000 -
```
* Execute all the cells below to produce policies h, pi, and reward predictor q and save the relevant parameters to an npz file

In [1]:
%reload_ext autoreload
%reload_ext line_profiler
%autoreload 2

import numpy as np
from joblib import Memory
from sklearn.datasets import load_svmlight_file
from sklearn.kernel_approximation import RBFSampler
from scipy.spatial.distance import pdist
import tqdm

In [2]:
mem = Memory("./mycache")

@mem.cache
def get_data():
    data = load_svmlight_file("xaa")
    return data[0], data[1]

X, y = get_data()

________________________________________________________________________________
[Memory] Calling __main__-C%3A-Users-nikosk.REDMOND-Desktop-projects-pycharm-mope-%3Cipython-input-921d442ec4ad%3E.get_data...
get_data()
_______________________________________________________get_data - 164.7s, 2.7min


In [3]:
x = X.todense().A.astype(np.float32)
del X
x /= 255.
xx = x.T@x

In [4]:
d,v = np.linalg.eigh(xx)
perm = np.argsort(-d)
vv = v[:,perm[:50]]

In [5]:
xv = x@vv
del x

In [6]:
class RFF():
    def __init__(self, x, gamma=None, n_components=None, seed=90210):
        n, d = x.shape
        self.d = d
        self.c = n_components if n_components is not None else int(np.round(np.sqrt(n)))
        self.seed = seed
        self.rs = np.random.RandomState(self.seed)
        if gamma is None:
            n_samples = min(n, 1000)
            perm = self.rs.permutation(n)[:n_samples]
            gamma = np.median(pdist(x[perm]).ravel())**(-2) 
        print(gamma)
        self.gamma = gamma
        self.omega = self.rs.normal(size=(d, self.c)).astype(np.float32)
        self.omega *= np.sqrt(2 * self.gamma)
        self.b = self.rs.uniform(0, 2*np.pi, size=self.c).astype(np.float32)
        self.normalizer = np.sqrt(2./self.c)
    
    def __call__(self,x):
        projection = np.dot(x, self.omega) + self.b
        np.cos(projection, projection)
        projection *= self.normalizer
        return projection

rff = RFF(xv)
rffxv = rff(xv)

0.012065845514763526


In [7]:
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(sparse=False).fit(y.reshape(-1,1))
y = ohe.transform(y.reshape(-1,1))

In [8]:
rxx=np.dot(rffxv.T,rffxv)

In [9]:
from scipy.linalg import cho_factor, cho_solve
rcf = cho_factor(rxx+1e-12*np.eye(rxx.shape[0]))

In [10]:
def agd(x,y,c=None):
    nx, d = x.shape
    ny, k = y.shape 
    assert nx == ny
    n = nx
    from scipy.linalg import cho_factor, cho_solve
    if c is None:
        xx=np.dot(x.T,x)
        c = cho_factor(xx+1e-12*np.eye(xx.shape[0]))
    li, nextli = (1,1)
    u = np.zeros((d,k), dtype=np.float32)
    w = np.zeros((d,k), dtype=np.float32)
    t = np.zeros((d,k), dtype=np.float32)
    p = np.zeros((n,k), dtype=np.float32)
    z = np.zeros(n, dtype=np.float32)
    g = np.zeros((d,k), dtype=np.float32)
    for i in range(100):
        np.dot(x, u, out=p)
        np.max(p, axis=1, out=z)
        p -= z.reshape(-1,1)
        np.exp(p, out=p)
        np.sum(p, axis=1, out=z)
        p /= z.reshape(-1,1)
        p -= y
        np.dot(x.T, p, out=g)
        gnorm = np.sum(g*g)/(n*d*k)  
        print(i, gnorm)
        if gnorm < 1e-4:
            break
        gi = (1-li) / nextli
        t[:] = gi * w[:]                    
        np.subtract(u, cho_solve(c,g), out=w)
        np.add((1-gi) * w, t, out=u)
        li=nextli
        nextli=(1+np.sqrt(1+4*li*li))/2;
    return w

In [11]:
w = agd(rffxv,y,c=rcf)

0 0.8397011968
1 0.6272526336
2 0.3982569472
3 0.2001848832
4 0.0937900416
5 0.0461673056
6 0.0245917536
7 0.0141343504
8 0.0086822848
9 0.0056444348
10 0.0038504268
11 0.0027363636
12 0.0020139444
13 0.0015277179
14 0.0011897801
15 0.0009482662
16 0.0007713892
17 0.0006389983
18 0.0005379338
19 0.0004593725
20 0.000397266875
21 0.000347388725
22 0.0003067337
23 0.0002731358
24 0.000245013425
25 0.00022119805
26 0.0002008180375
27 0.0001832147125
28 0.0001678860875
29 0.0001544444625
30 0.000142587025
31 0.00013207325
32 0.0001227101875
33 0.00011434135
34 0.0001068359625
35 0.000100084975
36 9.399680625e-05


In [12]:
def agdilr(x,y,c=None):
    nx, d = x.shape
    ny, k = y.shape 
    assert nx == ny
    n = nx
    from scipy.linalg import cho_factor, cho_solve
    from scipy.special import expit
    if c is None:
        xx=np.dot(x.T,x)
        c = cho_factor(xx+1e-12*np.eye(xx.shape[0]))
    li, nextli = (1,1)
    u = np.zeros((d,k), dtype=np.float32)
    w = np.zeros((d,k), dtype=np.float32)
    t = np.zeros((d,k), dtype=np.float32)
    p = np.zeros((n,k), dtype=np.float32)
    z = np.zeros(n, dtype=np.float32)
    g = np.zeros((d,k), dtype=np.float32)
    for i in range(100):
        np.dot(x, u, out=p)
        expit(p, p)
        p -= y
        np.dot(x.T, p, out=g)
        gnorm = np.sum(g*g)/(n*d*k)  
        print(i, gnorm)
        if gnorm < 1e-4:
            break
        gi = (1-li) / nextli
        t[:] = gi * w[:]                    
        np.subtract(u, cho_solve(c,g), out=w)
        np.add((1-gi) * w, t, out=u)
        li=nextli
        nextli=(1+np.sqrt(1+4*li*li))/2;
    return w

In [13]:
q = agdilr(rffxv,y,c=rcf)

0 61.2215422976
1 34.9249339392
2 21.1837599744
3 12.1602449408
4 6.9052243968
5 3.9821623296
6 2.363305984
7 1.4517590016
8 0.9247705088
9 0.6106792448
10 0.4174430208
11 0.294748416
12 0.2144181632
13 0.1602550784
14 0.1227006336
15 0.0959672576
16 0.0764628672
17 0.0619050432
18 0.0508107584
19 0.0421957376
20 0.0353932768
21 0.0299428608
22 0.0255201408
23 0.0218923456
24 0.0188891968
25 0.0163838352
26 0.0142800096
27 0.0125034736
28 0.0109959824
29 0.0097112656
30 0.0086120728
31 0.0076681112
32 0.006854528
33 0.006150772
34 0.0055398092
35 0.0050074096
36 0.0045416876
37 0.0041326784
38 0.0037720372
39 0.003452746
40 0.0031689024
41 0.0029155568
42 0.0026885412
43 0.0024843476
44 0.0023000292
45 0.0021330844
46 0.0019814108
47 0.0018432186
48 0.0017169848
49 0.0016014062
50 0.0014953648
51 0.0013978998
52 0.0013081731
53 0.0012254534
54 0.0011490939
55 0.001078521
56 0.0010132271
57 0.0009527599
58 0.0008967056
59 0.000844701
60 0.0007964062
61 0.00075151815
62 0.0007097567
63 0

In [14]:
wlin = agd(xv,y)

0 1603.230629888
1 1263.609053184
2 950.191980544
3 644.200726528
4 419.374497792
5 277.317189632
6 189.474635776
7 133.558771712
8 96.842203136
9 72.175239168
10 55.305777152
11 43.580964864
12 35.2965632
13 29.335881728
14 24.955836416
15 21.654812672
16 19.091929088
17 17.03514624
18 15.328274432
19 13.868142592
20 12.588895232
21 11.44962048
22 10.425668608
23 9.502051328
24 8.668997632
25 7.919024128
26 7.245551616
27 6.641988608
28 6.10170368
29 5.617979904
30 5.18426368
31 4.794210816
32 4.44194048
33 4.122111488
34 3.830002432
35 3.561542144
36 3.313305088
37 3.082521856
38 2.867015424
39 2.665138944
40 2.4757248
41 2.297963264
42 2.131347584
43 1.97556928
44 1.830432128
45 1.695756928
46 1.571324544
47 1.456835456
48 1.351885568
49 1.255970048
50 1.168477696
51 1.088727424
52 1.01599488
53 0.949546304
54 0.888651776
55 0.832633216
56 0.780859264
57 0.732776512
58 0.687901312
59 0.645833408
60 0.606246528
61 0.568886848
62 0.533563616
63 0.500131424
64 0.468502656
65 0.43860726

In [15]:
np.savez('mnist_models.npz',v=vv,omega=rff.omega,b=rff.b,z=rff.normalizer,wlin=wlin,w=w,q=q)