In [1]:
import numpy as np
import math
from scipy.linalg import dft 

# Compute a full gradient

In [2]:
def full_grad(z, MASK, meas):
    # Input: 
    # z, optimization iterate
    # MASK, observed fourier measurements
    # meas, measurements = F(X) + w
    # Output:
    # Full gradient at z

    # real grad
    res = np.fft.fft2(z) * MASK
    index = np.nonzero(MASK)
    res[index] = res[index] - meas[index]
    return np.fft.ifft2(res)

# Compute a stochastic gradient

In [1]:
def stoch_grad(z, MASK, meas, ind):
    # Input:
    # z, optimization iterate
    # meas, measurements = F(X) + w
    # batch_index, indices to update
    # Output:
    # stochastic gradient at z for measurements in B

    # batch gradient update
    tmp = np.zeros(z.shape[:2]).astype(float)
    res = np.fft.fft2(z) * MASK
    index = np.nonzero(MASK)
    res[index] = res[index] - meas[index]
    tmp = ind*res
    return np.fft.ifft2(tmp)

# Get a batch of measurements to run stochastic gradients

In [4]:
def get_batch(B, H, W, N):
    # Input:
    # B, batch size
    # H, image height
    # W, image width
    # N, list of indices
    # Output:
    # batch indices

    #Generate random indices
    tmp = np.random.permutation(N)
    k = tmp[0:B][:]
    batch = np.zeros([H,W])
    batch[k[:,0], k[:,1]] = 1

    # find nonzero batch indices
    return batch