In [None]:
from qiskit import QuantumCircuit, QuantumRegister
#from qiskit_aer import AerSimulator
from qiskit.quantum_info import Statevector, Operator
import torch
#import torch.nn.functional as F
import numpy as np
import math
from common import SoftThresholding, find_min_power

np.set_printoptions(suppress=True, precision=8)

In [None]:
# finds the closest int larger than input that is a power of 2
def next_power_of_2(n: int) -> int:
    return 1 << (n - 1).bit_length() # bit shift

# input int and output np.ndarray
def haar_matrix_builder(n: int) -> np.ndarray:

    # base case and dim check
    assert n > 0 and (n & (n - 1)) == 0, "n must be power of 2"
    if n == 1:
        return np.array([[1.0]], dtype=float)

    half = n // 2
    #print(type(half))
    norm = 1.0 / np.sqrt(2.0)

    # build the a haar level
    B = np.zeros((n, n), dtype=float)
    for k in range(half):
        col1 = 2 * k
        col2 = 2 * k + 1
        B[k, col1] = norm
        B[k, col2] = norm
        B[half+k, col1] = norm
        B[half+k, col2] = -norm

    # recursive call to build haar levels
    W_half = haar_matrix_builder(half)

    # after hitting base case multiply all haar levels
    R = np.block([
        [W_half, np.zeros((half, half), dtype=float)],
        [np.zeros((half, half), dtype=float), np.eye(half, dtype=float)]
    ])
    return R @ B

# input np.ndarray and output np.ndarray
def haar_transform_2d_classical(x: np.ndarray, inverse=False) -> np.ndarray:
    # build the haar transform classically
    H, W = x.shape
    WH = haar_matrix_builder(H)
    WW = haar_matrix_builder(W)

    # inverse if true
    if inverse:
        WH = WH.T
        WW = WW.T

    # apply the haar transform
    return WH @ x @ WW.T

In [None]:
# apply X pauli matrix
def apply_x_to_statevector(sv: np.ndarray, n_qubits: int, q: int) -> np.ndarray:
    psi = sv.reshape([2] * n_qubits)
    psi = np.swapaxes(psi, q, 0)
    psi = psi[::-1, ...] # reverse order of amplitudes to apply X gate
    psi = np.swapaxes(psi, 0, q)
    return psi.reshape(-1)

# apply Z pauli matrix
def apply_z_to_statevector(sv: np.ndarray, n_qubits: int, q: int) -> np.ndarray:
    psi = sv.reshape([2] * n_qubits)
    psi = np.swapaxes(psi, q, 0)
    psi[1, ...] *= -1 # multiply amplitudes of 1 states by -1
    psi = np.swapaxes(psi, 0, q)
    return psi.reshape(-1)

# apply random pauli gates to the circuit
def apply_local_pauli_noise(sv: np.ndarray, n_qubits: int, p: float, rng=None) -> np.ndarray:
    # if no seed is given use random seed
    if rng is None:
        rng = np.random.default_rng()

    # takes copy of statevector
    out = sv.copy()

    # randomly applies X, Y, Z paulis to qubit
    # higher p means higher error
    for q in range(n_qubits):
        r = rng.random() # random float from 0 to 1

        if r < (1 - p): # 1-p chance of applying identity
            continue

        elif r < (1 - p) + p/3: # p/3 chance of applying X
            out = apply_x_to_statevector(out, n_qubits, q)

        elif r < (1 - p) + 2*p/3: # p/3 chance of applying Y
            # applies Y pauli implicitly since Y=iXZ
            out = apply_z_to_statevector(out, n_qubits, q)
            out = apply_x_to_statevector(out, n_qubits, q)
            out = 1j * out

        else: # p/3 chance of applying Z
            out = apply_z_to_statevector(out, n_qubits, q)

    return out

# noise function
def haar_noise(sv_ideal, H_pad, W_pad, n_qubits, p=0.01, trials=100, seed=0):
    rng = np.random.default_rng(seed)
    Ys = []

    # runs multiple trials of noise to avg out at the end
    for _ in range(trials):
        sv_noisy = apply_local_pauli_noise(sv_ideal, n_qubits, p, rng=rng)
        Ys.append(sv_noisy.reshape(H_pad, W_pad))

    # take mean and standard deviation of the noisy Haar transformed images
    Ys = np.stack(Ys, axis=0)
    mean_Y = Ys.mean(axis=0)
    std_Y = np.abs(Ys).std(axis=0)

    return mean_Y, std_Y

In [None]:
def haar_transform_2d_qiskit(x, inverse=False, noise=0.01):
    # assert checks for batch and channel = 1
    B, C, H, W = x.shape
    assert B == 1 and C == 1, "function assumes B and C are = to 1"
    assert noise <= 1 and noise >= 0, "noise range: [0,1]"

    # pad h and w to the next power of 2
    H_pad_size = next_power_of_2(H)
    W_pad_size = next_power_of_2(W)
    X = np.zeros((H_pad_size, W_pad_size), dtype=float)
    X[:H, :W] = x[0, 0] # new padded image

    psi = X.reshape(-1) # flatten
    norm = np.linalg.norm(psi)
    if norm == 0:
        raise ValueError("Input is all zeros; cannot normalize for amplitude encoding.")
    amps = psi / norm # flatten normalized list of probability amplitudes

    # find num qubits
    n_row = int(math.log2(H_pad_size))
    n_col = int(math.log2(W_pad_size))
    n_qubits = n_row + n_col

    # qubit numbers
    col_qubits = list(range(n_col))
    row_qubits = list(range(n_col, n_col+n_row))

    # haar matrices inverse if true
    transform_H = haar_matrix_builder(H_pad_size)
    transform_W = haar_matrix_builder(W_pad_size)
    if inverse:
        transform_H = transform_H.T
        transform_W = transform_W.T
        
    # turn haar transform into operators
    U_row = Operator(transform_H)
    U_col = Operator(transform_W)

    # create q-circuit and encode amplitudes
    qc = QuantumCircuit(n_qubits)
    qc.initialize(amps, list(range(n_qubits)))

    # put the haar transforms into the circuit
    qc.append(U_col, col_qubits)
    qc.append(U_row, row_qubits)

    # extract amplitudes after applying haar
    sv_ideal = Statevector.from_instruction(qc).data # ideal
    Yq_ideal = sv_ideal.reshape(H_pad_size, W_pad_size)
    Yq_mean, Yq_std = haar_noise(sv_ideal, H_pad_size, W_pad_size, n_qubits, p=noise, trials=500)
    
    # get the normalized images back to shape and compare with classical implementation
    Yc = haar_transform_2d_classical(X) / norm

    return Yq_ideal, Yq_mean, Yq_std, Yc, qc

In [None]:
# all i changed in this CNN is name of the haar transform function
class HWTConv2D(torch.nn.Module):
    """2D Haar Wavelet Conv layer - HWT-MA-Net"""
    def __init__(self, height, width, in_channels, out_channels, pods=1, residual=True):
        super().__init__()
        self.height = height
        self.width = width
        self.height_pad = find_min_power(self.height)
        self.width_pad = find_min_power(self.width)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pods = pods
        
        # 1x1 Convolution, Channel mixing per pod
        self.conv = torch.nn.ModuleList([torch.nn.Conv2d(in_channels, out_channels, 1, bias=False) for _ in range(self.pods)])
        
        # Scaling parameters per pod
        self.v = torch.nn.ParameterList([
            torch.nn.Parameter(torch.rand(1, 1, self.height_pad, self.width_pad))
            for _ in range(pods)
        ])
        
        # Soft-thresholding per pod
        self.ST = torch.nn.ModuleList([
            SoftThresholding((self.height_pad, self.width_pad)) for _ in range(pods)
        ])
        self.residual = residual
        
    def forward(self, x):
        B, C_in, height, width = x.shape
        if height != self.height or width != self.width:
            raise Exception(f'({height}, {width})!=({self.height}, {self.width})')
        
        # Pad to power of 2
        pad_h = self.height_pad - height
        pad_w = self.width_pad - width
        if pad_h > 0 or pad_w > 0:
            f0 = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
        else:
            f0 = x
        
        # 2D Haar Transform
        f1 = haar_transform_2d_qiskit(f0)
        
        outputs = []
        for i in range(self.pods):
            f3 = f1 * self.v[i]  # Scaling
            f4 = self.conv[i](f3)  # 1x1 Conv
            f5 = self.ST[i](f4)  # Soft-threshold
            outputs.append(f5)
        
        # Sum paths
        f6 = torch.stack(outputs, dim=0).sum(dim=0)
        
        # Inverse Haar
        f7 = haar_transform_2d_qiskit(f6, inverse=True)
        
        # Crop
        y = f7[:, :, :height, :width]
        
        if self.residual:
            y = y + x
        return y

In [None]:
if __name__ == "__main__":
    # example here is 32x32
    #np.random.seed(0)
    X = np.random.rand(1, 1, 28, 28).astype(float)

    # noise is from 0 to 1     1 being most noise
    Yq_ideal, Yq_mean, Yq_std, Yc, qc = haar_transform_2d_qiskit(X, inverse=False, noise=0.01)

    # error between quantum and classical implementation of haar transform
    err_ideal = np.max(np.abs(np.real(Yq_ideal) - Yc))
    err_noise = np.max(np.abs(np.real(Yq_ideal) - Yq_mean))
    print("max ideal error:", err_ideal)
    print("max noisy error:", err_noise)

    # print matrix patch or full
    n_patch = 4
    print("\nYq (ideal)\n", np.real(Yq_ideal[:n_patch, :n_patch]))
    print("\nYq (mean with noise)\n", np.real(Yq_mean[:n_patch, :n_patch]))
    print("Sensitivity to noise:\n", Yq_std[:n_patch, :n_patch]) # sensitivity eachs matrix element is to noise
    print("\nYc\n", Yc[:n_patch, :n_patch])