## PTWT

1. Use ptwt to calculate 2nd order wavelet decomposition
2. Reconstruct using all the information
3. Reconstruct using only the 

In [1]:
import ptwt, pywt, torch
import torch.nn as nn
import numpy as np
import scipy.misc
from matplotlib import pyplot as plt
face = np.transpose(scipy.misc.face(),
                    [2, 0, 1]).astype(np.float64)
pytorch_face = torch.tensor(face)
coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
                             level=2, mode="constant")

In [None]:
def plot_(x):
    print(x.shape)
    plt.imshow(x.numpy().squeeze().transpose(1,2,0)/255)
    plt.show()

In [None]:
# Using both first and second order subbands - size remains same
reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
plot_(reconstruction)

In [None]:
# Reconstructing using second level subbands. Leads to downsampling by a factor of 2
c = []
c.append(coefficients[0])
c.append(coefficients[1])
reconstruction_ = ptwt.waverec2(c, pywt.Wavelet("haar"))
plot_(reconstruction_)

In [None]:
# Reconstructing using strategy in paper - upsampled by a factor of 2
# DWT(I) = (LL, (LH, HL, HH))
# I (u2) = (LL, (LH, HL, HH), (LH(u2), HL(u2), HH(u2)))

coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
                             level=1, mode="constant")
c = coefficients
upsample_ = nn.Upsample(scale_factor=2, mode='nearest')
c_2 = []
for coef in coefficients[1]:
    c_2.append(upsample_(coef))
c.append((c_2))

reconstruction_ = ptwt.waverec2(c, pywt.Wavelet("haar"))
plot_(reconstruction_)

# Wavelet Pooling 
##### Implementing wavelet pooling as a subclass of torch.nn module and implementing the backpropagation via the backward hook

In [None]:
import torch
import numpy as np
import ptwt, pywt
import torch.nn as nn

In [None]:
class WaveletPooling(nn.Module):
    def __init__(self, wavelet):
        super(WaveletPooling,self).__init__()
        self.upsample_ = nn.Upsample(scale_factor=2, mode='nearest')
        self.wavelet = wavelet
    
    def forward(self, x):        
        bs = x.size()[0]
        FORWARD_OUTPUT_ = []
        
        # loop over input as batching not supported
        for k in range(bs):
            # coeffiecients - cx1xhxw
            coefficients = ptwt.wavedec2(x[k,:,:,:], pywt.Wavelet(self.wavelet),
                                        level=2, mode="constant")
            # 2nd order DWT
            forward_output_ = ptwt.waverec2([coefficients[0], coefficients[1]], pywt.Wavelet(self.wavelet))
            
            # permute dim - 1xcxhxw
            FORWARD_OUTPUT_.append(torch.permute(forward_output_, [1,0,2,3]))
        
        FORWARD_OUTPUT_ = torch.cat(FORWARD_OUTPUT_, dim = 0)
        return FORWARD_OUTPUT_

In [None]:
def wavelet_pooling_hook(module, inp, out):
    '''
    inp - gradient output from the layer
    out - gradient inp to layer 
    '''

#     print('gradient out at pooling layer ...')
#     print(out[0].shape)
#     grad_output = out[0].squeeze().permute(1,2,0).detach().numpy()
#     plt.imshow(grad_output/255)
#     plt.show()
    
    # Computing gradient using paper.
    bs = out[0].size()[0]
    BACKWARD_OUTPUT_ = []

    # loop over input as batching not supported
    for k in range(bs):
        ## 1. 1st order DWT
        coefficients = ptwt.wavedec2(torch.squeeze(out[0][k]), pywt.Wavelet("haar"),
                                        level=1, mode="constant")
        ## 2. upsample subbands
        # LL
        upsampled_subbands_ = coefficients
        
        # LH, HL, HH
        upsampled_subbands_.append([])
        for k in range(len(coefficients[1])):
            upsampled_subbands_[-1].append(module.upsample_(coefficients[1][k]))
        upsampled_subbands_[-1] = tuple(upsampled_subbands_[-1])  

        ## 3. IDWT
        backward_output_ = ptwt.waverec2(upsampled_subbands_, pywt.Wavelet("haar"))
        BACKWARD_OUTPUT_.append(backward_output_.permute(1,0,2,3))
    
    BACKWARD_OUTPUT_ = torch.cat(BACKWARD_OUTPUT_, dim = 0)
    BACKWARD_OUTPUT_SHAPE_ = BACKWARD_OUTPUT_.shape
    INPUT_SIZE_CAHCED_ = inp[0].size()
        
#     print('gradient in at pooling layer ...')
#     print(BACKWARD_OUTPUT_.shape)
#     grad_output = BACKWARD_OUTPUT_.squeeze().permute(1,2,0).detach().numpy()
#     plt.imshow(grad_output/255)
#     plt.show()
    
    return [BACKWARD_OUTPUT_]

In [None]:
# Test model
class Model(nn.Module):
    def __init__(self, c):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(c,3,1,1)
        self.pool = WaveletPooling('haar')
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.pool(self.conv1(x)))
        return x

In [None]:
def hook_fn(m,i,o):
    print('gradient out at CNN layer ...')
    print(o[0].shape)
    grad_output = o[0].squeeze().permute(1,2,0).detach().numpy()
    plt.imshow(grad_output/255)
    plt.show()

In [None]:
import scipy.misc
import torch.optim as optim
from matplotlib import pyplot as plt

In [None]:
c = 3
b = 1
input_ = torch.tensor(np.transpose(scipy.misc.face(), [2, 0, 1]).astype(np.float64))[None,:,:,:] #torch.ones(b,c,4,4)
input_.requires_grad = True
b,c,h,w = input_.size()

m = Model(c)
m.pool.register_full_backward_hook(wavelet_pooling_hook)
m.conv1.register_full_backward_hook(hook_fn)
output_ = m(input_.float())

(torch.mul(output_, output_).sum()/2).backward()

In [None]:
img_input = input_.squeeze().permute(1,2,0).detach().numpy()
print(img_input.shape)
plt.imshow(img_input/255)
plt.show()

img_output = output_.squeeze().permute(1,2,0).detach().numpy()
print(img_output.shape)
plt.imshow(img_output/255)

## Analysing the output shape from pooling layer

In [None]:
dd

In [29]:
class WaveletPooling(nn.Module):
    def __init__(self, wavelet):
        super(WaveletPooling,self).__init__()
        self.upsample_ = nn.Upsample(scale_factor=2, mode='nearest')
        self.wavelet = wavelet
    
    def forward(self, x):        
        bs = x.size()[0]
        FORWARD_OUTPUT_ = []
        
        # loop over input as batching not supported
        for k in range(bs):
            # coeffiecients - cx1xhxw
            coefficients = ptwt.wavedec2(x[k,:,:-1,:-1], pywt.Wavelet(self.wavelet),
                                        level=2, mode="constant")
            # 2nd order DWT
            forward_output_ = ptwt.waverec2([coefficients[0], coefficients[1]], pywt.Wavelet(self.wavelet))
            
            # permute dim - 1xcxhxw
            FORWARD_OUTPUT_.append(torch.permute(forward_output_, [1,0,2,3]))
        
        FORWARD_OUTPUT_ = torch.cat(FORWARD_OUTPUT_, dim = 0)
        return FORWARD_OUTPUT_

In [30]:
def wavelet_pooling_hook(module, inp, out):
    '''
    inp - gradient output from the layer
    out - gradient inp to layer 
    '''

#     print('gradient out at pooling layer ...')
#     print(out[0].shape)
#     grad_output = out[0].squeeze().permute(1,2,0).detach().numpy()
#     plt.imshow(grad_output/255)
#     plt.show()
    
    # Computing gradient using paper.
    bs = out[0].size()[0]
    BACKWARD_OUTPUT_ = []

    # loop over input as batching not supported
    for k in range(bs):
        ## 1. 1st order DWT
        coefficients = ptwt.wavedec2(torch.squeeze(out[0][k]), pywt.Wavelet("haar"),
                                        level=1)#, mode="constant")
        ## 2. upsample subbands
        # LL
        upsampled_subbands_ = coefficients
        
        # LH, HL, HH
        upsampled_subbands_.append([])
        for k in range(len(coefficients[1])):
            upsampled_subbands_[-1].append(module.upsample_(coefficients[1][k]))
        upsampled_subbands_[-1] = tuple(upsampled_subbands_[-1])  

        ## 3. IDWT
        backward_output_ = ptwt.waverec2(upsampled_subbands_, pywt.Wavelet("haar"))
        BACKWARD_OUTPUT_.append(backward_output_.permute(1,0,2,3))
    
    BACKWARD_OUTPUT_ = torch.cat(BACKWARD_OUTPUT_, dim = 0)
    
    cw = 4-inp[0].shape[2]%4
    ch = 4-inp[0].shape[3]%4
    
    if cw != 4:
        BACKWARD_OUTPUT_ = BACKWARD_OUTPUT_[:,:,:-cw,:]
    
    if ch != 4:
        BACKWARD_OUTPUT_ = BACKWARD_OUTPUT_[:,:,:,:-ch]
        
#     BACKWARD_OUTPUT_SHAPE_ = BACKWARD_OUTPUT_.shape
#     INPUT_SIZE_CAHCED_ = inp[0].size()
#     print('gradient in at pooling layer ...')
#     print(BACKWARD_OUTPUT_.shape)
#     grad_output = BACKWARD_OUTPUT_.squeeze().permute(1,2,0).detach().numpy()
#     plt.imshow(grad_output/255)
#     plt.show()
    BACKWARD_OUTPUT_SHAPE_ = BACKWARD_OUTPUT_.shape
    print(inp[0].shape, out[0].shape, cw, ch, BACKWARD_OUTPUT_SHAPE_)
    
    return [BACKWARD_OUTPUT_]

In [31]:
class ModelPooling(nn.Module):
    def __init__(self, N = 1, pooling = 'waveler'):
        super(ModelPooling, self).__init__()
        if pooling == 'wavelet':
            self.pool = WaveletPooling('haar')
        else:
            self.pool = nn.MaxPool2d(3,2)
        self.N = N
    
    def forward(self, x):
        print(x.shape)
        for i in range(self.N):
            x = self.pool(x)
            print(x.shape)
        return x

In [32]:
c = 3
b = 1
w = 288
h = 64
# input_ = torch.tensor(np.transpose(scipy.misc.face(), [2, 0, 1]).astype(np.float64))[None,:,:,:] #torch.ones(b,c,4,4)
input_ = torch.randn(b,c,w,h)
input_.requires_grad = True
b,c,h,w = input_.size()


print('Forward pass ....')
# pooling = "maxpool"
pooling = "wavelet"
m = ModelPooling(N = 5, pooling=pooling)
# if pooling == "wavelet":
#     m.pool.register_full_backward_hook(wavelet_pooling_hook)
output_ = m(input_.float())


print('\nBackward pass ....')
(torch.mul(output_, output_).sum()/2).backward()

Forward pass ....
torch.Size([1, 3, 288, 64])
torch.Size([1, 3, 144, 32])
torch.Size([1, 3, 72, 16])
torch.Size([1, 3, 36, 8])
torch.Size([1, 3, 18, 4])
torch.Size([1, 3, 10, 2])

Backward pass ....
