In [15]:
import numpy as np
import math
from scipy.linalg import dft 

In [30]:
def full_grad(z, meas, S):
    # Input: 
    #	- z, optimization iterate
    #   - meas, observed measurements
    # 	- M1, M2 precomputed matrices capturing subsampled blurred info
    # Output:
    # 	- Full gradient at z
    return S.T.dot(S.dot(z) - meas)

In [17]:
def stoch_grad(z, meas, M1, M2):
    # Input: 
    #	- z, optimization iterate
    #   - meas, observed measurements
    # 	- M1, M2 precomputed matrices capturing subsampled blurred info
    #   - ind, batch to update
    # Output:
    # 	- Full gradient at z

    # get parameters
    m = meas.size
    weights = M1.dot(z) - meas
    # output gradient as a matrix
    return M2.dot(weights) 

In [18]:
def get_batch(B, H, W, N):
    # Input:
    # B, batch size
    # H, image height
    # W, image width
    # N, list of indices
    # Output:
    # batch indices

    #Generate random indices
    tmp = np.random.permutation(N)
    k = tmp[0:B][:]
    batch = np.zeros([H,W])
    batch[k[:,0], k[:,1]] = 1

    # find nonzero batch indices
    return batch

In [2]:
# ## Test stoch grad
# x = np.random.randn(128,128)
# ind = get_batch(1, 128, 128, np.transpose(np.nonzero(x)))
# blur = np.random.randn(128, 128)
# y = np.multiply(x, blur)
# print(x.shape, ind.shape, blur.shape, y.shape)
# test = stoch_grad(x, blur, y, ind)
# print(test.shape)