In [1]:
def compute_conv_output_size(
        input_size,   # (H, W)
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1
    ):
    """
    Computes output (H, W), output channels, and total features after a Conv2D-like layer.
    Supports int or tuple arguments.
    """
    # unpack input H, W
    H, W = input_size

    # allow int or tuple kernel/stride/paddisng/dilation
    def to_tuple(x):
        return (x, x) if isinstance(x, int) else x

    KH, KW = to_tuple(kernel_size)
    SH, SW = to_tuple(stride)
    PH, PW = to_tuple(padding)
    DH, DW = to_tuple(dilation)

    # compute output height and width using PyTorch formula
    out_H = ((H + 2*PH - DH*(KH - 1) - 1) // SH) + 1
    out_W = ((W + 2*PW - DW*(KW - 1) - 1) // SW) + 1

    # output channels always = out_channels
    out_C = out_channels

    # total number of features
    total_features = out_C * out_H * out_W

    return (out_C, out_H, out_W, total_features)


In [6]:
print(compute_conv_output_size(
    input_size=(32, 32),
    in_channels=3,
    out_channels=10,
    kernel_size=3,
    stride=1,
    padding=0
))

(10, 30, 30, 9000)


# Loading Pretrained Models

In [None]:
import torch

def load_pretrained_weights(model, state_dict_path=None, pretrained_model=None, strict=False):
    """
    Load pretrained weights into a flexible model, skipping layers that don't match in shape.

    Args:
        model: nn.Module - your flexible model
        state_dict_path: str - path to pretrained state_dict (optional)
        pretrained_model: nn.Module - another model with pretrained weights (optional)
        strict: bool - whether to enforce exact match for remaining layers (default False)
    """

    if state_dict_path:
        # Load weights from a saved checkpoint
        pretrained_dict = torch.load(state_dict_path, map_location='cpu')
        if 'state_dict' in pretrained_dict:
            pretrained_dict = pretrained_dict['state_dict']
    elif pretrained_model:
        # Get state_dict from another model
        pretrained_dict = pretrained_model.state_dict()
    else:
        raise ValueError("Either state_dict_path or pretrained_model must be provided.")

    model_dict = model.state_dict()

    # Filter out layers that don't match in shape
    filtered_dict = {}
    for k, v in pretrained_dict.items():
        if k in model_dict:
            if model_dict[k].shape == v.shape:
                filtered_dict[k] = v
            else:
                print(f"Skipping layer {k}: size mismatch {v.shape} vs {model_dict[k].shape}")
        else:
            print(f"Skipping layer {k}: not found in target model")

    # Load matched layers
    model_dict.update(filtered_dict)
    model.load_state_dict(model_dict, strict=strict)
    print(f"Loaded {len(filtered_dict)} / {len(model_dict)} layers from pretrained weights.")

    return model


In [None]:
import torchvision.models as models

# Suppose we want to use AlexNet pretrained weights
pretrained_alexnet = models.alexnet(pretrained=True)

# Our flexible AlexNet with different input channels / classes
alex_cnn = AlexNet(in_channels=1, num_classes=10)

# Load pretrained weights where shapes match
alex_cnn = load_pretrained_weights(alex_cnn, pretrained_model=pretrained_alexnet)


In [None]:
# Path to checkpoint
checkpoint_path = "alexnet_cifar10.pth"
alex_cnn = load_pretrained_weights(alex_cnn, state_dict_path=checkpoint_path)
