# Notations and Parameters

$C$: Number of channels

$H, W$: Height and Width of input image

$M$: denotes the number of kernels / filters

$h$: is the size of the kernels

$s$: stride

$N$: Maximum value to meet constraint MCHW <= N

$p$: Modulus for coefficients 

## Constraints: 
$MCHW ≤ N $

$O$ a constant -> $O = HW(MC − 1) +W(h − 1) + h − 1$ 

$ h ≤ H,W$

$ stride s > 0$

$H' = int((H − h + s) / s), W' = int((W − h + s) / s)$

Public parameters (pp): $pk, M, C,H, W, h, s$

output size = (M * C, H_out ,W_out)

## Input image

$π_{conv}^{i} \mathbb{Z}_{p}^{C×H×W} → A_{N,p)} $


$t̃ = π_{conv}^{i}(T)$ 

$t̃[cHW +iW + j] = T[c,i, j]$ 


# Kernel / filter

$π_{conv}^{w} \mathbb{Z}_{p}^{M×C×h×h} → A_{(N,p)} $

$k̃ = π_{conv}^{w}(K)$

$k̃[O − c'CHW − cHW − lW − l'] = K[c', c, l, l']$

# The convolution 
For each position $(c', i', j')$ of the output tensor $T'$

$T' = Conv2D(T,K ;s) =  T'[c', i', j'] = \sum\limits_{c \in [[C]], l,l' \in [[h]]} T[c,i's+l, j's+ l'] * K[c', c, l,l']$

With: $T[c,i's+l, j's+ l'] = t̃[cHW + (i's+l) * W + (j's+ l')]$

With: $K[c,i's+l, j's+ l'] = k̃[O - c' CHW - cHW - l * W - l_prime]$



In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import convolve2d

In [2]:
# Notations and Parameters
C = 3  # Number of channels
H, W = 4, 4  # Height and Width of input image
M = 2  # Number of kernels / filters

h = 2  # Size of kernels
s = 1  # Stride
N = 100  # Maximum value to meet constraint MCHW <= N
p = 29  # Modulus for coefficients (example)
modulo = 0 # Padding 

# ---------
H_out = (H - h + s) // s  # Height of output, <!> padding = 0
W_out = (W - h + s) // s  # Width of output, <!> padding = 0
O = H * W * (M * C - 1) + W * (h - 1) + h - 1

# Public parameters according to the paper
pp = {"pk": "public_key", "M": M, "C": C, "H": H, "W": W, "h": h, "s": s}

In [3]:
def pytorch_conv(T, K, modulo):
    """Conv2D with torch"""

    T_torch = torch.tensor(T, dtype=torch.float32).unsqueeze(0)  # Shape (1, C, H, W)
    K_torch = torch.tensor(K, dtype=torch.float32)  # Shape (M, C, h, h)

    T_out_torch = F.conv2d(T_torch, K_torch, stride=s).int().numpy() # Shape (M, H_out, W_out)

    return T_out_torch % modulo

def scipy_conv(T, K, H_out, W_out, modulo):
    """*Conv2D with scipy convolve2d for each kernel and channel"""

    T_out_scipy = np.zeros((M, H_out, W_out), dtype=int)

    for m in range(M):
        conv_sum = np.zeros((H_out, W_out))
        for c in range(C):
            # <!> Mirror the kernel
            conv_sum += convolve2d(T[c], np.flip(K[m, c]), mode='valid')[::s, ::s]
        T_out_scipy[m] = conv_sum % modulo

    return T_out_scipy

In [4]:
# Input image tensor T of shape (C, H, W)
T = np.array([[[1, 2, 3, 4],
               [5, 6, 7, 8],
               [9, 10, 11, 12],
               [13, 14, 15, 16]],

              [[2, 3, 4, 5],
               [6, 7, 8, 9],
               [10, 11, 12, 13],
               [14, 15, 16, 17]],

              [[1, 1, 1, 1],
               [1, 1, 1, 1],
               [1, 1, 1, 1],
               [1, 1, 1, 1]]])

# Kernel tensor K of shape (M, C, h, h)
K = np.array([[[[1, 0],
                [0, -1]],

               [[1, 1],
                [1, 0]],

               [[0, 1],
                [-1, 0]]],

              [[[1, -1],
                [1, 0]],

               [[0, 1],
                [1, -1]],

               [[1, 0],
                [0, 1]]]])

T.shape, K.shape

((3, 4, 4), (2, 3, 2, 2))

In [5]:
def pytorch_conv(T, K, modulo):
    """Convert numpy arrays to PyTorch tensors"""
    T_torch = torch.tensor(T, dtype=torch.float32).unsqueeze(0)  # Shape (1, C, H, W)
    K_torch = torch.tensor(K, dtype=torch.float32)  # Shape (M, C, h, h)

    T_out_torch = F.conv2d(T_torch, K_torch, stride=s).int().numpy() # Shape (M, H_out, W_out)

    return T_out_torch % modulo

def scipy_conv(T, K, H_out, W_out, modulo):
    """Perform convolution using scipy convolve2d for each kernel and channel"""
    T_out_scipy = np.zeros((M, H_out, W_out), dtype=int)
    for m in range(M):
        conv_sum = np.zeros((H_out, W_out))
        for c in range(C):
            conv_sum += convolve2d(T[c], np.flip(K[m, c]), mode='valid')[::s, ::s]
        T_out_scipy[m] = conv_sum % modulo

    return T_out_scipy

In [6]:
def pi_conv_i(T):
    """t̃[cHW + iW + j] = T[c,i, j]"""
    C, H, W = T.shape  

    t_tilde = {}
    detailed_t_tilde = {}

    for c in range(C):
        for i in range(H):
            for j in range(W):
                new_index = c * H * W + i * W + j
                t_tilde[new_index] = T[c, i, j]
                detailed_t_tilde[f"{(c,i,j)} -> {new_index}"] = T[c, i, j]
    return t_tilde, detailed_t_tilde


def pi_conv_w(K):
    """k̃[O - c'CHW - cHW - lW - l'] = K[c', c, l, l']$"""
    M, C, k_h, k_w = K.shape

    k_tilde = {}
    detailed_k_tilde = {}

    for c_prime in range(M):
        for c in range(C):
            for l in range(k_h):
                for l_prime in range(k_w):
                    new_index = O - c_prime * C * H * W - c * H * W - l * W - l_prime
                    k_tilde[new_index] = K[c_prime, c, l, l_prime]
                    detailed_k_tilde[f"{(c_prime, c, l, l_prime)} -> {new_index}"] = K[c_prime, c, l, l_prime]

    return k_tilde, detailed_k_tilde


def conv2d_polynomial(t_tilde, k_tilde):
    """ T' = Conv2D(T,K ;s) =  T'[c', i', j'] = \sum\limits_{c \in [[C]], l,l' \in [[h]]} T[c,i's+l, j's+ l'] * K[c', c, l,l']

        With: $T[c,i's+l, j's+ l'] = t̃[cHW + (i's+l) * W + (j's+ l')]$

        With: $K[c,i's+l, j's+ l'] = k̃[O - c' CHW - cHW - l * W - l_prime]$
    """
    T_out = np.zeros((M, H_out, W_out), dtype=int)
    for c_prime in range(M):
        for i_prime in range(H_out):
            for j_prime in range(W_out):
                value = 0
                for c in range(C):
                    for l in range(h):
                        for l_prime in range(h):
                            i = i_prime * s + l
                            j = j_prime * s + l_prime
                            t_index = c * H * W + i * W + j
                            k_index = O - c_prime * C * H * W - c * H * W - l * W - l_prime
                            value += t_tilde[t_index] * k_tilde[k_index]
                T_out[c_prime, i_prime, j_prime] = value % p
    return T_out

In [7]:
torch_conv2d_tensor = pytorch_conv(T, K, modulo=p)
scipy_conv2d_tensor = scipy_conv(T, K, H_out, W_out, modulo=p)

assert all((torch_conv2d_tensor == scipy_conv2d_tensor).flatten())

In [8]:
t_tilde, detailed_t_tilde = pi_conv_i(T)
k_tilde, detailed_k_tilde = pi_conv_w(K)

T_out = conv2d_polynomial(t_tilde, k_tilde)

torch_conv2d_tensor, T_out

(array([[[[ 6,  9, 12],
          [18, 21, 24],
          [ 1,  4,  7]],
 
         [[ 8, 10, 12],
          [16, 18, 20],
          [24, 26, 28]]]], dtype=int32),
 array([[[ 6,  9, 12],
         [18, 21, 24],
         [ 1,  4,  7]],
 
        [[ 8, 10, 12],
         [16, 18, 20],
         [24, 26, 28]]]))