In [3]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.functional as F
#import tensorflow as tf
from Ctorch import *

# Standard data type: batches of shape NxCxHxWx2

In [None]:
def zero_pad(image, shape, position='corner'):
    """
    Extends image to a certain size with zeros

    Parameters
    ----------
    image: real 2d `numpy.ndarray`
        Input image
    shape: tuple of int
        Desired output shape of the image
    position : str, optional
        The position of the input image in the output one:
            * 'corner'
                top-left corner (default)
            * 'center'
                centered

    Returns
    -------
    padded_img: real `numpy.ndarray`
        The zero-padded image

    """
    shape = np.asarray(shape, dtype=int)
    imshape = np.asarray(image.shape, dtype=int)

    if np.alltrue(imshape == shape):
        return image

    if np.any(shape <= 0):
        raise ValueError("ZERO_PAD: null or negative shape given")

    dshape = shape - imshape
    if np.any(dshape < 0):
        raise ValueError("ZERO_PAD: target size smaller than source one")

    pad_img = np.zeros(shape, dtype=image.dtype)

    idx, idy = np.indices(imshape)

    if position == 'center':
        if np.any(dshape % 2 != 0):
            raise ValueError("ZERO_PAD: source and target shapes "
                             "have different parity.")
        offx, offy = dshape // 2
    else:
        offx, offy = (0, 0)

    pad_img[idx + offx, idy + offy] = image

    return pad_img

def psf2otf(psf, shape):
    """
    Convert point-spread function to optical transfer function.

    Compute the Fast Fourier Transform (FFT) of the point-spread
    function (PSF) array and creates the optical transfer function (OTF)
    array that is not influenced by the PSF off-centering.
    By default, the OTF array is the same size as the PSF array.

    To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
    post-pads the PSF array (down or to the right) with zeros to match
    dimensions specified in OUTSIZE, then circularly shifts the values of
    the PSF array up (or to the left) until the central pixel reaches (1,1)
    position.

    Adapted from MATLAB psf2otf function
    """
    if np.all(psf == 0):
        return np.zeros_like(psf)

    inshape = psf.shape
    # Pad the PSF to outsize
    psf = zero_pad(psf, shape, position='corner')

    # Circularly shift OTF so that the 'center' of the PSF is
    # [0,0] element of the array
    for axis, axis_size in enumerate(inshape):
        psf = np.roll(psf, -int(axis_size / 2), axis=axis)

    # Compute the OTF
    otf = np.fft.fft2(psf)

    # Estimate the rough number of operations involved in the FFT
    # and discard the PSF imaginary part if within roundoff error
    # roundoff error  = machine epsilon = sys.float_info.epsilon
    # or np.finfo().eps
    n_ops = np.sum(psf.size * np.log2(psf.shape))
    otf = np.real_if_close(otf, tol=n_ops)

    return otf

In [9]:
def psf2otf(psf, img_shape):
    # shape and type of the point spread function(s)
    psf_shape = psf.shape;
    psf_type = psf.dtype

    # coordinates for 'cutting up' the psf tensor
    midH = tf.floor_div(psf_shape[0], 2)
    midW = tf.floor_div(psf_shape[1], 2)

    # slice the psf tensor into four parts
    top_left     = psf[:midH, :midW, :, :]
    top_right    = psf[:midH, midW:, :, :]
    bottom_left  = psf[midH:, :midW, :, :]
    bottom_right = psf[midH:, midW:, :, :]

    # prepare zeros for filler
    zeros_bottom = tf.zeros([psf_shape[0] - midH, img_shape[1] - psf_shape[1], psf_shape[2], psf_shape[3]], dtype=psf_type)
    zeros_top    = tf.zeros([midH, img_shape[1] - psf_shape[1], psf_shape[2], psf_shape[3]], dtype=psf_type)

    # construct top and bottom row of new tensor
    top    = tf.concat([bottom_right, zeros_bottom, bottom_left], 1)
    bottom = tf.concat([top_right,    zeros_top,    top_left],    1)

    # prepare additional filler zeros and put everything together
    zeros_mid = tf.zeros([img_shape[0] - psf_shape[0], img_shape[1], psf_shape[2], psf_shape[3]], dtype=psf_type)
    pre_otf = tf.concat([top, zeros_mid, bottom], 0)
    # output shape: [img_shape[0], img_shape[1], channels_in, channels_out]

    # fast fourier transform, transposed because tensor must have shape [..., height, width] for this
    otf = tf.fft2d(tf.cast(tf.transpose(pre_otf, perm=[2,3,0,1]), tf.complex64))

    # output shape: [channels_in, channels_out, img_shape[0], img_shape[1]]
    return otf