In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sc
from sklearn.preprocessing import normalize
import mat73
import scipy.sparse as sparse
import cvxpy as cp
import mosek
import pywt

In [2]:
data = mat73.loadmat('./data/data3_440.mat')

A = data['trainimages'][::800,:] # had to downsample images to make CVXPY not eat up my whole RAM
b = data['testimages'][::800,:] # had to downsample images to make CVXPY not eat up my whole RAM

train_id = data['trainids']
test_id = data['testids']

In [3]:
# 2-Norm
def norm2(x):
    return np.linalg.norm(x,2)

# Normalized Error
def normalized_error(x,x_pre):
    return norm2(x-x_pre)/norm2(x)

# Modified FISTA
def FISTAmod(B,b,w,gamma,p,q,r,n,m):

    # Forward Backward Operator with proxy
    def gradF():
        # print('gradF start')
        grad = B.T.dot(B.dot(y)-b)
        # print('gradF done')
        return grad
    
    def expr():
        # print('expr start')
        expr1 = y-gamma*gradF()
        # print('expr done')
        return expr1

    def fbo():
        expr1 = expr()
        # print('proxy start')
        x_hat = cp.Variable((n+m,1))
        prob = cp.Problem(cp.Minimize(0.5*cp.norm(x_hat-expr1,2)**2 + lambd*(np.linalg.norm(x,1))))
        prob.solve()
        prox = x_hat.value
        # print('proxy done')
        return prox
    
    tolerance = 10e-6
    lambd = 1
    t = 1
    x = w
    y = w
    count = 0
    max_iter = 10
    while count<max_iter:
        print('iteration: ',count)
        x_old = x
        x = fbo()
        t_old = t
        t = (p+np.sqrt(q+r*t_old**2))/2
        a = (t_old - 1)/t
        y = x+a*(x-x_old)
        res = np.linalg.norm(x_old-x,'fro')
        if res/n<tolerance:
            break
        count+=1
    x_sparse = x
    return(x_sparse)

In [4]:
# parameters
p = 0.1
q = 0.1
r = 1
m = A.shape[0]
n = A.shape[1]
A1 = A.T
b1 = b.T
gamma = 1/np.linalg.norm(A)**2
x_0 = np.zeros((n,1))
e = np.random.normal(0,1,(m,1))
e = normalize(e, axis=0, norm='max')
w = np.vstack((x_0,e))
B = sparse.hstack((A1,sparse.eye(n)))

In [None]:
total_err = [] 
for i in range(b.shape[1]):
    print('test_img: ',i)
    w_hat = FISTAmod(B,b1[i],w,gamma,p,q,r,n,m)
    total_err.append(normalized_error(w,w_hat))

test_img:  0
iteration:  0
iteration:  1
iteration:  2
test_img:  1
iteration:  0
iteration:  1
iteration:  2
test_img:  2
iteration:  0
iteration:  1
iteration:  2
test_img:  3
iteration:  0
iteration:  1
iteration:  2
test_img:  4
iteration:  0


In [None]:
# CVXPY takes way too long on full images and I have to downsample images significantly to make it work. 