In [9]:
import torch

# import pywt
import numpy as np
import os
import cv2
import argparse
import json


class HaarForward(torch.nn.Module):
    """
    Performs a 2d DWT Forward decomposition of an image using Haar Wavelets
    """
    def __init__(self, alpha):
        super(HaarForward, self).__init__()
        self.alpha = alpha

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Performs a 2d DWT Forward decomposition of an image using Haar Wavelets

        Arguments:
            x (torch.Tensor): input tensor of shape [b, c, h, w]

        Returns:
            out (torch.Tensor): output tensor of shape [b, c * 4, h / 2, w / 2]
        """

        ll = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] + x[:,:,1::2,0::2] + x[:,:,1::2,1::2])
        lh = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] - x[:,:,1::2,0::2] - x[:,:,1::2,1::2])
        hl = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] + x[:,:,1::2,0::2] - x[:,:,1::2,1::2])
        hh = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] - x[:,:,1::2,0::2] + x[:,:,1::2,1::2])
        return torch.cat([ll,lh,hl,hh], 1)


In [10]:
class BlurDetectionModel(torch.nn.Module):
    def __init__(self):
        super(BlurDetectionModel, self).__init__()
        self.haar = HaarForward(0.5)
        self.threshold = torch.tensor(35.0);
        
    def forward(self, image):
        # Assuming 'image' is a PyTorch tensor and 'threshold' is a float
        # Convert the PyTorch tensor to a NumPy array
        
        blur_value1, blur_value2 = self.blur_detect(image)
        # max_brightness = torch.max(image)
        # min_brightness = torch.max(torch.min(image), torch.tensor(10.0))
        # print(max_brightness, min_brightness)
        
        # dynamic_range = torch.round(torch.log2(max_brightness) - torch.log2(min_brightness), decimals=2)
        # Convert the blur values back to PyTorch tensors
        blur_tensor1 = torch.tensor(blur_value1, dtype=torch.float32)
        blur_tensor2 = torch.tensor(blur_value2, dtype=torch.float32)
        # dynamic_range_tensor = torch.tensor(dynamic_range, dtype=torch.float32)
        
        return blur_tensor1, blur_tensor2 #, dynamic_range_tensor
    def backward(self):
        pass

    def processHarrRes(self, _LL1):
        LL1, LH1, HL1, HH1 = _LL1[:,0,:,:], _LL1[:,1,:,:], _LL1[:,2,:,:], _LL1[:,3,:,:]
        LL1 = LL1.unsqueeze(0)
        LH1 = LH1.squeeze(0)
        HL1 = HL1.squeeze(0)
        HH1 = HH1.squeeze(0)
        return LL1, LH1, HL1, HH1

    def blur_detect(self,Y):
        M, N = Y.shape

        # Crop input image to be 3 divisible by 2
        Y = Y[0:int(M/16)*16, 0:int(N/16)*16]
        # add batch dimension
        Y = Y.unsqueeze(0).unsqueeze(0)
        # Step 1, compute Haar wavelet of input image
        _LL1 = self.haar(Y)
        LL1, LH1, HL1, HH1 = self.processHarrRes(_LL1)
        # Another application of 2D haar to LL1
        _LL2= self.haar(LL1) 
        LL2, LH2, HL2, HH2 = self.processHarrRes(_LL2)
        # Another application of 2D haar to LL2
        _LL3 = self.haar(LL2)
        LL3 , LH3, HL3, HH3 = self.processHarrRes(_LL3)
        
        LL1 = LL1.squeeze(0).squeeze(0)
        LL2 = LL2.squeeze(0).squeeze(0)
        LL3 = LL2.squeeze(0).squeeze(0)
        # Construct the edge map in each scale Step 2
        E1 = torch.sqrt(torch.pow(LH1, 2)+torch.pow(HL1, 2)+torch.pow(HH1, 2))
        E2 = torch.sqrt(torch.pow(LH2, 2)+torch.pow(HL2, 2)+torch.pow(HH2, 2))
        E3 = torch.sqrt(torch.pow(LH3, 2)+torch.pow(HL3, 2)+torch.pow(HH3, 2))
        
        M1, N1 = E1.shape


        # Sliding window size level 1
        sizeM1 = 8
        sizeN1 = 8
        
        # Sliding windows size level 2
        sizeM2 = int(sizeM1/2)
        sizeN2 = int(sizeN1/2)
        
        # Sliding windows size level 3
        sizeM3 = int(sizeM2/2)
        sizeN3 = int(sizeN2/2)
        
        # Number of edge maps, related to sliding windows size
        N_iter = int((M1/sizeM1)*(N1/sizeN1))
        
        Emax1 = torch.zeros((N_iter))
        Emax2 = torch.zeros((N_iter))
        Emax3 = torch.zeros((N_iter))
        
        
        count = 0
        
        # Sliding windows index of level 1
        x1 = 0
        y1 = 0
        # Sliding windows index of level 2
        x2 = 0
        y2 = 0
        # Sliding windows index of level 3
        x3 = 0
        y3 = 0
        
        # Sliding windows limit on horizontal dimension
        Y_limit = N1-sizeN1
        
        while count < N_iter:
            # Get the maximum value of slicing windows over edge maps 
            # in each level
            Emax1[count] = torch.max(E1[x1:x1+sizeM1,y1:y1+sizeN1])
            Emax2[count] = torch.max(E2[x2:x2+sizeM2,y2:y2+sizeN2])
            Emax3[count] = torch.max(E3[x3:x3+sizeM3,y3:y3+sizeN3])
            
            # if sliding windows ends horizontal direction
            # move along vertical direction and resets horizontal
            # direction
            if y1 == Y_limit:
                x1 = x1 + sizeM1
                y1 = 0
                
                x2 = x2 + sizeM2
                y2 = 0
                
                x3 = x3 + sizeM3
                y3 = 0
                
                count += 1
            
            # windows moves along horizontal dimension
            else:
                    
                y1 = y1 + sizeN1
                y2 = y2 + sizeN2
                y3 = y3 + sizeN3
                count += 1
        
        # Step 3
        EdgePoint1 = Emax1 > self.threshold;
        EdgePoint2 = Emax2 > self.threshold;
        EdgePoint3 = Emax3 > self.threshold;
        
        # Rule 1 Edge Pojnts
        EdgePoint = EdgePoint1 + EdgePoint2 + EdgePoint3
        
        n_edges = EdgePoint.shape[0]
        
        # Rule 2 Dirak-Structure or Astep-Structure
        DAstructure = (Emax1[EdgePoint] > Emax2[EdgePoint]) * (Emax2[EdgePoint] > Emax3[EdgePoint]);
        
        # Rule 3 Roof-Structure or Gstep-Structure
        
        RGstructure = torch.zeros((n_edges))

        for i in range(n_edges):
        
            if EdgePoint[i] == 1:
            
                if Emax1[i] < Emax2[i] and Emax2[i] < Emax3[i]:
                
                    RGstructure[i] = 1
                    
        # Rule 4 Roof-Structure
        
        RSstructure = torch.zeros((n_edges))

        for i in range(n_edges):
        
            if EdgePoint[i] == 1:
            
                if Emax2[i] > Emax1[i] and Emax2[i] > Emax3[i]:
                
                    RSstructure[i] = 1

        # Rule 5 Edge more likely to be in a blurred image 

        BlurC = torch.zeros((n_edges));

        for i in range(n_edges):
        
            if RGstructure[i] == 1 or RSstructure[i] == 1:
            
                if Emax1[i] < self.threshold:
                
                    BlurC[i] = 1                        
            
        # Step 6
        Per = torch.sum(DAstructure)/torch.sum(EdgePoint)
        
        # Step 7
        if (torch.sum(RGstructure) + torch.sum(RSstructure)) == 0:
            BlurExtent = torch.tensor(100)
        else:
            BlurExtent = torch.sum(BlurC) / (torch.sum(RGstructure) + torch.sum(RSstructure))
        
        return Per, BlurExtent



In [11]:
import time
# Example usage
aspect_ratios = ['1:1', '4:3', '14:9', '16:10', '16:9', '37:20', '2:1', '21:9']                
def get_resolution(aspect_ratio):
    if(aspect_ratio == '1:1'):
        return (600, 600)
    elif(aspect_ratio == '4:3'):
        return (800, 600)
    elif(aspect_ratio == '14:9'):
        return (1400, 900)
    elif(aspect_ratio == '16:10'):
        return (1280, 800)
    elif(aspect_ratio == '16:9'):
        return (1280, 720)
    elif(aspect_ratio == '37:20'):
        return (1850, 1000)
    elif(aspect_ratio == '2:1'):
        return (1200, 600)
    elif(aspect_ratio == '21:9'):
        return (2100, 900)
    else:
        return (1280, 720)

In [12]:
def find_images(input_dir):
    extensions = [".jpg", ".png", ".jpeg"]

    for root, dirs, files in os.walk(input_dir):
        for file in files:
            if os.path.splitext(file)[1].lower() in extensions:
                yield os.path.join(root, file)

In [13]:

# Create an instance of the custom model
model = BlurDetectionModel()
# model.eval()
model.training = False
# for ar in aspect_ratios:
# input_image = torch.randn(1080, 720)  # Replace with your actual image dimensions
torchtraced_model = torch.jit.script(model)
# torchtraced_model.eval()
# print('diff', traced_out[0] - blur_result1, traced_out[1] - blur_result2)
imageFiles = ['images/blur/IMG_007097.jpg', 'images/blur/IMG_008987.jpg']
res = find_images('images/hdr_test_case/')

for path in res:
    img = cv2.imread(path)
    Y = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)[:,:,0]
    # _x = Y
    # print(_x.shape)
    # cnt = 0
    # for i in range(0, _x.shape[0]):
        # for j in range(0, _x.shape[1]):
            # if(_x[i,j] < 1):
                # cnt = cnt + 1;    
    # print(cnt,  max(_x.flatten()), min(_x.flatten()))
    
    input_image = torch.tensor(Y, dtype=torch.float32)
    blur_result = model(input_image)
    traced_out = torchtraced_model(input_image)

    print(path, traced_out , blur_result)


  blur_tensor1 = torch.tensor(blur_value1, dtype=torch.float32)
  blur_tensor2 = torch.tensor(blur_value2, dtype=torch.float32)


images/hdr_test_case/4.jpeg (tensor(0.0036), tensor(0.4842)) (tensor(0.0036), tensor(0.4842))
images/hdr_test_case/5.jpeg (tensor(0.0068), tensor(0.3868)) (tensor(0.0068), tensor(0.3868))
images/hdr_test_case/2.jpg (tensor(0.0075), tensor(0.2967)) (tensor(0.0075), tensor(0.2967))
images/hdr_test_case/3.jpg (tensor(0.0102), tensor(0.3480)) (tensor(0.0102), tensor(0.3480))
images/hdr_test_case/1.jpg (tensor(0.0119), tensor(0.3380)) (tensor(0.0119), tensor(0.3380))


In [14]:
from torch.utils.mobile_optimizer import optimize_for_mobile

torchtraced_model.to('mps')
optimized_model = optimize_for_mobile(torchtraced_model, backend='Metal')
# optimized_model.training = False
# optimized_model.eval()
print(torch.jit.export_opnames(optimized_model))
optimized_model._save_for_lite_interpreter('haar_scripted_metal.pth')
torch.jit.save(optimized_model, 'haar_scripted_opt.pt')
torch.jit.save(torchtraced_model, 'haar_scripted.pt')

['aten::Bool.Tensor', 'aten::FloatImplicit', 'aten::Int.float', 'aten::__getitem__.t', 'aten::add.Tensor', 'aten::add.int', 'aten::cat', 'aten::copy_', 'aten::div.Tensor', 'aten::div.int', 'aten::eq.Scalar', 'aten::eq.int', 'aten::gt.Tensor', 'aten::index.Tensor', 'aten::lt.Tensor', 'aten::lt.int', 'aten::max', 'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::mul.float', 'aten::mul.int', 'aten::pow.Tensor_Scalar', 'aten::select.int', 'aten::size', 'aten::slice.Tensor', 'aten::sqrt', 'aten::squeeze.dim', 'aten::sub.Tensor', 'aten::sub.int', 'aten::sum', 'aten::tensor.float', 'aten::tensor.int', 'aten::unsqueeze', 'aten::zeros']


In [15]:
optimized_model.to('mps')

RecursiveScriptModule(original_name=BlurDetectionModel)

In [16]:
import coremltools as ct 

coreml_model = ct.convert(
    torchtraced_model, convert_to='mlprogram', inputs=[ct.TensorType(shape=(600, 600))])


Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  11%|â–ˆ         | 46/415 [00:00<00:00, 7815.05 ops/s]


ValueError: Torch var alpha.2 not found in context 

In [56]:
optimized_model.training = True