In [1]:
def compact_param_count(num_feat: int, num_conv: int, upscale: int = 4) -> int:
    """
    Compute total number of parameters in the 'compact' network.

    Args:
        num_feat (int): number of feature channels (F)
        num_conv (int): number of conv layers in the body (L)
        upscale (int): upscaling factor (default = 4)
    Returns:
        int: total number of parameters
    """
    F = num_feat
    L = num_conv
    C_in, C_out = 3, 3
    out_hr = C_out * (upscale ** 2)

    # First conv: 3 -> F
    P1 = (C_in * F * 3 * 3) + F

    # Body convs: (L-1) layers of F -> F
    Pbody = (L - 1) * ((F * F * 3 * 3) + F)

    # Last conv: F -> out_hr
    Plast = (F * out_hr * 3 * 3) + out_hr

    # PReLU activations (L of them, each with F params)
    Pprelu = L * F

    total = P1 + Pbody + Plast + Pprelu
    return total


# Example usage:
for (F, L) in [(24, 8), (40, 6), (40, 7), (64, 8)]:
    print(f"F={F:2d}, L={L:2d} -> {compact_param_count(F, L):,} parameters")

F=24, L= 8 -> 47,736 parameters
F=40, L= 6 -> 90,888 parameters
F=40, L= 7 -> 105,368 parameters
F=64, L= 8 -> 288,496 parameters


In [2]:
compact_param_count(24,8)

47736

In [20]:
compact_param_count(38,7)

95998

In [22]:
compact_param_count(35,8)

93848

In [2]:
compact_param_count(34,8)

89026