# Analyzing dimensions

In [1]:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:

transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=4, shuffle=False)  # Small batch for inspection
images, labels = next(iter(loader))

print(f"Original input shape: {images.shape}")  # (B, C, H, W)


Original input shape: torch.Size([4, 1, 28, 28])


In [3]:

class CNNInspect(nn.Module):
    def __init__(self):
        super(CNNInspect, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 28x28 -> 28x28
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 14x14 -> 14x14
        self.pool = nn.MaxPool2d(2, 2)                           # halves spatial dim
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        print(f"Input: {x.shape}")
        x = self.conv1(x)
        print(f"After conv1: {x.shape}")
        x = torch.relu(x)
        x = self.pool(x)
        print(f"After pool1: {x.shape}")

        x = self.conv2(x)
        print(f"After conv2: {x.shape}")
        x = torch.relu(x)
        x = self.pool(x)
        print(f"After pool2: {x.shape}")

        x = self.flatten(x)
        print(f"After flatten: {x.shape}")
        x = self.fc1(x)
        print(f"After fc1: {x.shape}")
        x = self.fc2(x)
        print(f"After fc2 (output): {x.shape}")
        return x


In [4]:

model = CNNInspect().to(device)
images = images.to(device)
_ = model(images)


Input: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 28, 28])
After pool1: torch.Size([4, 32, 14, 14])
After conv2: torch.Size([4, 64, 14, 14])
After pool2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
After fc1: torch.Size([4, 128])
After fc2 (output): torch.Size([4, 10])
