Skip to content

Commit

Permalink
fixed biorthogonal case
Browse files Browse the repository at this point in the history
  • Loading branch information
pderian-cea committed Feb 15, 2019
1 parent eff761b commit d53e49b
Showing 1 changed file with 97 additions and 83 deletions.
180 changes: 97 additions & 83 deletions lib/highorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,100 @@
Wavelets and Fluid Motion Estimation.
PhD thesis, MATISSE doctoral school, Université Rennes 1, 2012.
"""
###
# Third-party imports
import numpy as np
import pywt
import scipy.ndimage as ndimage
import scipy.optimize as optimize
###

class HighOrderRegularizer:
"""Implements high-order wavelet-based regularization terms.

def connection_coefficients(wav, order):
"""Find the connection coefficients of the wavelet at given order.
:param wav: a pywt.Wavelet;
:param order: the derivation order;
:param return: a vector of coefficients.
This is the evaluation of L2 dot-products of the form:
\int[ Phi(x) (d^(n)/dx^n){Phi}(x) ]dx
where Phi is the mother wavelet.
Written by P. DERIAN 2018-01-09.
Updated by P. DERIAN 2019-02-15: using reconstruction filter.
"""
ORDER_MAX = 6 #max order for coefficients computation.
ctol = 1e-15 # Tolerance for coefficients
etol = 1e-4 # Tolerance for eigenvalues

# Get the low-pass reconstruction filter
lo = wav.rec_lo
len_lo = len(lo)
# Create the matrix
dim = 2*len_lo - 3
matrix = np.zeros((dim, dim))
for m in range(dim):
for n in range(dim):
tmp = 0.
# for each filter value
for p, lo_p in enumerate(lo):
idx = m - 2*n + p + len_lo - 2
if (idx>=0 and idx<len_lo):
tmp += lo_p*lo[idx]
# store only if above threshold
if np.abs(tmp)>ctol:
matrix[n,m] = tmp
# Find coefficient vector: solve eigenvalue problem
eval, evec = np.linalg.eig(matrix)
# Check if any REAL eigenvalue matches our order
sigma = 1./float(2**order)
ev_found = False
for i, ev in enumerate(eval):
if (np.abs(np.real(ev)-sigma)<etol) and (np.abs(np.imag(ev)<1e-14)):
ev_found = True
break
# [TODO] warn if not found?
coeffs = None
if ev_found:
# Get the associated eigen vector
coeffs = evec[:,i]
# If the derivation order is odd, force mid value to be exactly zero
# as it should be by construction.
if order%2:
coeffs[dim//2] = 0.
# Apply Beylkin normalization
# [TODO] document
norm_factors = np.array([float(-len_lo + 2 + p) for p in range(dim)])**order
coeffs /= np.sum(norm_factors*coeffs)
tmp = (np.prod(np.arange(1, order+1)))*np.power(-1.,order)
coeffs *= tmp
return coeffs


class HighOrderRegularizerConv:
"""Implements high-order wavelet-based regularization terms,
here implemented with convolutions.
Written by P. DERIAN 2018-01-09.
"""
ORDER_MAX = 6 # Max order for coefficients computation.

def __init__(self, wav):
"""Instance constructor.
Written by P. DERIAN 2018-01-09.
Updated by P. DERIAN 2019-02-15: fixed biorthogonal case.
"""
self.mode = 'periodization'
self.wav = wav #keep reference
self.coeffs = {k: self.connection_coefficients(wav, k) for k in range(self.ORDER_MAX)}
# The main wavelet
self.wav = wav
# Get a swapped version (primal / dual filters are swapped)
# Note: if the wav is orthogonal, wavswap is in practice the same as wav.
self.wavswap = pywt.Wavelet(
name='{}_swap'.format(self.wav.name),
filter_bank=self.wav.inverse_filter_bank,
)
# Compute the connection coefficients up to order max.
self.coeffs = {k: connection_coefficients(wav, k) for k in range(self.ORDER_MAX)}
# The set of available regularizers
self.regularizers = {'l2norm': self._l2norm_gradient,
'hornschunck': self._hornschunck_gradient,
}
Expand All @@ -51,23 +123,25 @@ def evaluate(self, C1, C2, regul_type='l2norm'):
:return: value, (grad1, grad2)
Written by P. DERIAN 2018-01-09.
Updated by P. DERIAN 2019-02-15: fixed biorthogonal case.
"""
### infer levels for further decomposition
# Infer levels for further decomposition
levels = len(C1) - 1
### rebuild the fields
# Rebuild the fields
U1 = pywt.waverec2(C1, self.wav, mode=self.mode)
U2 = pywt.waverec2(C2, self.wav, mode=self.mode)
#### evaluate
# Evaluate
[grad1, grad2] = self.regularizers[regul_type](U1, U2)
# decompose gradient to complete its computation
grad1 = pywt.wavedec2(grad1, self.wav, level=levels, mode=self.mode)
grad2 = pywt.wavedec2(grad2, self.wav, level=levels, mode=self.mode)
# compute the functional value
# Decompose gradient to complete its computation
# Note: use wavswap here!
grad1 = pywt.wavedec2(grad1, self.wavswap, level=levels, mode=self.mode)
grad2 = pywt.wavedec2(grad2, self.wavswap, level=levels, mode=self.mode)
# Compute the functional value
result = 0.
for c, g in zip([C1, C2], [grad1, grad2]):
# add contribution of approx
# Add contribution of approx
result += np.dot(c[0].ravel(), g[0].ravel())
# and details
# And details
for cd, gd in zip(c[1:], g[1:]):
for cdd, gdd in zip(cd, gd):
result += np.dot(cdd.ravel(), gdd.ravel())
Expand Down Expand Up @@ -97,67 +171,6 @@ def _hornschunck_gradient(self, U1, U2):
for U in [U1, U2]]
return result

@staticmethod
def connection_coefficients(wav, order):
"""Find the connection coefficients of the wavelet at given order.
:param wav: a pywt.Wavelet;
:param order: the derivation order;
:param return: a vector of coefficients.
This is the evaluation of L2 dot-products of the form:
\int[ Phi(x) (d^(n)/dx^n){Phi}(x) ]dx
where Phi is the mother wavelet.
Written by P. DERIAN 2018-01-09.
"""
ctol = 1e-15 #tolerance for coefficients
etol = 1e-4 #tolerance for eigenvalues
### get the low-pass filter
lo = wav.dec_lo
len_lo = len(lo)
### create the matrix
dim = 2*len_lo - 3
matrix = np.zeros((dim, dim))
for m in range(dim):
for n in range(dim):
tmp = 0.
# for each filter value
for p, lo_p in enumerate(lo):
idx = m - 2*n + p + len_lo - 2
if (idx>=0 and idx<len_lo):
tmp += lo_p*lo[idx]
# store only if above threshold
if np.abs(tmp)>ctol:
matrix[n,m] = tmp
### Find coefficient vector
# solve eigen values
eval, evec = np.linalg.eig(matrix)
#print('eigs:', eval)
# check if any REAL eigenvalue matches our order
sigma = 1./float(2**order)
ev_found = False
for i, ev in enumerate(eval):
if (np.abs(np.real(ev)-sigma)<etol) and (np.abs(np.imag(ev)<1e-14)):
ev_found = True
break
# [TODO] warn if not found?
coeffs = None
if ev_found:
coeffs = evec[:,i]
#print('found:', ev_found, i, ev, coeffs)
### Process
# if the order is odd, force mid value to be exactly zero
if order%2:
coeffs[dim//2] = 0.
#print(coeffs)
# apply Beylkin normalization
norm_factors = np.array([float(-len_lo + 2 + p) for p in range(dim)])**order
coeffs /= np.sum(norm_factors*coeffs)
tmp = (np.prod(np.arange(1, order+1)))*np.power(-1.,order)
coeffs *= tmp
return coeffs

@staticmethod
def convolve_separable(x, filter1, filter2, origin1=0, origin2=0):
"""Separable convolution of x by filter1 along first axis and filter2 along second axis.
Expand All @@ -170,16 +183,17 @@ def convolve_separable(x, filter1, filter2, origin1=0, origin2=0):
tmp = ndimage.filters.convolve(x, filter1.reshape(1,-1), mode='wrap', origin=origin1)
return ndimage.filters.convolve(tmp, filter2.reshape(-1,1), mode='wrap', origin=origin2)


### Demonstrations ###

if __name__=="__main__":
###
# Standard library
import matplotlib.pyplot as pyplot
###
# Custom
import sys
sys.path.append('..')
import demo.inr as inr
###


def demo_connection_coeff():
"""
Expand All @@ -198,9 +212,9 @@ def hornschunk(U1, U2):
result += np.sum(g1**2 + g2**2)
return 0.5*result

levels = 4
wav = pywt.Wavelet('db5')
hor = HighOrderRegularizer(wav)
levels = 3
wav = pywt.Wavelet('db6')
hor = HighOrderRegularizerConv(wav)
U2, U1 = inr.readMotion('../demo/UVtruth.inr')
C1 = pywt.wavedec2(U1, hor.wav, level=levels, mode=hor.mode)
C2 = pywt.wavedec2(U2, hor.wav, level=levels, mode=hor.mode)
Expand Down

0 comments on commit d53e49b

Please sign in to comment.