In [1]:
import numpy as np
import math
from scipy.linalg import dft 

In [2]:
def full_grad(z, blur, meas):
    # Input: 
    #	- z, optimization iterate
    #   - blur, blurring kernel
    # 	- meas, measurements = F(X) + w
    # Output:
    # 	- Full gradient at z

    # get parameters
    H, W = z.shape[:2] 				# height and width of orig image
    
    # vectorize matrices
    vblur = np.matrix.flatten(blur)
    vz = np.matrix.flatten(z)
    
    # compute convolution b * x
    tmp = np.fft.ifft(np.multiply(np.fft.fft(vblur),np.fft.fft(vz)))
    # b*x - y
    res = tmp - np.matrix.flatten(meas)
    # b*(b*x - y)
    grad = np.fft.ifft(np.multiply(np.fft.fft(vblur),np.fft.fft(res)))
    # output gradient as a matrix
    return np.reshape(grad, [H,W])

In [3]:
def stoch_grad(z, blur, meas, ind):
    # Input: 
    #	- z, optimization iterate
    #   - blur, blurring kernel
    # 	- meas, measurements = F(X) + w
    # Output:
    # 	- Full gradient at z

    # get parameters
    H, W = z.shape[:2] 				# height and width of orig image
    
    # vectorize matrices
    vblur = np.matrix.flatten(blur)
    vz = np.matrix.flatten(z)
    
    # compute convolution b * x
    tmp = np.fft.ifft(np.multiply(np.fft.fft(vblur),np.fft.fft(vz)))
    # b*x - y
    res = tmp - np.matrix.flatten(meas)
    # get indices
    res = res * np.matrix.flatten(ind)
    # b*(b*x - y)
    grad = np.fft.ifft(np.multiply(np.fft.fft(vblur),np.fft.fft(res)))
    # output gradient as a matrix
    return np.reshape(grad, [H,W])

In [4]:
def get_batch(B, H, W, N):
    # Input:
    # B, batch size
    # H, image height
    # W, image width
    # N, list of indices
    # Output:
    # batch indices

    #Generate random indices
    tmp = np.random.permutation(N)
    k = tmp[0:B][:]
    batch = np.zeros([H,W])
    batch[k[:,0], k[:,1]] = 1

    # find nonzero batch indices
    return batch

In [None]:
# ## Test stoch grad
# x = np.random.randn(128,128)
# ind = get_batch(1, 128, 128, np.transpose(np.nonzero(x)))
# blur = np.random.randn(128, 128)
# y = np.multiply(x, blur)
# print(x.shape, ind.shape, blur.shape, y.shape)
# test = stoch_grad(x, blur, y, ind)
# print(test.shape)