In [19]:
from qiskit import QuantumCircuit, QuantumRegister
#from qiskit_aer import AerSimulator
from qiskit.quantum_info import Statevector, Operator
import torch
#import torch.nn.functional as F
import numpy as np
import math
from common import SoftThresholding, find_min_power

In [20]:
# https://github.com/vittpall/WT-Convolution/blob/haar_wavelet/ImageNet1K/WTHaarResNet50x3/layers/WTHaar.py
# credit where credit is due

In [21]:
# finds the closest int larger than input that is a power of 2
def next_power_of_2(n: int) -> int:
    return 1 << (n - 1).bit_length() # bit shift

# input int and output np.ndarray
def haar_matrix_builder(n: int) -> np.ndarray:

    # base case and dim check
    assert n > 0 and (n & (n - 1)) == 0, "n must be power of 2"
    if n == 1:
        return np.array([[1.0]], dtype=float)

    half = n // 2
    #print(type(half))
    norm = 1.0 / np.sqrt(2.0)

    # build the a haar level
    B = np.zeros((n, n), dtype=float)
    for k in range(half):
        col1 = 2 * k
        col2 = 2 * k + 1
        B[k, col1] = norm
        B[k, col2] = norm
        B[half+k, col1] = norm
        B[half+k, col2] = -norm

    # recursive call to build haar levels
    W_half = haar_matrix_builder(half)

    # after hitting base case multiply all haar levels
    R = np.block([
        [W_half, np.zeros((half, half), dtype=float)],
        [np.zeros((half, half), dtype=float), np.eye(half, dtype=float)]
    ])
    return R @ B

# input np.ndarray and output np.ndarray
def haar_transform_2d_classical(x: np.ndarray, inverse=False) -> np.ndarray:
    # build the haar transform classically
    H, W = x.shape
    WH = haar_matrix_builder(H)
    WW = haar_matrix_builder(W)

    # inverse if true
    if inverse:
        WH = WH.T
        WW = WW.T

    # apply the haar transform
    return WH @ x @ WW.T

In [22]:
def haar_transform_2d_qiskit(x, inverse=False):
    # assert checks for batch and channel = 1
    B, C, H, W = x.shape
    assert B == 1 and C == 1, "function assumes B and C are = to 1"

    # pad h and w to the next power of 2
    H_pad_size = next_power_of_2(H)
    W_pad_size = next_power_of_2(W)
    X = np.zeros((H_pad_size, W_pad_size), dtype=float)
    X[:H, :W] = x[0, 0] # new padded image

    psi = X.reshape(-1) # flatten
    norm = np.linalg.norm(psi)
    if norm == 0:
        raise ValueError("Input is all zeros; cannot normalize for amplitude encoding.")
    amps = psi / norm # flatten normalized list of probability amplitudes

    # find num qubits
    n_row = int(math.log2(H_pad_size))
    n_col = int(math.log2(W_pad_size))
    n_qubits = n_row + n_col

    # qubit numbers
    col_qubits = list(range(n_col))
    row_qubits = list(range(n_col, n_col+n_row))

    # haar matrices inverse if true
    transform_H = haar_matrix_builder(H_pad_size)
    transform_W = haar_matrix_builder(W_pad_size)
    if inverse:
        transform_H = transform_H.T
        transform_W = transform_W.T
        
    # turn haar transform into operators
    U_row = Operator(transform_H)
    U_col = Operator(transform_W)

    # create q-circuit and encode amplitudes
    qc = QuantumCircuit(n_qubits)
    qc.initialize(amps, list(range(n_qubits)))

    # put the haar transforms into the circuit
    qc.append(U_col, col_qubits)
    qc.append(U_row, row_qubits)

    # extract amplitudes after applying haar
    sv = Statevector.from_instruction(qc).data

    # get the normalized images back to shape and compare with classical implementation
    Yq = sv.reshape(H_pad_size, W_pad_size)
    Yc = haar_transform_2d_classical(X) / norm

    return Yq, Yc, qc

In [23]:
# all i changed in this CNN is name of the haar transform function
class HWTConv2D(torch.nn.Module):
    """2D Haar Wavelet Conv layer - HWT-MA-Net"""
    def __init__(self, height, width, in_channels, out_channels, pods=1, residual=True):
        super().__init__()
        self.height = height
        self.width = width
        self.height_pad = find_min_power(self.height)
        self.width_pad = find_min_power(self.width)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pods = pods
        
        # 1x1 Convolution, Channel mixing per pod
        self.conv = torch.nn.ModuleList([torch.nn.Conv2d(in_channels, out_channels, 1, bias=False) for _ in range(self.pods)])
        
        # Scaling parameters per pod
        self.v = torch.nn.ParameterList([
            torch.nn.Parameter(torch.rand(1, 1, self.height_pad, self.width_pad))
            for _ in range(pods)
        ])
        
        # Soft-thresholding per pod
        self.ST = torch.nn.ModuleList([
            SoftThresholding((self.height_pad, self.width_pad)) for _ in range(pods)
        ])
        self.residual = residual
        
    def forward(self, x):
        B, C_in, height, width = x.shape
        if height != self.height or width != self.width:
            raise Exception(f'({height}, {width})!=({self.height}, {self.width})')
        
        # Pad to power of 2
        pad_h = self.height_pad - height
        pad_w = self.width_pad - width
        if pad_h > 0 or pad_w > 0:
            f0 = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
        else:
            f0 = x
        
        # 2D Haar Transform
        f1 = haar_transform_2d_qiskit(f0)
        
        outputs = []
        for i in range(self.pods):
            f3 = f1 * self.v[i]  # Scaling
            f4 = self.conv[i](f3)  # 1x1 Conv
            f5 = self.ST[i](f4)  # Soft-threshold
            outputs.append(f5)
        
        # Sum paths
        f6 = torch.stack(outputs, dim=0).sum(dim=0)
        
        # Inverse Haar
        f7 = haar_transform_2d_qiskit(f6, inverse=True)
        
        # Crop
        y = f7[:, :, :height, :width]
        
        if self.residual:
            y = y + x
        return y

In [24]:
if __name__ == "__main__":
    # example here is 32x32
    np.random.seed(0)
    X = np.random.rand(1, 1, 28, 28).astype(float)

    Yq, Yc, qc = haar_transform_2d_qiskit(X, inverse=False)

    # error between quantum and classical implementation of haar transform
    err = np.max(np.abs(np.real(Yq) - Yc))
    print("max abs error:", err)

    # print matrix patch or full
    n_patch = 4
    print("Yq (real)\n", np.real(Yq[:n_patch, :n_patch]))
    print("Yc\n", Yc[:n_patch, :n_patch])

max abs error: 3.3306690738754696e-16
Yq (real)
 [[ 0.75362786  0.09671871  0.00076583  0.14984047]
 [ 0.10176302  0.00197198  0.00101918  0.03519545]
 [ 0.00943346 -0.00157904 -0.00747259 -0.00415681]
 [ 0.1660058   0.06420613  0.02101443  0.01952859]]
Yc
 [[ 0.75362786  0.09671871  0.00076583  0.14984047]
 [ 0.10176302  0.00197198  0.00101918  0.03519545]
 [ 0.00943346 -0.00157904 -0.00747259 -0.00415681]
 [ 0.1660058   0.06420613  0.02101443  0.01952859]]
