In [14]:
import torch

# import pywt
import numpy as np
import os
import cv2
import argparse
import json


class HaarForward(torch.nn.Module):
    """
    Performs a 2d DWT Forward decomposition of an image using Haar Wavelets
    """
    def __init__(self, alpha):
        super(HaarForward, self).__init__()
        self.alpha = alpha

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Performs a 2d DWT Forward decomposition of an image using Haar Wavelets

        Arguments:
            x (torch.Tensor): input tensor of shape [b, c, h, w]

        Returns:
            out (torch.Tensor): output tensor of shape [b, c * 4, h / 2, w / 2]
        """

        ll = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] + x[:,:,1::2,0::2] + x[:,:,1::2,1::2])
        lh = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] - x[:,:,1::2,0::2] - x[:,:,1::2,1::2])
        hl = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] + x[:,:,1::2,0::2] - x[:,:,1::2,1::2])
        hh = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] - x[:,:,1::2,0::2] + x[:,:,1::2,1::2])
        return torch.cat([ll,lh,hl,hh], 1)

class BlurDetectionModel(torch.nn.Module):
    def __init__(self, threshold):
        super(BlurDetectionModel, self).__init__()
        self.threshold = threshold
        self.haar = HaarForward(0.5)

    def forward(self, image):
        # Assuming 'image' is a PyTorch tensor and 'threshold' is a float
        # Convert the PyTorch tensor to a NumPy array

        blur_value1, blur_value2 = self.blur_detect(image, self.threshold)

        # Convert the blur values back to PyTorch tensors
        blur_tensor1 = torch.tensor(blur_value1, dtype=torch.float32)
        blur_tensor2 = torch.tensor(blur_value2, dtype=torch.float32)

        return blur_tensor1, blur_tensor2
    
    def processHarrRes(self, _LL1):
        LL1, LH1, HL1, HH1 = _LL1[:,0,:,:], _LL1[:,1,:,:], _LL1[:,2,:,:], _LL1[:,3,:,:]
        LL1 = LL1.unsqueeze(0)
        LH1 = LH1.squeeze(0)
        HL1 = HL1.squeeze(0)
        HH1 = HH1.squeeze(0)
        return LL1, LH1, HL1, HH1

    def blur_detect(self,Y, threshold):
        M, N = Y.shape

        # Crop input image to be 3 divisible by 2
        Y = Y[0:int(M/16)*16, 0:int(N/16)*16]
        # add batch dimension
        Y = Y.unsqueeze(0).unsqueeze(0)
        # Step 1, compute Haar wavelet of input image
        _LL1 = self.haar(Y)
        LL1, LH1, HL1, HH1 = self.processHarrRes(_LL1)
        # Another application of 2D haar to LL1
        _LL2= self.haar(LL1) 
        LL2, LH2, HL2, HH2 = self.processHarrRes(_LL2)
        # Another application of 2D haar to LL2
        _LL3 = self.haar(LL2)
        LL3 , LH3, HL3, HH3 = self.processHarrRes(_LL3)
        
        LL1 = LL1.squeeze(0).squeeze(0)
        LL2 = LL2.squeeze(0).squeeze(0)
        LL3 = LL2.squeeze(0).squeeze(0)
        # Construct the edge map in each scale Step 2
        E1 = torch.sqrt(torch.pow(LH1, 2)+torch.pow(HL1, 2)+torch.pow(HH1, 2))
        E2 = torch.sqrt(torch.pow(LH2, 2)+torch.pow(HL2, 2)+torch.pow(HH2, 2))
        E3 = torch.sqrt(torch.pow(LH3, 2)+torch.pow(HL3, 2)+torch.pow(HH3, 2))
        
        M1, N1 = E1.shape


        # Sliding window size level 1
        sizeM1 = 8
        sizeN1 = 8
        
        # Sliding windows size level 2
        sizeM2 = int(sizeM1/2)
        sizeN2 = int(sizeN1/2)
        
        # Sliding windows size level 3
        sizeM3 = int(sizeM2/2)
        sizeN3 = int(sizeN2/2)
        
        # Number of edge maps, related to sliding windows size
        N_iter = int((M1/sizeM1)*(N1/sizeN1))
        
        Emax1 = torch.zeros((N_iter))
        Emax2 = torch.zeros((N_iter))
        Emax3 = torch.zeros((N_iter))
        
        
        count = 0
        
        # Sliding windows index of level 1
        x1 = 0
        y1 = 0
        # Sliding windows index of level 2
        x2 = 0
        y2 = 0
        # Sliding windows index of level 3
        x3 = 0
        y3 = 0
        
        # Sliding windows limit on horizontal dimension
        Y_limit = N1-sizeN1
        
        while count < N_iter:
            # Get the maximum value of slicing windows over edge maps 
            # in each level
            Emax1[count] = torch.max(E1[x1:x1+sizeM1,y1:y1+sizeN1])
            Emax2[count] = torch.max(E2[x2:x2+sizeM2,y2:y2+sizeN2])
            Emax3[count] = torch.max(E3[x3:x3+sizeM3,y3:y3+sizeN3])
            
            # if sliding windows ends horizontal direction
            # move along vertical direction and resets horizontal
            # direction
            if y1 == Y_limit:
                x1 = x1 + sizeM1
                y1 = 0
                
                x2 = x2 + sizeM2
                y2 = 0
                
                x3 = x3 + sizeM3
                y3 = 0
                
                count += 1
            
            # windows moves along horizontal dimension
            else:
                    
                y1 = y1 + sizeN1
                y2 = y2 + sizeN2
                y3 = y3 + sizeN3
                count += 1
        
        # Step 3
        EdgePoint1 = Emax1 > threshold;
        EdgePoint2 = Emax2 > threshold;
        EdgePoint3 = Emax3 > threshold;
        
        # Rule 1 Edge Pojnts
        EdgePoint = EdgePoint1 + EdgePoint2 + EdgePoint3
        
        n_edges = EdgePoint.shape[0]
        
        # Rule 2 Dirak-Structure or Astep-Structure
        DAstructure = (Emax1[EdgePoint] > Emax2[EdgePoint]) * (Emax2[EdgePoint] > Emax3[EdgePoint]);
        
        # Rule 3 Roof-Structure or Gstep-Structure
        
        RGstructure = torch.zeros((n_edges))

        for i in range(n_edges):
        
            if EdgePoint[i] == 1:
            
                if Emax1[i] < Emax2[i] and Emax2[i] < Emax3[i]:
                
                    RGstructure[i] = 1
                    
        # Rule 4 Roof-Structure
        
        RSstructure = torch.zeros((n_edges))

        for i in range(n_edges):
        
            if EdgePoint[i] == 1:
            
                if Emax2[i] > Emax1[i] and Emax2[i] > Emax3[i]:
                
                    RSstructure[i] = 1

        # Rule 5 Edge more likely to be in a blurred image 

        BlurC = torch.zeros((n_edges));

        for i in range(n_edges):
        
            if RGstructure[i] == 1 or RSstructure[i] == 1:
            
                if Emax1[i] < threshold:
                
                    BlurC[i] = 1                        
            
        # Step 6
        Per = torch.sum(DAstructure)/torch.sum(EdgePoint)
        
        # Step 7
        if (torch.sum(RGstructure) + torch.sum(RSstructure)) == 0:
            
            BlurExtent = torch.tensor(100)
        else:
            BlurExtent = torch.sum(BlurC) / (torch.sum(RGstructure) + torch.sum(RSstructure))
        
        return Per, BlurExtent



# Create an instance of the custom model
model = BlurDetectionModel(torch.tensor(0.2))

import time
# Example usage
aspect_ratios = ['1:1', '4:3', '14:9', '16:10', '16:9', '37:20', '2:1', '21:9']                
def get_resolution(aspect_ratio):
    if(aspect_ratio == '1:1'):
        return (600, 600)
    elif(aspect_ratio == '4:3'):
        return (800, 600)
    elif(aspect_ratio == '14:9'):
        return (1400, 900)
    elif(aspect_ratio == '16:10'):
        return (1280, 800)
    elif(aspect_ratio == '16:9'):
        return (1280, 720)
    elif(aspect_ratio == '37:20'):
        return (1850, 1000)
    elif(aspect_ratio == '2:1'):
        return (1200, 600)
    elif(aspect_ratio == '21:9'):
        return (2100, 900)
    else:
        return (1280, 720)
    
_st = time.time()

# for ar in aspect_ratios:
res = get_resolution(aspect_ratios[0])
input_image = torch.randn(res[0], res[1])  # Replace with your actual image dimensions
# blur_threshold = 0.5  # Replace with your desired threshold
print(res)
blur_result1, blur_result2 = model(input_image)
# input_image = torch.randn(1080, 720)  # Replace with your actual image dimensions
torchtraced_model = torch.jit.script(model)
traced_out = torchtraced_model(input_image)
print('diff', traced_out[0] - blur_result1, traced_out[1] - blur_result2)
for i in range(5):
    input_image = torch.randn(get_resolution(aspect_ratios[1]))  # Replace with your actual image dimensions
    blur_result1, blur_result2 = model(input_image)
    traced_out = torchtraced_model(input_image)

    print('diff', traced_out[0] - blur_result1, traced_out[1] - blur_result2)

torch.jit.save(torchtraced_model, 'haar_scripted_'+str(res[0])+'_'+str(res[1])+'.pt')

(600, 600)


  blur_tensor1 = torch.tensor(blur_value1, dtype=torch.float32)
  blur_tensor2 = torch.tensor(blur_value2, dtype=torch.float32)


diff tensor(0.) tensor(0.)
diff tensor(0.) tensor(0.)
diff tensor(0.) tensor(0.)
diff tensor(0.) tensor(0.)
diff tensor(0.) tensor(0.)
diff tensor(0.) tensor(0.)
