In [11]:
import numpy as np
from pynq import Overlay
from pynq import allocate
from time import perf_counter

BITFILE = "./srcnn.bit" 
HWHFILE = BITFILE.replace(".bit", ".hwh") 

# Match C Types
DT_FTMAP = np.float32   # ftmap_t
DT_PARAM = np.float32   # param_t

# Compile-time constants:
N0, N1, N2, N3 = 1, 64, 32, 1
F1, F2, F3 = 9, 1, 5
H, W = 255, 255
# ----------------------------------------------------------

ol = Overlay(BITFILE)  # loads bitstream + hwh

# Get the IP
ip = ol.ip_dict  # peek available IPs
print(list(ip.keys()))

srcnn = ol.srcnn_0 


print(srcnn.register_map)
# srcnn.register_map



##########################################################################################
# Utility functions:
##########################################################################################
from pathlib import Path
import numpy as np

def _ensure_size(fname: str | Path, nelems: int, itemsize: int):
    nbytes = Path(fname).stat().st_size
    if nbytes != nelems * itemsize:
        raise ValueError(
            f"{fname}: size mismatch. File has {nbytes} bytes, "
            f"expected {nelems*itemsize} for the requested shape."
        )

def read_image_u8(fname, shape, out_dtype=np.float32, normalize=True):
    """
    Matches C load_image(): read uint8 file, optional normalize to [0,1].
    shape: tuple like (N0, H, W) or (H, W).
    """
    nelems = int(np.prod(shape))
    _ensure_size(fname, nelems, 1)
    a = np.fromfile(fname, dtype=np.uint8, count=nelems).reshape(shape)
    if normalize:
        return (a.astype(out_dtype) / out_dtype(255.0))
    return a.astype(out_dtype)

def read_float_bin(fname, shape, endian="<", out_dtype=np.float32):
    """
    Matches C load_param()/load_ftmap(): read raw 32-bit IEEE float file.
    endian: '<' little-endian (PCs/Linux default), '>' big-endian if needed.
    """
    nelems = int(np.prod(shape))
    _ensure_size(fname, nelems, 4)
    dt = np.dtype(endian + "f4")
    a = np.fromfile(fname, dtype=dt, count=nelems).reshape(shape)
    return a.astype(out_dtype, copy=False)

def write_image_u8(fname, img_01):
    """
    Matches C write_bin(): scale [0,1] → uint8 and write.
    """
    out = np.clip(np.rint(img_01 * 255.0), 0, 255).astype(np.uint8)
    out.tofile(fname)

def mse(a, b):
    a = np.asarray(a, dtype=np.float64)
    b = np.asarray(b, dtype=np.float64)
    return np.mean((a - b) ** 2)

def psnr_from_mse(mse_val, peak=1.0):
    # peak = 1.0 for [0,1] floats, 255.0 for uint8-scale
    if mse_val == 0:
        return float("inf")
    return 10.0 * np.log10((peak * peak) / mse_val)

def psnr_float_01(img1, img2, peak=1.0):
    """
    PSNR in float domain (assumes values in [0,1] if peak=1).
    """
    err = np.asarray(img1, dtype=np.float64) - np.asarray(img2, dtype=np.float64)
    mse = np.mean(err * err)
    if mse == 0:
        return float("inf")
    return 10.0 * np.log10((peak * peak) / mse)

def psnr_u8_like_c(img1, img2):
    """
    Match the C behavior:
      rmse = sqrt(mean( (uint8(img1*255) - uint8(img2*255))^2 ))
      psnr = 20*log10(255/rmse)
    """
    a = (np.asarray(img1) * 255.0).astype(np.uint8, copy=False)
    b = (np.asarray(img2) * 255.0).astype(np.uint8, copy=False)

    # use signed type to avoid uint8 wrap on subtraction
    diff = a.astype(np.int16) - b.astype(np.int16)
    rmse = np.sqrt(np.mean(diff * diff))
    if rmse == 0:
        return float("inf")
    return 20.0 * np.log10(255.0 / rmse)


def set_ptr64(ip, base_name, addr):
    rm = ip.register_map
    lo = int(addr) & 0xFFFFFFFF
    hi = (int(addr) >> 32) & 0xFFFFFFFF

    # Some PYNQ versions also expose a combined 64b field; prefer it if present
    if hasattr(rm, base_name):
        setattr(rm, base_name, int(addr))
        return

    # Common case: two 32b halves named *_1 and *_2
    if hasattr(rm, base_name + "_1") and hasattr(rm, base_name + "_2"):
        setattr(rm, base_name + "_1", lo)
        setattr(rm, base_name + "_2", hi)
        return

    # Fallback: write by offsets (shouldn’t be needed if .hwh loaded)
    off_lo = getattr(rm, base_name + "_1").offset
    ip.write(off_lo,     lo)
    ip.write(off_lo + 4, hi)



['srcnn_0', 'zynq_ultra_ps_e_0']
RegisterMap {
  CTRL = Register(AP_START=0, AP_DONE=0, AP_IDLE=1, AP_READY=0, RESERVED_1=0, AUTO_RESTART=0, RESERVED_2=0, INTERRUPT=0, RESERVED_3=0),
  GIER = Register(Enable=0, RESERVED=0),
  IP_IER = Register(CHAN0_INT_EN=0, CHAN1_INT_EN=0, RESERVED_0=0),
  IP_ISR = Register(CHAN0_INT_ST=0, CHAN1_INT_ST=0, RESERVED_0=0),
  input_ftmap_1 = Register(input_ftmap=write-only),
  input_ftmap_2 = Register(input_ftmap=write-only),
  conv1_weights_1 = Register(conv1_weights=write-only),
  conv1_weights_2 = Register(conv1_weights=write-only),
  conv1_biases_1 = Register(conv1_biases=write-only),
  conv1_biases_2 = Register(conv1_biases=write-only),
  conv2_weights_1 = Register(conv2_weights=write-only),
  conv2_weights_2 = Register(conv2_weights=write-only),
  conv2_biases_1 = Register(conv2_biases=write-only),
  conv2_biases_2 = Register(conv2_biases=write-only),
  conv3_weights_1 = Register(conv3_weights=write-only),
  conv3_weights_2 = Register(conv3_weights

In [15]:
# Allocate physically-contiguous DDR buffers the PL can DMA from/to
# Alloccate arrays in DRAM memory
input_ftmap  = allocate((N0, H, W),           dtype=DT_FTMAP)   # image
output_ftmap = allocate((N3, H, W),           dtype=DT_FTMAP)   # result
conv1_w      = allocate((N1, N0, F1, F1),     dtype=DT_PARAM)
conv1_b      = allocate((N1,),                dtype=DT_PARAM)
conv2_w      = allocate((N2, N1, F2, F2),     dtype=DT_PARAM)
conv2_b      = allocate((N2,),                dtype=DT_PARAM)
conv3_w      = allocate((N3, N2, F3, F3),     dtype=DT_PARAM)
conv3_b      = allocate((N3,),                dtype=DT_PARAM)

# Inputs
img_host       = read_image_u8("./set5/butterfly_3x_LR_u8.bin",       (N0, H, W))         # uint8→[0,1] float32


# loading trained SRCNN weights
w1_host        = read_float_bin("./weights/conv1_weights_3x_flp.bin", (N1, N0, F1, F1))   # float32
b1_host        = read_float_bin("./weights/conv1_biases_3x_flp.bin",  (N1,))
w2_host        = read_float_bin("./weights/conv2_weights_3x_flp.bin", (N2, N1, F2, F2))
b2_host        = read_float_bin("./weights/conv2_biases_3x_flp.bin",  (N2,))
w3_host        = read_float_bin("./weights/conv3_weights_3x_flp.bin", (N3, N2, F3, F3))
b3_host        = read_float_bin("./weights/conv3_biases_3x_flp.bin",  (N3,))

# copy and flush
input_ftmap[:] = img_host;  input_ftmap.flush()
conv1_w[:]     = w1_host;   conv1_w.flush()
conv1_b[:]     = b1_host;   conv1_b.flush()
conv2_w[:]     = w2_host;   conv2_w.flush()
conv2_b[:]     = b2_host;   conv2_b.flush()
conv3_w[:]     = w3_host;   conv3_w.flush()
conv3_b[:]     = b3_host;   conv3_b.flush()

    
    
# ------------------------- Program kernel args --------------------------
rm = srcnn.register_map  # has fields named after your top-level args

# # Most PYNQ builds expose 64-bit address fields with the same names as your ports.
# # You can either set by name (preferred) or by raw offsets. Setting by name:
# rm.input_ftmap      = input_ftmap.physical_address
# rm.conv1_weights    = conv1_w.physical_address
# rm.conv1_biases     = conv1_b.physical_address
# rm.conv2_weights    = conv2_w.physical_address
# rm.conv2_biases     = conv2_b.physical_address
# rm.conv3_weights    = conv3_w.physical_address
# rm.conv3_biases     = conv3_b.physical_address
# rm.output_ftmap     = output_ftmap.physical_address

set_ptr64(srcnn, "input_ftmap",   input_ftmap.physical_address)
set_ptr64(srcnn, "conv1_weights", conv1_w.physical_address)
set_ptr64(srcnn, "conv1_biases",  conv1_b.physical_address)
set_ptr64(srcnn, "conv2_weights", conv2_w.physical_address)
set_ptr64(srcnn, "conv2_biases",  conv2_b.physical_address)
set_ptr64(srcnn, "conv3_weights", conv3_w.physical_address)
set_ptr64(srcnn, "conv3_biases",  conv3_b.physical_address)
set_ptr64(srcnn, "output_ftmap",  output_ftmap.physical_address)



# If your hwh names differ (e.g., suffixed with _1), print(rm) and set those names instead.

# srcnn.register_map = rm  # write back

# ----------------------------- Run kernel ------------------------------
t0 = perf_counter()
srcnn.write(0x00, 1)              # ap_start = 1 (bit0). PYNQ also supports srcnn.start()
# srcnn.start()

# Wait for ap_done (bit1). Simple poll; you can also use interrupts.
while (srcnn.read(0x00) & 0x2) == 0:
    pass
t1 = perf_counter()
print(f"Kernel time: {(t1 - t0)*1e3:.2f} ms")

# Result coherency: invalidate caches for buffers the PL wrote
output_ftmap.invalidate()

# ----------------------------- Use results -----------------------------
print("Output stats:", float(output_ftmap.min()), float(output_ftmap.max()))
# e.g., visualize center crop
print(output_ftmap[0, H//2-4:H//2+5, W//2-4:W//2+5])


# Ground Reference
ref  = read_float_bin("./set5/butterfly_3x_GR_flp.bin",       (N3, H, W))         # float32
# Ground truth HR (not GR!)
ref_GT  = read_image_u8("./set5/butterfly_3x_GT_u8.bin",       (N3, H, W))         # uint8

# MSE Calculation
m = mse(output_ftmap[0], ref[0])
print("MSE:", m)
# print("PSNR (dB):", psnr_float_01(output_ftmap[0], ref[0]))
print("PSNR (dB):", psnr_u8_like_c(output_ftmap[0], ref_GT[0]))

Kernel time: 77647.21 ms
Output stats: 0.08662913739681244 0.977342426776886
[[0.7266981  0.7240178  0.72445774 0.7226668  0.7225674  0.7230912
  0.7208     0.71854    0.71898794]
 [0.72910124 0.72463864 0.7231202  0.7249268  0.7271924  0.7239278
  0.71885383 0.7138936  0.72397584]
 [0.7287495  0.72689694 0.7266812  0.72745234 0.724705   0.7246985
  0.7230487  0.71978873 0.72489774]
 [0.7310186  0.7300781  0.7348327  0.73387605 0.7303059  0.72948897
  0.7293809  0.7257134  0.72433853]
 [0.7308904  0.7304646  0.73089266 0.73061395 0.7329797  0.73287374
  0.73143315 0.7264671  0.7260677 ]
 [0.73488337 0.7344304  0.7329245  0.73406315 0.7327407  0.7316062
  0.72867894 0.7248599  0.72475165]
 [0.736313   0.73322266 0.7322284  0.73650783 0.7350246  0.73226756
  0.734739   0.72863656 0.72687864]
 [0.732189   0.7276774  0.7261607  0.73400575 0.73594004 0.73027235
  0.73333234 0.73003614 0.72810024]
 [0.730495   0.7308389  0.73272574 0.7360512  0.73663616 0.7365305
  0.73331636 0.73051053 0.72