In [1]:
import numpy as np
import math
from scipy.linalg import dft 

In [107]:
def full_grad(z, MASK, meas):
    # Input: 
    # z, optimization iterate
    # MASK, observed fourier measurements
    # meas, measurements = F(X) + w
    # Output:
    # Full gradient at z

    # real grad
    H, W = z.shape[:2]
    res = np.fft.fft2(z) * MASK
    index = np.nonzero(MASK)
    res[index] = res[index] - meas[index]
    return np.real(np.fft.ifft2(res))

In [108]:
def stoch_grad(z, IND, meas):
    # Input:
    # z, optimization iterate
    # meas, measurements = F(X) + w
    # batch_index, indices to update
    # Output:
    # stochastic gradient at z for measurements in B
    H, W = z.shape[:2]
    # batch gradient update
    res = IND * (np.fft.fft2(z) - meas)
    return np.real(np.fft.ifft2(res))

In [109]:
def get_batch(B, MASK):
    H, W = MASK.shape[:2]
    batch = np.zeros((1,H*W))
    tmp = np.linspace(0, H*W - 1, H*W)
    one_locs = tmp[np.matrix.flatten(MASK) == 1].astype(int)
    batch_locs = np.random.choice(one_locs, B, replace=False)
    batch[0, batch_locs] = 1

    # find nonzero batch indices
    return batch.reshape(H,W).astype(int)

    # Input:
    # B, batch size
    # H, image height
    # W, image width
    # N, list of indices
    # Output:
    # batch indices

    #Generate random indices

In [110]:
def test_get_batch():
    batch_size = 4
    mask = np.random.choice([0, 1], size=(8,8), p=[1 - .5, .5])	# generate random mask
    print(np.count_nonzero(mask))
    batch_mask = get_batch(batch_size, mask)
    print(mask, batch_mask)

In [112]:
def test_grads():
    X = np.random.randn(32, 32)
    Y = np.random.randn(32, 32)
    
    batchsize = 4
    prob = .5
    ## Make measurements
    mask = np.random.choice([0, 1], size=(32,32), p=[1 - prob, prob])	# generate random mask  
    
    ind = get_batch(np.count_nonzero(mask), mask)
    ind1 = get_batch(batchsize, mask)
    
    full = full_grad(X, mask, Y)
    stoch_full = stoch_grad(X, ind, Y)
    print(full, stoch_full)