In [11]:
import torch
from collections import defaultdict
import torch.nn.functional as F

def conv_bitwise_counts(
    x_pt: torch.Tensor,
    weights: torch.Tensor,
    input_bits: int = 8,
    weight_bits: int = 5,
    stride: int = 1,
    padding: int = 0,
    per_output_channel: bool = False,
):
    """
    Compute bitwise pair counts ((1,1), (1,0), (0,1), (0,0)) for a convolutional multiplication
    (bit-serial view) between input and weights.

    Parameters
    ----------
    x_pt : torch.Tensor
        Input tensor of shape (B, C_in, H, W), dtype=torch.uint8.
    weights : torch.Tensor
        Weight tensor of shape (C_out, C_in, KH, KW), dtype=torch.uint8.
    input_bits : int
        Bit-width of the input values (default 8).
    weight_bits : int
        Bit-width of the weight values (default 5).
    stride : int
        Convolution stride (default 1).
    padding : int
        Convolution zero-padding (default 0).
    per_output_channel : bool
        If True, returns per-output-channel bit-pair statistics.

    Returns
    -------
    global_counts : dict[int, dict[str, int]]
        Global bit-pair counts aggregated over all channels and spatial positions.
    out_counts : dict[int, dict[int, dict[str, int]]], optional
        Per-output-channel counts, only if per_output_channel=True.
    """

    B, C_in, H, W = x_pt.shape
    C_out, C_in_w, KH, KW = weights.shape
    assert C_in == C_in_w, "Input and weight channel dimensions must match"

    # Compute output spatial dimensions
    H_out = (H + 2 * padding - KH) // stride + 1
    W_out = (W + 2 * padding - KW) // stride + 1

    # Initialize dictionaries
    def new_bitdict():
        return {"(1,1)": 0, "(1,0)": 0, "(0,1)": 0, "(0,0)": 0}

    global_counts = {i: new_bitdict() for i in range(input_bits)}
    out_counts = (
        {oc: {i: new_bitdict() for i in range(input_bits)} for oc in range(C_out)}
        if per_output_channel
        else None
    )

    # Pad input tensor if needed
    if padding > 0:
        x_pt = F.pad(x_pt, (padding, padding, padding, padding), mode="constant", value=0)
    for b in range(B):
        for oc in range(C_out):
            for oy in range(H_out):
                for ox in range(W_out):
                    y0 = oy * stride
                    x0 = ox * stride

                    for ic in range(C_in):
                        for ky in range(KH):
                            for kx in range(KW):
                                x_val = int(x_pt[b, ic, y0 + ky, x0 + kx].item())
                                w_val = int(weights[oc, ic, ky, kx].item())

                                for i_bit in range(input_bits):
                                    xi = (x_val >> i_bit) & 1
                                    for w_bit in range(weight_bits):
                                        wi = (w_val >> w_bit) & 1
                                        if xi == 1 and wi == 1:
                                            key = "(1,1)"
                                        elif xi == 1 and wi == 0:
                                            key = "(1,0)"
                                        elif xi == 0 and wi == 1:
                                            key = "(0,1)"
                                        else:
                                            key = "(0,0)"

                                        global_counts[i_bit][key] += 1
                                        if per_output_channel:
                                            out_counts[oc][i_bit][key] += 1

    return (global_counts, out_counts) if per_output_channel else global_counts

if __name__ == "__main__":
    # Input: (B=1, C_in=3, H=7, W=7)
    x_pt = torch.zeros((1, 3, 7, 7), dtype=torch.uint8)
    x_pt[0, 0, 0, 0] = 1  # single active pixel

    # Weights: (C_out=4, C_in=3, KH=3, KW=3)
    weights = torch.zeros((4, 3, 3, 3), dtype=torch.uint8)
    weights[0, 0, 0, 0] = 1  # single active weight

    global_counts, out_counts = conv_bitwise_counts(
        x_pt, weights,
        input_bits=8, weight_bits=5,
        stride=1, padding=0,
        per_output_channel=True
    )

    print("=== Global bitwise counts per instant ===")
    for inst, counts in global_counts.items():
        print(f"Instant {inst}: {counts}")

    print("\n=== Per-output-channel (example: channel 0) ===")
    for inst, counts in out_counts[0].items():
        print(f"Instant {inst}: {counts}")


=== Global bitwise counts per instant ===
Instant 0: {'(1,1)': 1, '(1,0)': 19, '(0,1)': 24, '(0,0)': 13456}
Instant 1: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 13475}
Instant 2: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 13475}
Instant 3: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 13475}
Instant 4: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 13475}
Instant 5: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 13475}
Instant 6: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 13475}
Instant 7: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 13475}

=== Per-output-channel (example: channel 0) ===
Instant 0: {'(1,1)': 1, '(1,0)': 4, '(0,1)': 24, '(0,0)': 3346}
Instant 1: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 3350}
Instant 2: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 3350}
Instant 3: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 3350}
Instant 4: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 3350}
Instant 5: {'(1,1)': 0, '(1,0)': 0, '(0,1)': 25, '(0,0)': 3350}
Inst