In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import sys
sys.path.append('..')
from models.resnet import BasicBlock, Bottleneck, ResNet, resnet18, resnet34, resnet50

In [2]:
model = resnet18()
x = torch.randn(4, 3, 32, 32)
y = model(x)
print(y.shape)  # should be [4, 10]
print("Params:", sum(p.numel() for p in model.parameters()))

torch.Size([4, 10])
Params: 11173962


In [3]:
model = resnet50()
x = torch.randn(4, 3, 32, 32)
y = model(x)
print(y.shape)  # [4, 10]
print("Params:", sum(p.numel() for p in model.parameters()))

torch.Size([4, 10])
Params: 23520842


In [4]:
blk = BasicBlock(64, 64)
x = torch.randn(2, 64, 32, 32)
print(blk(x).shape)   # expect [2, 64, 32, 32]

down = nn.Sequential(
    nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
    nn.BatchNorm2d(128),
)
blk2 = BasicBlock(64, 128, stride=2, downsample=down)
x = torch.randn(2, 64, 32, 32)
print(blk2(x).shape)  # expect [2, 128, 16, 16]


torch.Size([2, 64, 32, 32])
torch.Size([2, 128, 16, 16])


In [5]:
model = ResNet(BasicBlock, [2, 2, 2, 2])
x = torch.randn(4, 3, 32, 32)
y = model(x)
print(y.shape)  # should be [4, 10]

torch.Size([4, 10])


In [7]:
downsample = nn.Sequential(
    nn.Conv2d(64, 256, kernel_size=1, stride=2, bias=False),
    nn.BatchNorm2d(256),
)
block_ds = Bottleneck(in_channels=64, out_channels=64, stride=2, downsample=downsample)
x = torch.randn(2, 64, 56, 56)
y = block_ds(x)
print(y.shape)

torch.Size([2, 256, 28, 28])
