In [None]:
import dtcwt
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib notebook
sns.set_style("white")

from dtcwt.coeffs import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
from dtcwt.utils import appropriate_complex_type_for, asfarray

from dtcwt.numpy.lowlevel import colfilter as colf, coldfilt as cold, colifilt as coli

In [None]:
im = np.load(os.path.join('tests', 'mandrill.npz'))['mandrill']
plt.imshow(im, cmap='gray', interpolation='none')

In [None]:
tf.reset_default_graph()
g = tf.get_default_graph()
dir(g)
g.get_collection('variables')

In [None]:
class Pyramid(object):
    """A representation of a transform domain signal.
    Backends are free to implement any class which respects this interface for
    storing transform-domain signals. The inverse transform may accept a
    backend-specific version of this class but should always accept any class
    which corresponds to this interface.
    .. py:attribute:: lowpass
        A NumPy-compatible array containing the coarsest scale lowpass signal.
    .. py:attribute:: highpasses
        A tuple where each element is the complex subband coefficients for
        corresponding scales finest to coarsest.
    .. py:attribute:: scales
        *(optional)* A tuple where each element is a NumPy-compatible array
        containing the lowpass signal for corresponding scales finest to
        coarsest. This is not required for the inverse and may be *None*.
    """
    def __init__(self, lowpass, highpasses, scales=None):
        self.lowpass = tf.Variable(lowpass, trainable=False, dtype=tf.float32)
        self.highpasses = tuple(tf.Variable(x, trainable=False, dtype=tf.complex64) 
                                if x is not None else None for x in highpasses)
        self.scales = tuple(tf.Variable(x, trainable=False, dtype=tf.float32) 
                            for x in scales) if scales is not None else None

In [None]:
from __future__ import absolute_import, division

__all__ = [ 'colfilter', 'colifilt', 'coldfilt', ]

import numpy as np
from six.moves import xrange
from dtcwt.utils import as_column_vector, asfarray, appropriate_complex_type_for, reflect

def _centered(arr, newsize):
    # Return the center newsize portion of the array.
    # (Shamelessly cribbed from scipy.)
    newsize = np.asanyarray(newsize)
    currsize = np.array(arr.shape)
    startind = (currsize - newsize) // 2
    endind = startind + newsize
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
    return arr[tuple(myslice)]

# This is to allow easy replacement of these later with, possibly, GPU versions
_rfft = np.fft.rfft
_irfft = np.fft.irfft

def _column_convolve(X, h):
    """Convolve the columns of *X* with *h* returning only the 'valid' section,
    i.e. those values unaffected by zero padding. Irrespective of the ftype of
    *h*, the output will have the dtype of *X* appropriately expanded to a
    floating point type if necessary.
    We assume that h is small and so direct convolution is the most efficient.
    """
    Xshape = np.asanyarray(X.shape)
    h = h.flatten().astype(X.dtype)
    h_size = h.shape[0]

    full_size = X.shape[0] + h_size - 1
    Xshape[0] = full_size

    out = np.zeros(Xshape, dtype=X.dtype)
    for idx in xrange(h_size):
        out[idx:(idx+X.shape[0]),...] += X * h[idx]

    outShape = Xshape.copy()
    outShape[0] = abs(X.shape[0] - h_size) + 1
    return _centered(out, outShape)

def colfilter2(X, h):
    """Filter the columns of image *X* using filter vector *h*, without decimation.
    If len(h) is odd, each output sample is aligned with each input sample
    and *Y* is the same size as *X*.  If len(h) is even, each output sample is
    aligned with the mid point of each pair of input samples, and Y.shape =
    X.shape + [1 0].
    :param X: an image whose columns are to be filtered
    :param h: the filter coefficients.
    :returns Y: the filtered image.
    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, August 2013
    .. codeauthor:: Cian Shaffrey, Cambridge University, August 2000
    .. codeauthor:: Nick Kingsbury, Cambridge University, August 2000
    """

    # Interpret all inputs as arrays
    X = asfarray(X)
    h = as_column_vector(h)

    r, c = X.shape
    m = h.shape[0]
    m2 = np.fix(m*0.5)

    # Symmetrically extend with repeat of end samples.
    # Use 'reflect' so r < m2 works OK.
    xe = reflect(np.arange(-m2, r+m2, dtype=np.int), -0.5, r-0.5)

    # Perform filtering on the columns of the extended matrix X(xe,:), keeping
    # only the 'valid' output samples, so Y is the same size as X if m is odd.
    Y = _column_convolve(X[xe,:], h)

    return X[xe,:]

In [None]:
def colfilter(X, h):
    """Filter the columns of image *X* using filter vector *h*, without decimation.
    If len(h) is odd, each output sample is aligned with each input sample
    and *Y* is the same size as *X*.  If len(h) is even, each output sample is
    aligned with the mid point of each pair of input samples, and Y.shape =
    X.shape + [1 0].
    :param X: an image whose columns are to be filtered
    :param h: the filter coefficients.
    :returns Y: the filtered image.
    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, August 2013
    .. codeauthor:: Cian Shaffrey, Cambridge University, August 2000
    .. codeauthor:: Nick Kingsbury, Cambridge University, August 2000
    """

    m = h.get_shape().as_list()[0]
    m2 = m//2

    # Symmetrically extend with repeat of end samples.
    # Pad only the second dimension of the tensor X (the columns)
    X = tf.pad(X, [[0, 0],[m2, m2], [0, 0]], 'SYMMETRIC')
    r, c = X.get_shape().as_list()[1:3]

    # X currently has shape [batch, rows, cols]
    # h currently has shape [f_rows]
    # For conv2d to work, X needs to be in shape [batch, rows, cols, in_channels]
    # and h needs to be in shape [f_rows, f_cols, in_channels, out_channels]
    h = tf.reshape(h, [-1, 1, 1, 1])
    X = tf.reshape(X, [-1, r, c, 1])
    
    y = tf.nn.conv2d(X, h, strides=[1, 1, 1, 1], padding='VALID')
    r,c = y.get_shape().as_list()[1:3]
    # Drop the last dimension
    return tf.reshape(y, [-1,r,c])

def rowfilter(X, h):
    """Filter the rows of image *X* using filter vector *h*, without decimation.
    If len(h) is odd, each output sample is aligned with each input sample
    and *Y* is the same size as *X*.  If len(h) is even, each output sample is
    aligned with the mid point of each pair of input samples, and Y.shape =
    X.shape + [0 1].
    :param X: an image whose columns are to be filtered
    :param h: the filter coefficients.
    :returns Y: the filtered image.
    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, August 2013
    .. codeauthor:: Cian Shaffrey, Cambridge University, August 2000
    .. codeauthor:: Nick Kingsbury, Cambridge University, August 2000
    """

    m = h.get_shape().as_list()[0]
    m2 = m//2

    # Symmetrically extend with repeat of end samples.
    # Pad only the second dimension of the tensor X (the columns)
    X = tf.pad(X, [[0, 0], [0, 0], [m2, m2]], 'SYMMETRIC')
    r, c = X.get_shape().as_list()[1:3]

    # X currently has shape [batch, rows, cols]
    # h currently has shape [f_rows]
    # For conv2d to work, X needs to be in shape [batch, rows, cols, in_channels]
    # and h needs to be in shape [f_rows, f_cols, in_channels, out_channels]
    h = tf.reshape(h, [1, -1, 1, 1])
    X = tf.reshape(X, [-1, r, c, 1])
    
    y = tf.nn.conv2d(X, h, strides=[1, 1, 1, 1], padding='VALID')
    r,c = y.get_shape().as_list()[1:3]
    # Drop the last dimension
    return tf.reshape(y, [-1,r,c])

In [None]:
f = Transform2d()
h1o = tf.constant(f.qshift[0][::-1], dtype=tf.float32)
i = tf.placeholder(tf.float32, shape=[None, 512, 512])

In [None]:
# Compare the 2
im_hat = colf(im, f.qshift[0].astype('float32'))
y1 = colfilter(i, h1o)
y2 = rowfilter(tf.transpose(i, perm=[0,2,1]),h1o)
with tf.Session() as sess:
    im_hat1,im_hat2 = (k[0] for k in sess.run([y1,y2], feed_dict={i:[im]}))
    
np.testing.assert_array_almost_equal(im_hat, im_hat2.T, decimal=4)

In [None]:
fig, axes = plt.subplots(nrows=1,ncols=2,figsize=(10,5))
fig.tight_layout()
axes[0].imshow(im_hat1, cmap='gray', interpolation='none')
axes[1].imshow(im_hat2, cmap='gray', interpolation='none')

In [None]:
def q2c(y):
    """
    Convert from quads in y to complex numbers in z.
    """

    # Arrange pixels from the corners of the quads into
    # 2 subimages of alternate real and imag pixels.
    #  a----b
    #  |    |
    #  |    |
    #  c----d

    # Combine (a,b) and (d,c) to form two complex subimages.
    a,b,c,d = y[0::2, 0::2], y[0::2,1::2], y[1::2,0::2], y[1::2,1::2]
    
    p = tf.complex(a/np.sqrt(2), b/np.sqrt(2))    # p = (a + jb) / sqrt(2)
    q = tf.complex(d/np.sqrt(2), -c/np.sqrt(2))   # q = (d - jc) / sqrt(2)

    # Form the 2 highpasses in z.
    return (p-q, p+q)

In [None]:
class Transform2d(object):
    """
    An implementation of the 2D DT-CWT via NumPy. *biort* and *qshift* are the
    wavelets which parameterise the transform.
    If *biort* or *qshift* are strings, they are used as an argument to the
    :py:func:`dtcwt.coeffs.biort` or :py:func:`dtcwt.coeffs.qshift` functions.
    Otherwise, they are interpreted as tuples of vectors giving filter
    coefficients. In the *biort* case, this should be (h0o, g0o, h1o, g1o). In
    the *qshift* case, this should be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).
    """
    def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT):
        # Load bi-orthogonal wavelets
        try:
            self.biort = _biort(biort)
        except TypeError:
            self.biort = biort

        # Load quarter sample shift wavelets
        try:
            self.qshift = _qshift(qshift)
        except TypeError:
            self.qshift = qshift

    def forward(self, X, nlevels=3, include_scale=False):
        """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.
        :param X: 2D real array
        :param nlevels: Number of levels of wavelet decomposition
        :returns: A :py:class:`dtcwt.Pyramid` compatible object representing the transform-domain signal
        .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
        .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001
        """
        # If biort has 6 elements instead of 4, then it's a modified
        # rotationally symmetric wavelet
        if len(self.biort) == 4:
            # h0o, g0o, h1o, g1o = self.biort            
            h0o = tf.Variable(self.biort[0], trainable=False, name='dtcwt/h0o')
            g0o = tf.Variable(self.biort[1], trainable=False, name='dtcwt/g0o')
            h1o = tf.Variable(self.biort[2], trainable=False, name='dtcwt/h1o')
            g1o = tf.Variable(self.biort[3], trainable=False, name='dtcwt/g1o')
        elif len(self.biort) == 6:
            #h0o, g0o, h1o, g1o, h2o, g2o = self.biort
            h0o = tf.Variable(self.biort[0], trainable=False, name='dtcwt/h0o')
            g0o = tf.Variable(self.biort[1], trainable=False, name='dtcwt/g0o')
            h1o = tf.Variable(self.biort[2], trainable=False, name='dtcwt/h1o')
            g1o = tf.Variable(self.biort[3], trainable=False, name='dtcwt/g1o')
            h2o = tf.Variable(self.biort[4], trainable=False, name='dtcwt/h2o')
            g2o = tf.Variable(self.biort[5], trainable=False, name='dtcwt/g2o')
        else:
            raise ValueError('Biort wavelet must have 6 or 4 components.')

        # If qshift has 12 elements instead of 8, then it's a modified
        # rotationally symmetric wavelet
        
        # We have to reverse the qshift filters, as tensorflow's conv2d is
        # really cross-correlation. Note that we didn't have to do this for
        # biorthogonal filters as they are already symmetric.
        if len(self.qshift) == 8:
            #h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
            h0a = tf.Variable(self.qshift[0][::-1], trainable=False, name='dtcwt/h0a')
            h0b = tf.Variable(self.qshift[1][::-1], trainable=False, name='dtcwt/h0b')
            g0a = tf.Variable(self.qshift[2][::-1], trainable=False, name='dtcwt/g0a')
            g0a = tf.Variable(self.qshift[3][::-1], trainable=False, name='dtcwt/g0b')
            h1a = tf.Variable(self.qshift[4][::-1], trainable=False, name='dtcwt/h1a')
            h1b = tf.Variable(self.qshift[5][::-1], trainable=False, name='dtcwt/h1b')
            g1a = tf.Variable(self.qshift[6][::-1], trainable=False, name='dtcwt/g1a')
            g1b = tf.Variable(self.qshift[7][::-1], trainable=False, name='dtcwt/g1b')
        elif len(self.qshift) == 12:
            #h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
            h0a = tf.Variable(self.qshift[0][::-1], trainable=False, name='dtcwt/h0a')
            h0b = tf.Variable(self.qshift[1][::-1], trainable=False, name='dtcwt/h0b')
            g0a = tf.Variable(self.qshift[2][::-1], trainable=False, name='dtcwt/g0a')
            g0a = tf.Variable(self.qshift[3][::-1], trainable=False, name='dtcwt/g0b')
            h1a = tf.Variable(self.qshift[4][::-1], trainable=False, name='dtcwt/h1a')
            h1b = tf.Variable(self.qshift[5][::-1], trainable=False, name='dtcwt/h1b')
            g1a = tf.Variable(self.qshift[6][::-1], trainable=False, name='dtcwt/g1a')
            g1b = tf.Variable(self.qshift[7][::-1], trainable=False, name='dtcwt/g1b')
            h2a = tf.Variable(self.qshift[8][::-1], trainable=False, name='dtcwt/h2a')
            h2b = tf.Variable(self.qshift[9][::-1], trainable=False, name='dtcwt/h2b')
        else:
            raise ValueError('Qshift wavelet must have 12 or 8 components.')

        # Check the shape of the input
        original_size = X.get_shape().as_list()[1:3]

        # The next few lines of code check to see if the image is odd in size, if so an extra ...
        # row/column will be added to the bottom/right of the image
        initial_row_extend = 0  #initialise
        initial_col_extend = 0
        if original_size[0] % 2 != 0:
            # if X.shape[0] is not divisable by 2 then we need to extend X by adding a row at the bottom
            bottom_row = tf.slice(X, [0, original_size[0]-1,0], [-1, 1, -1])
            X = tf.concat([X, bottom_row], axis=1)
            #X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
            initial_row_extend = 1

        if original_size[1] % 2 != 0:
            # if X.shape[1] is not divisable by 2 then we need to extend X by adding a col to the left
            right_column = tf.slice(X, [0, 0, original_size[1]-1], [-1, -1, 1])
            X = tf.concat([X, right_column], axis=2)
            #X = np.hstack((X, X[:,[-1]]))
            initial_col_extend = 1

        extended_size = X.get_shape().as_list()[1:3]
        
        if nlevels == 0:
            if include_scale:
                return Pyramid(X, (), ())
            else:
                return Pyramid(X, ())

        # initialise
        Yh = [None,] * nlevels
        if include_scale:
            # this is only required if the user specifies a third output component.
            Yscale = [None,] * nlevels

        #complex_dtype = appropriate_complex_type_for(X)
        
        if nlevels >= 1:
            # Do odd top-level filters on cols.
            Lo = tf.transpose(colfilter(X,h0o), perm=[0,2,1])
            Hi = tf.transpose(colfilter(X,h1o), perm=[0,2,1])
            if len(self.biort) >= 6:
                Ba = tf.tranpsoe(colfilter(X,h2o), perm=[0,2,1])

            # Do odd top-level filters on rows.
            LoLo = tf.transpose(colfilter(Lo,h0o), perm=[0,2,1])
            LoLo_shape = LoLo.get_shape().as_list()[1:3]
            Yh[0] = tf.Variable(np.zeros((LoLo_shape[0]>>1, LoLo_shape[1] >>1, 6), dtype=tf.complex64))
            Yh[0][:,:,0], Yh[0][:,:,5] = q2c(tf.transpose(colfilter(Hi,h0o), perm=[0,2,1]))     # Horizontal pair
            Yh[0][:,:,2], Yh[0][:,:,3] = q2c(tf.transpose(colfilter(Lo,h1o), perm=[0,2,1]))     # Vertical pair
            if len(self.biort) >= 6:
                Yh[0][:,:,1], Yh[0][:,:,4] = q2c(tf.transpose(colfilter(Ba,h2o), perm=[0,2,1]))     # Diagonal pair
            else:
                Yh[0][:,:,1], Yh[0][:,:,4] = q2c(tf.tranpose(colfilter(Hi,h1o), perm=[0,2,1]))     # Diagonal pair

            if include_scale:
                Yscale[0] = LoLo
        
        if include_scale:
            return Pyramid(LoLo, Yh, Yscale)
        else:
            return Pyramid(LoLo, Yh, ())
        

In [None]:
P = f.forward(i)

In [None]:
im = sess.run(x, feed_dict={i: [im]})
im2.shape

In [None]:
      for level in xrange(1, nlevels):
            row_size, col_size = LoLo.shape
            if row_size % 4 != 0:
                # Extend by 2 rows if no. of rows of LoLo are not divisable by 4
                LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))

            if col_size % 4 != 0:
                # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
                LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))

            # Do even Qshift filters on rows.
            Lo = coldfilt(LoLo,h0b,h0a).T
            Hi = coldfilt(LoLo,h1b,h1a).T
            if len(self.qshift) >= 12:
                Ba = coldfilt(LoLo,h2b,h2a).T

            # Do even Qshift filters on columns.
            LoLo = coldfilt(Lo,h0b,h0a).T

            Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
            Yh[level][:,:,0:6:5] = q2c(coldfilt(Hi,h0b,h0a).T)  # Horizontal
            Yh[level][:,:,2:4:1] = q2c(coldfilt(Lo,h1b,h1a).T)  # Vertical
            if len(self.qshift) >= 12:
                Yh[level][:,:,1:5:3] = q2c(coldfilt(Ba,h2b,h2a).T)  # Diagonal   
            else:
                Yh[level][:,:,1:5:3] = q2c(coldfilt(Hi,h1b,h1a).T)  # Diagonal   

            if include_scale:
                Yscale[level] = LoLo
                
        return X

In [None]:
"""
        if len(X.shape) >= 3:
            raise ValueError('The entered image is {0}, please enter each image slice separately.'.
                    format('x'.join(list(str(s) for s in X.shape))))

        
        for level in xrange(1, nlevels):
            row_size, col_size = LoLo.shape
            if row_size % 4 != 0:
                # Extend by 2 rows if no. of rows of LoLo are not divisable by 4
                LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))

            if col_size % 4 != 0:
                # Extend by 2 cols if no. of cols of LoLo are not divisable by 4
                LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))

            # Do even Qshift filters on rows.
            Lo = coldfilt(LoLo,h0b,h0a).T
            Hi = coldfilt(LoLo,h1b,h1a).T
            if len(self.qshift) >= 12:
                Ba = coldfilt(LoLo,h2b,h2a).T

            # Do even Qshift filters on columns.
            LoLo = coldfilt(Lo,h0b,h0a).T

            Yh[level] = np.zeros((LoLo.shape[0]>>1, LoLo.shape[1]>>1, 6), dtype=complex_dtype)
            Yh[level][:,:,0:6:5] = q2c(coldfilt(Hi,h0b,h0a).T)  # Horizontal
            Yh[level][:,:,2:4:1] = q2c(coldfilt(Lo,h1b,h1a).T)  # Vertical
            if len(self.qshift) >= 12:
                Yh[level][:,:,1:5:3] = q2c(coldfilt(Ba,h2b,h2a).T)  # Diagonal   
            else:
                Yh[level][:,:,1:5:3] = q2c(coldfilt(Hi,h1b,h1a).T)  # Diagonal   

            if include_scale:
                Yscale[level] = LoLo

        Yl = LoLo

        if initial_row_extend == 1 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row and rightmost column have been duplicated, prior to decomposition.')

        if initial_row_extend == 1 and initial_col_extend == 0:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row has been duplicated, prior to decomposition.')

        if initial_row_extend == 0 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The rightmost column has been duplicated, prior to decomposition.')

        if include_scale:
            return Pyramid(Yl, tuple(Yh), tuple(Yscale))
        else:
            return Pyramid(Yl, tuple(Yh))
"""

#    def inverse(self, pyramid, gain_mask=None):
        """Perform an *n*-level dual-tree complex wavelet (DTCWT) 2D
        reconstruction.
        :param pyramid: A :py:class:`dtcwt.Pyramid`-like class holding the transform domain representation to invert.
        :param gain_mask: Gain to be applied to each subband.
        :returns: A numpy-array compatible instance with the reconstruction.
        The (*d*, *l*)-th element of *gain_mask* is gain for subband with direction
        *d* at level *l*. If gain_mask[d,l] == 0, no computation is performed for
        band (d,l). Default *gain_mask* is all ones. Note that both *d* and *l* are
        zero-indexed.
        .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, May 2002
        .. codeauthor:: Cian Shaffrey, Cambridge University, May 2002
        """
"""
        Yl = pyramid.lowpass
        Yh = pyramid.highpasses

        a = len(Yh) # No of levels.

        if gain_mask is None:
            gain_mask = np.ones((6,a)) # Default gain_mask.

        gain_mask = np.array(gain_mask)

        # If biort has 6 elements instead of 4, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.biort) == 4:
            h0o, g0o, h1o, g1o = self.biort
        elif len(self.biort) == 6:
            h0o, g0o, h1o, g1o, h2o, g2o = self.biort
        else:
            raise ValueError('Biort wavelet must have 6 or 4 components.')

        # If qshift has 12 elements instead of 8, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.qshift) == 8:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
        elif len(self.qshift) == 12:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b, g2a, g2b = self.qshift
        else:
            raise ValueError('Qshift wavelet must have 12 or 8 components.')

        current_level = a
        Z = Yl

        while current_level >= 2: # this ensures that for level 1 we never do the following
            lh = c2q(Yh[current_level-1][:,:,[0, 5]], gain_mask[[0, 5], current_level-1])
            hl = c2q(Yh[current_level-1][:,:,[2, 3]], gain_mask[[2, 3], current_level-1])
            hh = c2q(Yh[current_level-1][:,:,[1, 4]], gain_mask[[1, 4], current_level-1])

            # Do even Qshift filters on columns.
            y1 = colifilt(Z,g0b,g0a) + colifilt(lh,g1b,g1a)

            if len(self.qshift) >= 12:
                y2 = colifilt(hl,g0b,g0a)
                y2bp = colifilt(hh,g2b,g2a)

                # Do even Qshift filters on rows.
                Z = (colifilt(y1.T,g0b,g0a) + colifilt(y2.T,g1b,g1a) + colifilt(y2bp.T, g2b, g2a)).T
            else:
                y2 = colifilt(hl,g0b,g0a) + colifilt(hh,g1b,g1a)

                # Do even Qshift filters on rows.
                Z = (colifilt(y1.T,g0b,g0a) + colifilt(y2.T,g1b,g1a)).T

            # Check size of Z and crop as required
            [row_size, col_size] = Z.shape
            S = 2*np.array(Yh[current_level-2].shape)
            if row_size != S[0]:    # check to see if this result needs to be cropped for the rows
                Z = Z[1:-1,:]
            if col_size != S[1]:    # check to see if this result needs to be cropped for the cols
                Z = Z[:,1:-1]

            if np.any(np.array(Z.shape) != S[:2]):
                raise ValueError('Sizes of highpasses are not valid for DTWAVEIFM2')
            
            current_level = current_level - 1

        if current_level == 1:
            lh = c2q(Yh[current_level-1][:,:,[0, 5]],gain_mask[[0, 5],current_level-1])
            hl = c2q(Yh[current_level-1][:,:,[2, 3]],gain_mask[[2, 3],current_level-1])
            hh = c2q(Yh[current_level-1][:,:,[1, 4]],gain_mask[[1, 4],current_level-1])

            # Do odd top-level filters on columns.
            y1 = colfilter(Z,g0o) + colfilter(lh,g1o)

            if len(self.biort) >= 6:
                y2 = colfilter(hl,g0o)
                y2bp = colfilter(hh,g2o)

                # Do odd top-level filters on rows.
                Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o) + colfilter(y2bp.T, g2o)).T
            else:
                y2 = colfilter(hl,g0o) + colfilter(hh,g1o)

                # Do odd top-level filters on rows.
                Z = (colfilter(y1.T,g0o) + colfilter(y2.T,g1o)).T

        return Z
"""
#==========================================================================================
#                       **********    INTERNAL FUNCTIONS    **********
#==========================================================================================



def c2q(w,gain):
    """
    Scale by gain and convert from complex w(:,:,1:2) to real quad-numbers
    in z.
    Arrange pixels from the real and imag parts of the 2 highpasses
    into 4 separate subimages .
     A----B     Re   Im of w(:,:,1)
     |    |
     |    |
     C----D     Re   Im of w(:,:,2)
    """

    x = np.zeros((w.shape[0] << 1, w.shape[1] << 1), dtype=w.real.dtype)

    sc = np.sqrt(0.5) * gain
    P = w[:,:,0]*sc[0] + w[:,:,1]*sc[1]
    Q = w[:,:,0]*sc[0] - w[:,:,1]*sc[1]

    # Recover each of the 4 corners of the quads.
    x[0::2, 0::2] = P.real  # a = (A+C)*sc
    x[0::2, 1::2] = P.imag  # b = (B+D)*sc
    x[1::2, 0::2] = Q.imag  # c = (B-D)*sc
    x[1::2, 1::2] = -Q.real # d = (C-A)*sc

    return x

In [None]:
def _centered(arr, newsize):
    # Return the center newsize portion of the array.
    # (Shamelessly cribbed from scipy.)
    newsize = np.asanyarray(newsize)
    currsize = np.array(arr.shape)
    startind = (currsize - newsize) // 2
    endind = startind + newsize
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
    return arr[tuple(myslice)]

# This is to allow easy replacement of these later with, possibly, GPU versions
_rfft = np.fft.rfft
_irfft = np.fft.irfft

def _column_convolve(X, h):
    """Convolve the columns of *X* with *h* returning only the 'valid' section,
    i.e. those values unaffected by zero padding. Irrespective of the ftype of
    *h*, the output will have the dtype of *X* appropriately expanded to a
    floating point type if necessary.
    We assume that h is small and so direct convolution is the most efficient.
    """
    Xshape = np.asanyarray(X.shape)
    h = h.flatten().astype(X.dtype)
    h_size = h.shape[0]

    full_size = X.shape[0] + h_size - 1
    Xshape[0] = full_size

    out = np.zeros(Xshape, dtype=X.dtype)
    for idx in xrange(h_size):
        out[idx:(idx+X.shape[0]),...] += X * h[idx]

    outShape = Xshape.copy()
    outShape[0] = abs(X.shape[0] - h_size) + 1
    return _centered(out, outShape)


    return Y

def coldfilt(X, ha, hb):
    """Filter the columns of image X using the two filters ha and hb =
    reverse(ha).  ha operates on the odd samples of X and hb on the even
    samples.  Both filters should be even length, and h should be approx linear
    phase with a quarter sample advance from its mid pt (i.e. :math:`|h(m/2)| >
    |h(m/2 + 1)|`).
    .. code-block:: text
                          ext        top edge                     bottom edge       ext
        Level 1:        !               |               !               |               !
        odd filt on .    b   b   b   b   a   a   a   a   a   a   a   a   b   b   b   b
        odd filt on .      a   a   a   a   b   b   b   b   b   b   b   b   a   a   a   a
        Level 2:        !               |               !               |               !
        +q filt on x      b       b       a       a       a       a       b       b
        -q filt on o          a       a       b       b       b       b       a       a
    The output is decimated by two from the input sample rate and the results
    from the two filters, Ya and Yb, are interleaved to give Y.  Symmetric
    extension with repeated end samples is used on the composite X columns
    before each filter is applied.
    Raises ValueError if the number of rows in X is not a multiple of 4, the
    length of ha does not match hb or the lengths of ha or hb are non-even.
    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, August 2013
    .. codeauthor:: Cian Shaffrey, Cambridge University, August 2000
    .. codeauthor:: Nick Kingsbury, Cambridge University, August 2000
    """
    # Make sure all inputs are arrays
    X = asfarray(X)
    ha = asfarray(ha)
    hb = asfarray(hb)

    r, c = X.shape
    if r % 4 != 0:
        raise ValueError('No. of rows in X must be a multiple of 4')

    if ha.shape != hb.shape:
        raise ValueError('Shapes of ha and hb must be the same')

    if ha.shape[0] % 2 != 0:
        raise ValueError('Lengths of ha and hb must be even')

    m = ha.shape[0]
    m2 = np.fix(m*0.5)

    # Set up vector for symmetric extension of X with repeated end samples.
    xe = reflect(np.arange(-m, r+m), -0.5, r-0.5)

    # Select odd and even samples from ha and hb. Note that due to 0-indexing
    # 'odd' and 'even' are not perhaps what you might expect them to be.
    hao = as_column_vector(ha[0:m:2])
    hae = as_column_vector(ha[1:m:2])
    hbo = as_column_vector(hb[0:m:2])
    hbe = as_column_vector(hb[1:m:2])
    t = np.arange(5, r+2*m-2, 4)
    r2 = r//2;
    Y = np.zeros((r2,c), dtype=X.dtype)

    if np.sum(ha*hb) > 0:
       s1 = slice(0, r2, 2)
       s2 = slice(1, r2, 2)
    else:
       s2 = slice(0, r2, 2)
       s1 = slice(1, r2, 2)

    # Perform filtering on columns of extended matrix X(xe,:) in 4 ways.
    Y[s1,:] = _column_convolve(X[xe[t-1],:],hao) + _column_convolve(X[xe[t-3],:],hae)
    Y[s2,:] = _column_convolve(X[xe[t],:],hbo) + _column_convolve(X[xe[t-2],:],hbe)

    return Y

def colifilt(X, ha, hb):
    """ Filter the columns of image X using the two filters ha and hb =
    reverse(ha).  ha operates on the odd samples of X and hb on the even
    samples.  Both filters should be even length, and h should be approx linear
    phase with a quarter sample advance from its mid pt (i.e `:math:`|h(m/2)| >
    |h(m/2 + 1)|`).
    .. code-block:: text
                          ext       left edge                      right edge       ext
        Level 2:        !               |               !               |               !
        +q filt on x      b       b       a       a       a       a       b       b
        -q filt on o          a       a       b       b       b       b       a       a
        Level 1:        !               |               !               |               !
        odd filt on .    b   b   b   b   a   a   a   a   a   a   a   a   b   b   b   b
        odd filt on .      a   a   a   a   b   b   b   b   b   b   b   b   a   a   a   a
    The output is interpolated by two from the input sample rate and the
    results from the two filters, Ya and Yb, are interleaved to give Y.
    Symmetric extension with repeated end samples is used on the composite X
    columns before each filter is applied.
    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, August 2013
    .. codeauthor:: Cian Shaffrey, Cambridge University, August 2000
    .. codeauthor:: Nick Kingsbury, Cambridge University, August 2000
    """
    # Make sure all inputs are arrays
    X = asfarray(X)
    ha = asfarray(ha)
    hb = asfarray(hb)

    r, c = X.shape
    if r % 2 != 0:
        raise ValueError('No. of rows in X must be a multiple of 2')

    if ha.shape != hb.shape:
        raise ValueError('Shapes of ha and hb must be the same')

    if ha.shape[0] % 2 != 0:
        raise ValueError('Lengths of ha and hb must be even')

    m = ha.shape[0]
    m2 = np.fix(m*0.5)

    Y = np.zeros((r*2,c), dtype=X.dtype)
    if not np.any(np.nonzero(X[:])[0]):
        return Y

    if m2 % 2 == 0:
        # m/2 is even, so set up t to start on d samples.
        # Set up vector for symmetric extension of X with repeated end samples.
        # Use 'reflect' so r < m2 works OK.
        xe = reflect(np.arange(-m2, r+m2, dtype=np.int), -0.5, r-0.5)

        t = np.arange(3, r+m, 2)
        if np.sum(ha*hb) > 0:
            ta = t
            tb = t - 1
        else:
            ta = t - 1
            tb = t

        # Select odd and even samples from ha and hb. Note that due to 0-indexing
        # 'odd' and 'even' are not perhaps what you might expect them to be.
        hao = as_column_vector(ha[0:m:2])
        hae = as_column_vector(ha[1:m:2])
        hbo = as_column_vector(hb[0:m:2])
        hbe = as_column_vector(hb[1:m:2])

        s = np.arange(0,r*2,4)

        Y[s,:]   = _column_convolve(X[xe[tb-2],:],hae)
        Y[s+1,:] = _column_convolve(X[xe[ta-2],:],hbe)
        Y[s+2,:] = _column_convolve(X[xe[tb  ],:],hao)
        Y[s+3,:] = _column_convolve(X[xe[ta  ],:],hbo)
    else:
        # m/2 is odd, so set up t to start on b samples.
        # Set up vector for symmetric extension of X with repeated end samples.
        # Use 'reflect' so r < m2 works OK.
        xe = reflect(np.arange(-m2, r+m2, dtype=np.int), -0.5, r-0.5)

        t = np.arange(2, r+m-1, 2)
        if np.sum(ha*hb) > 0:
            ta = t
            tb = t - 1
        else:
            ta = t - 1
            tb = t

        # Select odd and even samples from ha and hb. Note that due to 0-indexing
        # 'odd' and 'even' are not perhaps what you might expect them to be.
        hao = as_column_vector(ha[0:m:2])
        hae = as_column_vector(ha[1:m:2])
        hbo = as_column_vector(hb[0:m:2])
        hbe = as_column_vector(hb[1:m:2])

        s = np.arange(0,r*2,4)

        Y[s,:]   = _column_convolve(X[xe[tb],:],hao)
        Y[s+1,:] = _column_convolve(X[xe[ta],:],hbo)
        Y[s+2,:] = _column_convolve(X[xe[tb],:],hae)
        Y[s+3,:] = _column_convolve(X[xe[ta],:],hbe)

    return Y