In [26]:
import torch
import pickle
from scipy.interpolate import interp1d
import numpy as np
import math
import os
import threading
from multiprocessing import Process

PACKAGEDIR = '/home/xingguang/Documents/turb/TurbulenceSim_P2S/turb-sim-git-v2-test/turbSimv4'

params = {'D': 0.2, # Aperture diameter (meters)
            'wvl': 550e-9, # Wavelength (meters)
            'f': 0.2034, # Focal length (meters)
            'L': 1000, # Propagation distance (meters)
            'num_zern': 36, # Number of Zernikes to be used (36 only)
            'resize_bases': 2, # resizing the blur kernels (1 or 2 seem to work best)
            'width': 4.5, # width of image in meters
            'Cn2' : 1e-14 * torch.ones(1), # turbulence strength (just change the first number, otherwise ask)
            'pad_ratio' : 0, # how much did you pad the input images?
            'epochs' : 10, # how many loops of the sim? Total output images will be epochs * batch_size
            'output_path' : '.',
            'batch_size' : 1, # batch size on the input images
            'temp_corr' : 0.95, # a value in (0, 1), ~1 is a lot of correlation, ~0 is none
          'img_shape':(512,512)
            }

a = z_space(params)

def logical_xor(str1, str2):
    return bool(str1) ^ bool(str2)


def z_space(params):
    width = params['width']
    Cn2 = params['Cn2']
    L = params['L']
    wvl = params['wvl']
    D = params['D']

    N1, N2 = params['img_shape'][1], params['img_shape'][0] # Need to change to image sizes
    zern_space = torch.zeros(N2, N1, params['num_zern'])
    dx = width/N1
    if type(Cn2) == float:
        M = 1
    else:
        M = len(Cn2)
    xx,yy = torch.meshgrid(torch.linspace(-N1/2, N1/2, N1)*dx,
                            torch.linspace(-N2/2, N2/2, N2)*dx, indexing='xy')
    s_arr = torch.sqrt(xx**2 + yy**2)

    for i in torch.arange(2, 37):
        zern_space[:,:,i-1], r0 = corr_zern(i, i, s_arr, M, Cn2, L, wvl, D)
    #print(field_is_not_alias(zern_space, s_arr))
    return zern_space, r0


def tak_indicator(i, j, mi, mj):  
    if (mi != 0) and (mj != 0) and ((i + j) % 2 == 0):
        return 1
    if (mi != 0) and (mj != 0) and ((i + j) % 2 != 0):
        return 2
    if logical_xor((mi == 0) and (j % 2 == 0), (mj == 0) and (i % 2 == 0)):
        return 3
    if logical_xor((mi == 0) and (j % 2 != 0), (mj == 0) and (i % 2 != 0)):
        return 4
    if (mi == 0) and (mj == 0):
        return 5


def fij(i, j, s_arr):
    if (i == 2 or i == 3) and (j == 2 or j == 3):
        smax = 100
    else:
        smax = 10
        
    if len(s_arr.shape) == 1:
        thetamat = 0
    else:
        s1, s2 = torch.arange(0, s_arr.shape[1]), torch.arange(0, s_arr.shape[0])
        [x, y] = torch.meshgrid(s1, s2, indexing='xy')
        s = torch.sqrt((x - s_arr.shape[1] / 2) ** 2 + (y - s_arr.shape[0] / 2) ** 2)
        xhat = (x - s_arr.shape[1] / 2)
        yhat = (y - s_arr.shape[0] / 2)
        thetamat = torch.atan2(yhat, xhat)
        
    ni, mi = nollToZernInd(i)
    nj, mj = nollToZernInd(j)
    nplus, nminus, mplus, mminus = ni+nj, ni-nj, mi+mj, mi-mj
    
    h = tak_indicator(i, j, mi, mj)
    
    s_arr[s_arr >= smax*0.9] = smax*0.9

    if h == 1:
        #h1_l1_int = interp1d(s_arr, int1)
        with open(os.path.join(PACKAGEDIR, 'precomputed/fij_ints/int1_smax_{}_{}_{}'.format(smax, i, j)), 'rb') as file_handle:
            func1 = pickle.load(file_handle)

        #h1_l2_int = interp1d(s_arr, int2)
        with open(os.path.join(PACKAGEDIR, 'precomputed/fij_ints/int2_smax_{}_{}_{}'.format(smax, i, j)), 'rb') as file_handle:
            func2 = pickle.load(file_handle)
            
        if (i % 2 == 0 and j % 2 == 0):
            mult = -1
        else:
            mult = 1
        
        return mult * (-1)**((nplus - mminus)/2) * torch.cos(mplus * thetamat) * \
            func1(s_arr) + (-1)**((nplus + 2 * mi + np.abs(mminus))/2) * \
            torch.cos(mminus * thetamat) * func2(s_arr)
            
    if h == 2:
        #h2_l1_int = interp1d(s_arr, int1)
        with open(os.path.join(PACKAGEDIR, 'precomputed/fij_ints/int1_smax_{}_{}_{}'.format(smax, i, j)), 'rb') as file_handle:
            func1 = pickle.load(file_handle)

        #h2_l2_int = interp1d(s_arr, int2)
        with open(os.path.join(PACKAGEDIR, 'precomputed/fij_ints/int2_smax_{}_{}_{}'.format(smax, i, j)), 'rb') as file_handle:
            func2 = pickle.load(file_handle)
            
        return (-1)**((nplus - mminus)/2) * torch.sin(mplus * thetamat) * \
            func1(s_arr) + (-1)**((nplus + 2 * mi + np.abs(mminus))/2) * \
            torch.sin(mminus * thetamat) * func2(s_arr)
            
    if h == 3:
        #h3_l1_int = interp1d(s_arr, int1)
        with open(os.path.join(PACKAGEDIR, 'precomputed/fij_ints/int1_smax_{}_{}_{}'.format(smax, i, j)), 'rb') as file_handle:
            func1 = pickle.load(file_handle)
            
        return (-1)**((nplus - mminus)/2) * np.sqrt(2) * torch.cos(mplus * thetamat) * \
            func1(s_arr)
            
    if h == 4:
        #h4_l1_int = interp1d(s_arr, int1)
        with open(os.path.join(PACKAGEDIR, 'precomputed/fij_ints/int1_smax_{}_{}_{}'.format(smax, i, j)), 'rb') as file_handle:
            func1 = pickle.load(file_handle)
            
        return (-1)**((nplus - mminus)/2) * np.sqrt(2) * torch.sin(mplus * thetamat) * \
            func1(s_arr)
        
    if h == 5:
        #h5_l1_int = interp1d(s_arr, int1)
        with open(os.path.join(PACKAGEDIR, 'precomputed/fij_ints/int1_smax_{}_{}_{}'.format(smax, i, j)), 'rb') as file_handle:
            func1 = pickle.load(file_handle)
            
        return (-1)**((nplus - mminus)/2) * func1(s_arr)
            

def corr_zern(i, j, s_arr, M, Cn2, L, wvl, D):
    tilt_coeff_final = (wvl / (D / 2)) ** 2
    our_coeff_final = 0.0096932 * (2*np.pi/wvl) ** 2 * 2 ** (14 / 3) * np.pi ** (8 / 3) * (D / 2) ** (5 / 3) / np.pi ** 2
    ni, mi = nollToZernInd(i)
    nj, mj = nollToZernInd(j)
    nplus, nminus, mplus, mminus = ni+nj, ni-nj, mi+mj, mi-mj
    
    c_out = s_arr * 0
    theta = 0
    r0 = 0

    for l in range(M):
        m = l + 1
        temp = fij(i, j, s_arr  * m / (D*(M + 1 - m)))
        cn2_i = Cn2[l]
        r0 += (0.423 * (2*math.pi / wvl)**2 * cn2_i * L / M) * (l/M) ** (5/3)
        c_out += temp * cn2_i * ((M + 1 - m) / (M + 1)) ** (5 / 3)
    #c_out = our_coeff_final * c_out * (L / M) * tilt_coeff_final * 2
    c_out = our_coeff_final * c_out * (L / M) * tilt_coeff_final * math.sqrt((ni+1) *(ni+1))
    return c_out, r0**(-3/5)


def nollToZernInd(j):
    """
    This function maps the input "j" to the (row, column) of the Zernike pyramid using the Noll numbering scheme.
    Authors: Tim van Werkhoven, Jason Saredy
    See: https://github.com/tvwerkhoven/libtim-py/blob/master/libtim/zern.py
    """
    if (j == 0):
        raise ValueError("Noll indices start at 1, 0 is invalid.")
    n = 0
    j1 = j-1
    while (j1 > n):
        n += 1
        j1 -= n
    m = (-1)**j * ((n % 2) + 2 * int((j1+((n+1)%2)) / 2.0 ))

    return n, m


def nollCovMat(Z):
    """
    This function generates the covariance matrix for a single point source. See the associated paper for details on
    the matrix itself.
    :param Z: Number of Zernike basis functions/coefficients, determines the size of the matrix.
    :param D: The diameter of the aperture (meters)
    :param fried: The Fried parameter value
    :return:
    """
    C = torch.zeros((Z,Z))
    ijs = []
    for i in range(Z):
        ijs.append(nollToZernInd(i+1))
    for i in range(Z):
        ni, mi = ijs(i)
        for j in range(Z):
            nj, mj = ijs(j)
            if (abs(mi) == abs(mj)) and (np.mod(i + j, 2) == 0 or (mi == 0) or (mj == 0)):
            #if (abs(mi) == abs(mj)) and (np.mod(i - j, 2) == 0):
                #kzz = math.gamma(14/3) * ((24/5) * math.gamma(6/5))**(5/6) * math.gamma(11/6)**2 / (2 * np.pi**2) * (-1)**((ni + nj - 2*mi)/2) * np.sqrt(ni + 1)  * np.sqrt(nj + 1)
                kzz = 2.2698 * (-1)**((ni + nj - 2*mi)/2) * math.sqrt(ni + 1)  * math.sqrt(nj + 1)
                den = math.gamma((-ni + nj + 17.0/3.0)/2.0) * math.gamma((ni - nj + 17.0/3.0)/2.0) * \
                      math.gamma((ni + nj + 23.0/3.0)/2.0)
                C[i, j] = kzz * math.gamma((ni + nj - 5/3)/2) / den
            else:
                C[i, j] = 0
    C[0,0] = 1
    return C


'''
========================================================================
Below are some functions that are not used directly in dfp2s, but may
be utilized elsewhere
========================================================================
'''


def genZernikeCoeff(num_zern, D_r0, **kwargs):
    '''
    Just a simple function to generate random coefficients as needed, conforms to Zernike's Theory. The nollCovMat()
    function is at the heart of this function.
    A note about the function call of nollCovMat in this function. The input (..., 1, 1) is done for the sake of
    flexibility. One can call the function in the typical way as is stated in its description. However, for
    generality, the D/r0 weighting is pushed to the "b" random vector, as the covariance matrix is merely scaled by
    such value.
    :param num_zern: This is the number of Zernike basis functions/coefficients used. Should be numbers that the pyramid
    rows end at. For example [1, 3, 6, 10, 15, 21, 28, 36]
    :param D_r0:
    :return:
    '''
    num_vecs = kwargs.get('num', 1)
    C = nollCovMat(num_zern)
    R = np.linalg.cholesky(C[1:,1:])

    b = np.random.randn(int(num_zern)-1, num_vecs) * D_r0 ** (5.0/6.0)
    a = np.matmul(R, b)

    return a


def zernikeGen(N, coeff, **kwargs):
    # Generating the Zernike Phase representation.
    #
    # This implementation uses Noll's indices. 1 -> (0,0), 2 -> (1,1), 3 -> (1, -1), 4 -> (2,0), 5 -> (2, -2), etc.

    num_coeff = coeff.size

    # Setting up 2D grid
    x_grid, y_grid = np.meshgrid(np.linspace(-1, 1, N, endpoint=True), np.linspace(-1, 1, N, endpoint=True))
    mask = np.sqrt(x_grid**2 + y_grid**2) <= .97

    zern_out = np.zeros((N,N,num_coeff))
    for i in range(num_coeff):
        zern_out[:,:,i] = coeff[i]*genZernPoly(i+1, x_grid, y_grid)
        #print(i)
        #plt.imshow(zern_out[:,:,i]*mask)
        #plt.show()

    return zern_out


def genZernPoly(index, x_grid, y_grid):
    """
    This function simply
    :param index:
    :param x_grid:
    :param y_grid:
    :return:
    """
    n,m = nollToZernInd(index)
    radial = radialZernike(x_grid, y_grid, (n,m))
    m = np.abs(m)
    #print(n,m)
    if m == 0:
        return np.sqrt(n+1)*radial
    if np.mod(index,2) == 0:
        return np.multiply(np.sqrt(n+1)*radial, np.sqrt(2)*np.cos(m * np.arctan2(y_grid, x_grid)))
    if np.mod(index,2) == 1:
        return np.multiply(np.sqrt(n+1)*radial, np.sqrt(2)*np.sin(m * np.arctan2(y_grid, x_grid)))
    '''
    if m < 0:
        return np.multiply(np.sqrt(n+1)*radial, np.sqrt(2)*np.sin(-m * np.arctan2(y_grid, x_grid)))
    else:
        return np.multiply(np.sqrt(n+1)*radial, np.sqrt(2)*np.cos(m * np.arctan2(y_grid, x_grid)))
    '''

def radialZernike(x_grid, y_grid, z_ind):
    rho = np.sqrt(x_grid ** 2 + y_grid ** 2)
    radial = np.zeros(rho.shape)
    n = z_ind[0]
    m = np.abs(z_ind[1])

    for k in range(int((n - m)/2 + 1)):
        #print(k)
        temp = (-1) ** k * np.math.factorial(n - k) / (np.math.factorial(k) * np.math.factorial((n + m)/2 - k)
                                                       * np.math.factorial((n - m)/2 - k))
        radial += temp * rho ** (n - 2*k)

    # radial = rho ** np.reshape(np.asarray([range(int((n - m)/2 + 1))]), (int((n - m)/2 + 1), 1, 1))

    return radial

In [28]:
start = time.time()
for i in range(10):
    a = z_space(params)
print(time.time()-start)

  func1 = pickle.load(file_handle)
  func2 = pickle.load(file_handle)
  func1 = pickle.load(file_handle)


5.24432897567749


In [102]:
class z_space:
    def __init__(self, params, precompute_path):
        self.width = params['width']
        self.Cn2 = params['Cn2']
        self.L = params['L']
        self.wvl = params['wvl']
        self.D = params['D']
        self.Z = params['num_zern']
        N1, N2 = params['img_shape'][1], params['img_shape'][0]
        self.zern_space = torch.zeros(N2, N1, self.Z)
        dx = self.width/N1
        self.tilt_coeff_final = (self.wvl / (self.D / 2)) ** 2
        self.our_coeff_final = 0.0096932 * (2*np.pi/self.wvl)**2 * 2**(14/3) * \
                                np.pi**(8/3) * (self.D/2)**(5/3) / np.pi**2

        if type(self.Cn2) == float:
            self.M = 1
            self.Cn2 = [self.Cn2]
        else:
            self.M = len(self.Cn2)

        xx,yy = torch.meshgrid(torch.linspace(-N1/2, N1/2, N1)*dx,
                                torch.linspace(-N2/2, N2/2, N2)*dx, indexing='xy')       
        self.s_arr = torch.sqrt(xx**2 + yy**2)
        
        s1, s2 = torch.arange(0, N1), torch.arange(0, N2)
        x, y = torch.meshgrid(s1, s2, indexing='xy')
        self.thetamat = torch.atan2((x - N1 / 2), (y - N2 / 2))
        
        self.func1_10, self.func2_10 = self._load_func(precompute_path, 10)
        self.func1_100, self.func2_100 = self._load_func(precompute_path, 100)
        
    def _load_func(self, path, smax):
        func1 = {}
        func2 = {}
        for i in range(2, self.Z+1):
            path_1 = os.path.join(path, 'precomputed/fij_ints/int1_smax_{}_{}_{}'.format(smax, i, i))
            path_2 = os.path.join(path, 'precomputed/fij_ints/int2_smax_{}_{}_{}'.format(smax, i, i))
            if os.path.exists(path_1):
                func1[i] = pickle.load(open(path_1, 'rb'))
            if os.path.exists(path_2):
                func2[i] = pickle.load(open(path_2, 'rb'))
        return func1, func2
    
    def _fij(self, i, s_arr):
        ni, mi = nollToZernInd(i)
        h = tak_indicator(i, i, mi, mi) # only 1 or 5 will be used here
        nplus, nminus, mplus, mminus = ni*2, 0, mi*2, 0       
        
        if h == 1:
            mult = 1 if i%2 else -1    
            if i == 2 or i == 3:
                smax = 100
                func1, func2 = self.func1_100[i], self.func2_100[i]
            else:
                smax = 10
                func1, func2 = self.func1_10[i], self.func2_10[i]
            s_arr[s_arr >= smax*0.9] = smax*0.9
#             return mult * (-1)**((nplus - mminus)/2) * torch.cos(mplus * self.thetamat) * \
#                 func1(s_arr) + (-1)**((nplus + 2 * mi + np.abs(mminus))/2) * \
#                 torch.cos(mminus * self.thetamat) * func2(s_arr)
        
            return mult * (-1)**ni * torch.cos(mi * 2 * self.thetamat) * func1(s_arr) \
                        + (-1)**(ni + mi) * func2(s_arr)

        elif h == 5:
            if i == 2 or i == 3:
                smax = 100
                func1 = self.func1_100[i]
            else:
                smax = 10
                func1 = self.func1_10[i]
            s_arr[s_arr >= smax*0.9] = smax*0.9
            return (-1)**ni* torch.tensor(func1(s_arr))
        else:
            print("wrong h:", h)
            raise
    
    def _corr_zern(self, i):
        ni, mi = nollToZernInd(i)
        
        c_out = torch.zeros_like(self.s_arr)
        for l in range(self.M):
            temp = self._fij(i, self.s_arr * (l+1) / (self.D*(self.M-l)))     
            c_out += temp * self.Cn2[l] * ((self.M - l) / (self.M + 1)) ** (5 / 3)
        c_out = self.our_coeff_final * c_out * (self.L/self.M) * self.tilt_coeff_final * math.sqrt((ni+1) *(ni+1))
        self.zern_space[:,:,i-1] = c_out
    
    def _get_r0(self):
        r0 = 0
        for l in range(self.M):
            r0 += (0.423 * (2*math.pi/self.wvl)**2 * self.Cn2[l] * self.L / self.M) * (l/self.M) ** (5/3)
        return r0**(-3/5)  
    
    def generate(self, multithread=True):
        if multithread:
            threads = list()
            for index in range(2, self.Z+1):
                x = threading.Thread(target=self._corr_zern, args=(index,))
                threads.append(x)
                x.start()
            for thread in threads:
                thread.join()
        else:
            for i in range(2, self.Z+1):
#                 12ms/iteration 
                self._corr_zern(i) 
        #print(field_is_not_alias(zern_space, s_arr))
        return self.zern_space, self._get_r0()

In [104]:
Zsp = z_space(params, PACKAGEDIR)
start = time.time()
for i in range(10):
    a = Zsp.generate()
print(time.time()-start)

  func1[i] = pickle.load(open(path_1, 'rb'))
  func2[i] = pickle.load(open(path_2, 'rb'))


2.954676866531372


In [None]:

s_arr * 1/D