In [79]:
from math import ceil, floor

def build_decoder_layers(strides):
    layers = []
    
    layers.append(
        {"name": "DecInputConv", "type": "conv", "k": 7, "s": 1, "d": 1, "p": 3}
        )
    
    for i, stride in enumerate(strides, 1):
        upsample_k = 2 * stride
        upsample_p = ceil(stride / 2)
        
        layers.append(
            {
            "name": f"DecBlock{i}-UpsampleConv", 
            "type": "Tconv", 
            "k": upsample_k, 
            "s": stride, 
            "d": 1, 
            "p": upsample_p
            }
        )
        
        for unit in range(1, 4):
            dilation = 2 ** (unit - 1)  # 1, 3, 9
            
            padding = dilation * (7 - 1) // 2
            
            # Conv1 in ResidualUnit
            layers.append({
                "name": f"DecBlock{i}-Unit{unit}-Conv1",
                "type": "conv", 
                "k": 7, 
                "s": 1, 
                "d": dilation, 
                "p": padding
            })
            
            # Conv2 in ResidualUnit
            layers.append({
                "name": f"DecBlock{i}-Unit{unit}-Conv2",
                "type": "conv", 
                "k": 1, 
                "s": 1, 
                "d": 1, 
                "p": 0
            })
    
    
    layers.append(
        {"name": "DecFinalConv", "type": "conv", "k": 7, "s": 1, "d": 1, "p": 3}
        )
    
    return layers

In [80]:
def calculate_receptive_field(layers, t_out=0):
    n = len(layers)
    
    jumps = []
    J = 1.0
    for layer in layers:
        jumps.append(J)
        if layer["type"] == "conv":
            J *= layer["s"]
        else:  # tconv
            J /= layer["s"]
    
    # Work backwards
    L = R = t_out
    for i in reversed(range(n)):
        layer = layers[i]
        k, s, d, p = layer["k"], layer["s"], layer["d"], layer["p"]
        J = jumps[i]
        
        if layer["type"] == "conv":
            L_new = L - p * J
            R_new = R + ((k - 1) * d - p) * J
        else:  # tconv
            L_new = ceil(L/J + (p - (k - 1) * d) / s) * J
            R_new = floor(R/J + p / s) * J
        
        rf_len = R_new - L_new + 1
        layer["jump"] = J
        layer["L"] = L_new
        layer["R"] = R_new
        layer["rf_len"] = rf_len
        
        L, R = L_new, R_new
    
    return layers

In [81]:
def print_results(layers):
    """Print receptive field results"""
    print(f"{'Layer':<15} {'Type':<6} {'k':<3} {'s':<3} {'p':<3} {'Jump':<10} {'L':<8} {'R':<8} {'RF_len':<8}")
    print("-" * 70)
    for layer in layers:
        print(f"{layer['name']:<15} {layer['type']:<6} {layer['k']:<3} {layer['s']:<3} {layer['p']:<3} "
              f"{layer['jump']:<10.4f} {layer['L']:<8.2f} {layer['R']:<8.2f} {layer['rf_len']:<8.2f}")

In [84]:
strides = [2, 8, 5, 6, 4]
layers = build_decoder_layers(strides)
print(f"Testing strides: {strides}")
result_layers = calculate_receptive_field(layers.copy())
print_results(result_layers)
print(f"\nFinal receptive field size: {result_layers[0]['L']} and {result_layers[0]['R']}")

Testing strides: [2, 8, 5, 6, 4]
Layer           Type   k   s   p   Jump       L        R        RF_len  
----------------------------------------------------------------------
DecInputConv    conv   7   1   3   1.0000     -16.00   15.00    32.00   
DecBlock1-UpsampleConv Tconv  4   2   1   1.0000     -13.00   12.00    26.00   
DecBlock1-Unit1-Conv1 conv   7   1   3   0.5000     -12.50   12.00    25.50   
DecBlock1-Unit1-Conv2 conv   1   1   0   0.5000     -11.00   10.50    22.50   
DecBlock1-Unit2-Conv1 conv   7   1   6   0.5000     -11.00   10.50    22.50   
DecBlock1-Unit2-Conv2 conv   1   1   0   0.5000     -8.00    7.50     16.50   
DecBlock1-Unit3-Conv1 conv   7   1   12  0.5000     -8.00    7.50     16.50   
DecBlock1-Unit3-Conv2 conv   1   1   0   0.5000     -2.00    1.50     4.50    
DecBlock2-UpsampleConv Tconv  16  8   4   0.5000     -2.00    1.50     4.50    
DecBlock2-Unit1-Conv1 conv   7   1   3   0.0625     -1.69    1.62     4.31    
DecBlock2-Unit1-Conv2 conv   1   1   

### Check DAC Sampling Ratio

In [None]:
import inspect
print(inspect.getsource(torch.nn.Conv1d))