In [1]:
import torch
import torch.nn as nn

def simulate_conv_shape(input_shape, conv_layers):
    """
    Simulates the output shape after passing through a series of Conv2D layers.
    
    Args:
        input_shape (tuple): Input shape as (batch_size, channels, height, width)
        conv_layers (list): List of dictionaries containing Conv2D parameters
            Each dict should have: out_channels, kernel_size, stride, padding
            
    Returns:
        list: Shape at each layer
    """
    device = torch.device('cpu')
    shapes = [input_shape]
    current_shape = input_shape
    
    # Create a dummy tensor
    x = torch.randn(current_shape, device=device)
    
    # Process each conv layer
    for i, layer_params in enumerate(conv_layers):
        # Create Conv2D layer
        conv = nn.Conv2d(
            in_channels=current_shape[1],
            out_channels=layer_params['out_channels'],
            kernel_size=layer_params['kernel_size'],
            stride=layer_params.get('stride', 1),
            padding=layer_params.get('padding', 0)
        ).to(device)
        
        # Forward pass
        x = conv(x)
        current_shape = tuple(x.shape)
        shapes.append(current_shape)
        
    return shapes

In [None]:
# Example usage
input_shape = (1, 1, 256, 256)  # batch_size=1, channels=3, height=32, width=32
conv_layers = [
    {'out_channels': 64, 'kernel_size': 3, 'stride': 1, 'padding': 1},
    {'out_channels': 128, 'kernel_size': 3, 'stride': 1, 'padding': 1},
    {'out_channels': 128, 'kernel_size': 4, 'stride': 2, 'padding': 1},
    {'out_channels': 256, 'kernel_size': 3, 'stride': 1, 'padding': 1},
    {'out_channels': 256, 'kernel_size': 4, 'stride': 2, 'padding': 1},
    {'out_channels': 256, 'kernel_size': 3, 'stride': 1, 'padding': 1},
    {'out_channels': 256, 'kernel_size': 4, 'stride': 2, 'padding': 1},
    {'out_channels': 4, 'kernel_size': 3, 'stride': 1, 'padding': 1},
]

shapes = simulate_conv_shape(input_shape, conv_layers)
for i, shape in enumerate(shapes):
    print(f"Layer {i}: {shape}")