In [74]:
%env MXNET_CPU_WORKER_NTHREADS=4
import scipy.io as sio
from scipy.sparse.linalg import svds, eigs
import warnings
import numpy as np
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
import mxnet as mx
from mxnet import nd, autograd, gluon
mx.random.seed(1)

env: MXNET_CPU_WORKER_NTHREADS=4


In [75]:
def load_data(file_name, train, tune):
    #loads the .mat file in file_name
    #returns a dictionary data, containing test, train and tune
    #the original data file is a dict with key 'X', containing the data
    #training is X[:,0:train-1], tune is X[:,train:train+tune-1]
    #test is the rest
    data_dict = sio.loadmat(file_name)
    data_matr = data_dict['X']
    data_matr = np.array(data_matr)
    (d,n) = data_matr.shape
    if (n <= train):
        sys.exit("train >= number of points")
    data_train = data_matr[:,0:train]
    data_train /= np.max(np.abs(data_train))
    if (n <= train+tune):
        sys.exit("train+tune >= number of points")
    data_tune = data_matr[:,train:train+tune]
    data_tune /= np.max(np.abs(data_tune))
    data_test = data_matr[:,train+tune:]
    data_test /= np.max(np.abs(data_test))
    data = {"train":data_train,"tune":data_tune,"test":data_test}
    return data

In [76]:
data = load_data('./data/xsmnist.mat', 8000, 1999)
data['test'] = data['tune']

In [77]:
data.keys()
data['test'].shape

(196, 1999)

In [78]:
data_ctx = mx.cpu()
model_ctx = mx.cpu()

In [79]:
(d,n) = data['train'].shape
k = 1
u, s, vt = svds(np.dot(data['test'],data['test'].T)/(data['test'].shape)[1],k)
obj_val = sum(s)
print (obj_val)
batch_size = 1;
train_data = mx.gluon.data.DataLoader(data['train'].T, batch_size, shuffle=True)
test_data = mx.gluon.data.DataLoader(data['test'].T, batch_size, shuffle=False)
#train_data = mx.nd.array(data['train'])
#test_data = mx.nd.array(data['test'])

1.30170732793


In [80]:
U = nd.random_normal(shape = (d,k), ctx=model_ctx)
Q,L = nd.linalg.gelqf(U.T)
U = Q.T
U.attach_grad()
def proj_matr(x_loading):
    return nd.dot(U,x_loading)

def net(X, U_matr):
    #return U
    x_loading = nd.dot(U_matr.T, X)
    return proj_matr(x_loading)

def eval_loss(yhat,y):
    return -0.5*nd.dot(yhat,y)
    #yhat = proj_matr(nd.dot(U.T,y))
    #return -0.5*nd.dot(yhat.T,y)

def SGD(U_matr, eta, compare, do_qr = False):
    #print nd.norm(compare+U_matr.grad)
    #print (eta)
    U_matr[:] = U_matr - eta*U_matr.grad
    if(do_qr):
        Q,L = nd.linalg.gelqf(U_matr.T)
        U_matr = Q.T
        
'''def evaluate_accuracy(data_iterator, net, U_matr, true_obj):
    curr_obj = 0.0
    Q,L = nd.linalg.gelqf(U_matr.T)
    U_matr = Q.T
    for i,data in enumerate(data_iterator):
        data = data.as_in_context(model_ctx).astype(dtype='float32')
        proj_data = net(data.T, U_matr)
        curr_obj += nd.norm(nd.dot(U_matr.T,data.T))*nd.norm(nd.dot(U_matr.T,data.T))
    return true_obj - curr_obj/(1.0*(i+1))'''

def evaluate_accuracy(data, U_matr, true_obj):
    Q,L = nd.linalg.gelqf(U_matr.T)
    U_matr = Q.T
    data = nd.array(data)
    obj_sqrt = nd.norm(nd.dot(U_matr.T,data))
    return true_obj - obj_sqrt*obj_sqrt/(data.shape)[1]

epochs = 10
learning_rate = .05

for e in range(epochs):
    #cumulative_loss = 0
    for i, data_t in enumerate(train_data):
        data_t = data_t.as_in_context(model_ctx).astype(dtype='float32')
        with autograd.record():
            output = net(data_t.T, U)
            loss = eval_loss(output.T,data_t.T)
        loss.backward()
        compare = nd.dot(data_t.T,nd.dot(data_t,U))
        SGD(U, learning_rate/(i+1),compare)
        #print(nd.norm(nd.dot(U.T,nd.array(u))))
        #cumulative_loss += nd.sum(loss).asscalar()
        #if(i%100==0):
            #print("Current objective: %s", evaluate_accuracy(data['train'], U, obj_val))
        
    test_accuracy = evaluate_accuracy(data['test'], U, obj_val)
    train_accuracy = evaluate_accuracy(data['train'], U, obj_val)
    print("Epoch %s. Train_err %s, Test_err %s" % (e, train_accuracy, test_accuracy))

Epoch 0. Train_err 
[ 1.177549]
<NDArray 1 @cpu(0)>, Test_err 
[ 1.17804921]
<NDArray 1 @cpu(0)>
Epoch 1. Train_err 
[ 1.04018283]
<NDArray 1 @cpu(0)>, Test_err 
[ 1.04080641]
<NDArray 1 @cpu(0)>
Epoch 2. Train_err 
[ 0.79537845]
<NDArray 1 @cpu(0)>, Test_err 
[ 0.79587299]
<NDArray 1 @cpu(0)>
Epoch 3. Train_err 
[ 0.61238325]
<NDArray 1 @cpu(0)>, Test_err 
[ 0.61175948]
<NDArray 1 @cpu(0)>
Epoch 4. Train_err 
[ 0.39780205]
<NDArray 1 @cpu(0)>, Test_err 
[ 0.39197528]
<NDArray 1 @cpu(0)>
Epoch 5. Train_err 
[ 0.31499761]
<NDArray 1 @cpu(0)>, Test_err 
[ 0.31334871]
<NDArray 1 @cpu(0)>
Epoch 6. Train_err 
[ 0.1942122]
<NDArray 1 @cpu(0)>, Test_err 
[ 0.19247532]
<NDArray 1 @cpu(0)>
Epoch 7. Train_err 
[ 0.13182962]
<NDArray 1 @cpu(0)>, Test_err 
[ 0.12810647]
<NDArray 1 @cpu(0)>
Epoch 8. Train_err 
[ 0.08859444]
<NDArray 1 @cpu(0)>, Test_err 
[ 0.07792294]
<NDArray 1 @cpu(0)>


KeyboardInterrupt: 