In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from scipy.linalg import circulant

from lightonml.projections.sklearn import OPUMap

In [None]:
def cconj(A):
    return np.conj(A.T)

In [None]:
def compute_FFT_inverse(x1_first_column, x2_first_column, Y):
    
    t0 = time.time()
    
    x1_first_column = x1_first_column.reshape([-1,1])
    x2_first_column = x2_first_column.reshape([-1,1])
    
    lam_1 = np.fft.fft(x1_first_column, axis=0).astype('complex64')
    lam_2 = np.fft.fft(x2_first_column, axis=0).astype('complex64')
    
    d1 = (0.5*1/lam_1).astype('complex64')
    d2 = (0.5*1/lam_2).astype('complex64')
    
    N = int(Y.shape[1]/2)
    Y1_star = cconj(Y[:,:N].astype('complex64')).astype('complex64')
    Y2_star = cconj(Y[:,N:].astype('complex64')).astype('complex64')
    
    first_part = np.conj(d1).reshape([-1,1]) * np.fft.fft(Y1_star, axis=0).astype('complex64')
    second_part = np.conj(d2).reshape([-1,1]) * np.fft.fft(Y2_star, axis=0).astype('complex64')

    A = np.conj(np.fft.ifft(first_part + second_part, axis=0).astype('complex64')).T
    
    print ('FFT: ', time.time() - t0)
    
    return A.astype('complex64')
    
def build_half_X(N, circ_N, nonzero_ind):
    X_p = [0.5, 0.5]
    x1_first_column = np.random.choice([0,1], size=circ_N, p=X_p)
    X1 = circulant(x1_first_column)
    x2_first_column = np.random.choice([0,1], size=circ_N, p=X_p)
    X2 = circulant(x2_first_column)

    X = np.zeros([N, 2*circ_N])
    
    X[nonzero_ind,:circ_N] = X1
    X[nonzero_ind,circ_N:] = X2
    
    return X, x1_first_column, x2_first_column

def build_X(N, circ_N):
    ind = np.random.choice(N, size=2*(N-circ_N), replace=False)
    first_ind = ind[:int(0.5*len(ind))]
    second_ind = ind[int(0.5*len(ind)):]
    
    ind_common = np.arange(N)[~np.isin(np.arange(N), ind)]
    first_nonzero = np.arange(N)[~np.isin(np.arange(N), first_ind)]
    second_nonzero = np.arange(N)[~np.isin(np.arange(N), second_ind)]
    
    X1, x1, x2 = build_half_X(N, circ_N, first_nonzero)
    X2, x3, x4 = build_half_X(N, circ_N, second_nonzero)
        
    X = np.zeros([N, 4*circ_N])
    X[:,:2*circ_N] = X1
    X[:,2*circ_N:] = X2
    
    first_ind = np.arange(N)[np.isin(np.arange(N), first_ind)]
    second_ind = np.arange(N)[np.isin(np.arange(N), second_ind)]
    
    return X, first_ind, second_ind, ind_common, x1, x2, x3, x4

In [None]:
def make_anchors(X, number_of_anchors):
    X_sum = np.sum(X.copy(), axis=0)
    X_sum[X_sum>0] = 1
    
    anchors = np.zeros([number_of_anchors, X.shape[1]])
    
    anchor_p = [0.85,0.15]
    
    anchors[0] = np.random.choice([0,1], size=X.shape[1], p=anchor_p) + X_sum
    
    for i in range(1, number_of_anchors):
        anchors[i] = np.random.choice([0,1], size=X.shape[1], p=anchor_p) + anchors[i-1]

    anchors[anchors>0] = 1
    anchors = anchors[::-1] # for my convenience
    
    return anchors

def interfere_with_anchors(n, x, anchors):
    interfered = anchors - x
    interfered = np.vstack((interfered, x)) # x with zero (zero is less than x so subtract the other way)
    
    anchors = np.vstack((anchors, np.zeros(n))) # zero is an anchor too
    
    for i in range(anchors.shape[0]-1):
        diffs = anchors[i] - anchors[1+i:]
        interfered = np.vstack((interfered, diffs))
    
    return interfered

def interfere_with_anchors_simplified(n, x, anchors):
    interfered = anchors - x
    interfered = np.vstack((interfered, x)) # x with zero (zero is less than x so subtract the other way)
        
    return interfered

def get_OPU_measurements(opu_input, num_rand_proj):
    mapping = OPUMap(n_components=num_rand_proj, verbose_level=1)
    mapping.opu.device.exposure_us = 110
    y = mapping.transform(opu_input.astype('uint8'))
    
    return y

def make_D_ensembles(y, number_of_anchors):
    num_elements = int((number_of_anchors+2)* (number_of_anchors+1) * 0.5)
    
    trials = y.shape[1]
    dim = number_of_anchors+2
    all_D_oracles_x = np.zeros([trials, dim, dim])
    
    ind = np.triu_indices(all_D_oracles_x[0].shape[0], k=1)
    for i in range(trials):
        data = y[0:num_elements,i]
        all_D_oracles_x[i][ind] = data
        all_D_oracles_x[i] += all_D_oracles_x[i].T
        
    return all_D_oracles_x

def do_MDS(D, number_of_anchors):
    m = number_of_anchors
    J = np.eye(m + 2) - 1. / (m + 2) * np.ones((m + 2, m + 2))
    G = -1/2 * np.dot(J, D).dot(J)
    U, s, VT = np.linalg.svd(G)
    Z_est_R2 = np.dot(np.diag(np.sqrt(s[:2])), VT[:2, :])
    Z_est_cpx = Z_est_R2[0, :] + 1j*Z_est_R2[1, :]
    
    # translate the origin back at (0, 0)
    Z_est_cpx -= Z_est_cpx[m + 1]
    
    return Z_est_cpx

def ortho_procrustes(fixed, modify):
    fixed = np.vstack ((np.real(fixed[1:]), np.imag(fixed[1:])))
    modify = np.vstack ((np.real(modify), np.imag(modify)))
    original = modify.copy()
    modify = modify[:,1:]
    fixed_mean = (np.mean(fixed, axis=1)).reshape([-1,1])
    fixed -= fixed_mean
    modify_mean = (np.mean(modify, axis=1)).reshape([-1,1])
    modify -= modify_mean
    M = fixed @ modify.T
    u, s, vh = np.linalg.svd(M)
    R = u @ vh
    original = R @ (original - modify_mean @ np.ones(
            [1, original.shape[1]])) + fixed_mean@np.ones([1, original.shape[1]])
    return original[0] + 1j*original[1]

def compute_anchors_pinv_norms(anchor_positions):
    anchor_positions = anchor_positions.reshape([1,-1])
    anchor_square_norm = np.abs(anchor_positions)**2
    anchors = np.vstack((np.real(anchor_positions)[0], np.imag(anchor_positions)[0]))

    anchor_square_norm = anchor_square_norm.reshape([-1,1])
    anchors = anchors.T
    
    T = np.hstack((-anchors, 0.5*np.ones([anchors.shape[0],1])))
    Tpinv = np.linalg.pinv(T)
    
    return Tpinv, anchor_square_norm

def get_and_save_opu_input(M, N, num_signals, num_anchors, num_data, X, anchors1_ind, anchors2_ind):
    X_a1 = X[anchors1_ind]
    X_a2 = X[anchors2_ind]
    data = X[-num_data:]
    
    anchors1 = make_anchors(X_a1, num_anchors)
    anchors2 = make_anchors(X_a2, num_anchors)
        
    opu_input1_anchors = interfere_with_anchors(N, X_a1[0], anchors1)
    anchors_input_size = opu_input1_anchors.shape[0]
    opu_input1 = np.zeros([anchors_input_size + ((num_anchors+1)* X_a1.shape[0]), N]).astype('uint8')
    opu_input1[:anchors_input_size] = opu_input1_anchors.copy()
    del opu_input1_anchors
    for i in tqdm(range(X_a1.shape[0])):
        x = X_a1[i]
        opu_input1[anchors_input_size + i*(num_anchors+1): anchors_input_size + (i+1)*(num_anchors+1)] = interfere_with_anchors_simplified(N, x, anchors1)
        
    opu_input2_anchors = interfere_with_anchors(N, X_a2[0], anchors2)
    anchors_input_size = opu_input2_anchors.shape[0]
    opu_input2 = np.zeros([anchors_input_size + ((num_anchors+1)* X_a2.shape[0]), N]).astype('uint8')
    opu_input2[:anchors_input_size] = opu_input2_anchors.copy()
    del opu_input2_anchors
    for i in tqdm(range(X_a2.shape[0])):
        x = X_a2[i]
        opu_input2[anchors_input_size + i*(num_anchors+1): anchors_input_size + (i+1)*(num_anchors+1)] = interfere_with_anchors_simplified(N, x, anchors2)
          
    opu_input = np.vstack((opu_input1, opu_input2)).astype('uint8')
    opu_input = np.vstack((opu_input, data)).astype('uint8')
    measurement_split = opu_input1.shape[0]
    
    del opu_input1, opu_input2, X, X_a1, X_a2
    
    return opu_input, measurement_split, anchors_input_size

def get_data(M, opu_input, measurement_split):
    print('Getting OPU data')
    y_quant = get_OPU_measurements(opu_input, M)
    print('Got OPU data')

    return y_quant[:measurement_split], y_quant[measurement_split:2*measurement_split], y_quant[-num_data:]

def do_MPR(M, N, num_signals, num_anchors, num_data, X, anchors1_ind, anchors2_ind, ind_common):
    opu_input, measurement_split, anchors_input_size = get_and_save_opu_input(M, N, num_signals, num_anchors, num_data, X, anchors1_ind, anchors2_ind)
    
    y_quant_a1, y_quant_a2, y_quant_data = get_data(M, opu_input, measurement_split)
    
    del opu_input
    
    print ('STARTING TIMING')
    t0 = time.time()
    dim = num_anchors+2
    ind = np.triu_indices(dim, k=1)
    
    ### a1 points
    results_a1 = np.zeros([M, len(anchors1_ind)]).astype('complex64')
    for i in tqdm(range(M)):
        D = np.zeros([dim, dim]).astype('uint8')
        D[ind] = y_quant_a1[0:anchors_input_size,i]
        D += D.T
        anchor_positions_a1 = do_MDS(D, num_anchors)
        
        ### below reshape because of how numpy does reshapes. Transpose later.
        dist_to_anchors = y_quant_a1[anchors_input_size:,i].reshape([num_signals, num_anchors+1])
        Tpinv, anchor_a1_norms = compute_anchors_pinv_norms(anchor_positions_a1[1:])
        B = 0.5*(dist_to_anchors.T - anchor_a1_norms)
        results = Tpinv[:2] @ B
        results_a1[i] = results[0] + 1j*results[1]

    del y_quant_a1, anchor_positions_a1
    
    ### a2 points    
    results_a2 = np.zeros([M, len(anchors2_ind)]).astype('complex64')
    for i in tqdm(range(M)):
        D = np.zeros([dim, dim]).astype('uint8')
        D[ind] = y_quant_a2[0:anchors_input_size,i]
        D += D.T
        anchor_positions_a2 = do_MDS(D, num_anchors)
        
        ### below reshape because of how numpy does reshapes. Transpose later.
        dist_to_anchors = y_quant_a2[anchors_input_size:,i].reshape([num_signals, num_anchors+1])
        Tpinv, anchor_a2_norms = compute_anchors_pinv_norms(anchor_positions_a2[1:])
        B = 0.5*(dist_to_anchors.T - anchor_a2_norms)
        results = Tpinv[:2] @ B
        results_a2[i] = results[0] + 1j*results[1]
               
    del y_quant_a2, anchor_positions_a2
    
    t1 = time.time()
    time_taken = t1-t0

    return results_a1, results_a2, y_quant_data.copy(), time_taken

def stitch_measurements(N, M, Y1_0, Y2_0, anchors1_ind, anchors2_ind, ind_common, x1,x2,x3,x4, circ_N):
    A1 = np.zeros([M, N]).astype('complex64')
    A2 = np.zeros([M, N]).astype('complex64')
    
    # https://stackoverflow.com/questions/23726026/finding-which-rows-have-all-elements-as-zeros-in-a-matrix-with-numpy
    A1_ind = np.arange(N)[~np.isin(np.arange(N), anchors1_ind)]
    A2_ind = np.arange(N)[~np.isin(np.arange(N), anchors2_ind)]
    
    A1[:,A1_ind] = compute_FFT_inverse(x1, x2, Y1_0)
    A2[:,A2_ind] = compute_FFT_inverse(x3, x4, Y2_0)
    
    del Y1_0, Y2_0

    P1 = np.angle(A1[:, ind_common] / A2[:, ind_common]).astype('float32')
    P2 = np.angle(A1[:, ind_common] / np.conj(A2[:, ind_common])).astype('float32')
    mean_P1 = np.mean(P1, axis=1)
    mean_P2 = np.mean(P2, axis=1)
    P1 = np.std(P1, axis=1)
    P2 = np.std(P2, axis=1)
    mask1 = P1 < P2
    mask2 = np.invert(mask1)
    A2[mask2] = np.conj(A2[mask2])
    phases = (mean_P1*mask1 + mean_P2*mask2).reshape([M,1])
    phases = np.exp(1j*phases)
    A2 = phases * A2

    A1_ind = anchors1_ind # indices which are all zero
    A1[:,A1_ind] = A2[:,A1_ind]

    del A2
    
    return A1

def get_A(M, N, circ_N, num_signals, num_anchors, num_data, root_folder):
    X, first_ind, second_ind, ind_common, x1, x2, x3, x4 = build_X(N, circ_N)

    ### CHOOSE THE TYPE OF IMAGE
    data = np.random.choice([0,1], size=[N, num_data], p=[0.5,0.5])
#     data = np.load('images/uiuc_bw.npy').reshape([N, num_data])
#     data = np.load('images/uiuc_bw_64x64.npy').reshape([N, num_data])
#     data = np.load('images/uiuc_bw_96x96.npy').reshape([N, num_data])
#     data = np.load('images/uiuc_bw_128x128.npy').reshape([N, num_data])
    
    X = np.hstack((X,data)).astype('float32')
    
    anchors1_ind = np.arange(num_signals)
    anchors2_ind = np.arange(num_signals, 2*num_signals)
    
    Y1_ind, Y2_ind, y_quant_data, time_taken1 = do_MPR(M, N, num_signals, num_anchors, num_data, X.T, anchors1_ind, anchors2_ind, ind_common)
    
    print ('Done MPR')
    
    del X
    
    t0 = time.time()
    
    A0_hat = stitch_measurements(N, M, Y1_ind, Y2_ind, first_ind, second_ind, ind_common, x1, x2, x3, x4, circ_N)
    
    t1 = time.time()
    total_time = time_taken1 + (t1-t0)
    print ('TOTAL TIME: ', total_time)

    bins=np.histogram(np.hstack((np.real(A0_hat).flatten(), np.imag(A0_hat).flatten())), bins=100)[1]
    plt.figure(figsize=(6,3))
    plt.subplot(1,1,1)
    plt.title('Recovered TM histogram')
    plt.hist(np.hstack((np.real(A0_hat).flatten(), np.imag(A0_hat).flatten())), bins=bins)
    plt.tight_layout()
    plt.savefig(root_folder + 'recovered_histogram.pdf')
    
    return A0_hat, data.reshape([-1, num_data]), y_quant_data.reshape([-1,num_data])
    

In [None]:
np.random.seed(1)
    
N = 32**2
M = int(32*N)
circ_N = int(0.5*1.5*N) # for 32 and 64
# circ_N = int(0.5*1.125*N) # for 96
# circ_N = int(0.5*1.03125*N) # for 128
num_signals = 2*circ_N
num_anchors = 20
# num_anchors = 15 # for 128
num_data = 1

print ('N = ', N)
print ('M = ', M)
print ('num_signals = ', num_signals)
print ('num_anchors = ', num_anchors)

root_folder = ''

A_hat, x, y = get_A(M, N, circ_N, num_signals, num_anchors, num_data, root_folder)

print ('Got A')
np.save(root_folder + 'A.npy', A_hat)
np.save(root_folder + 'x_true.npy', x)
np.save(root_folder + 'y.npy', y)
