Test self-implemented ResNet models. For the reason why I re-implement ResNet, please
see `libs.models.resnet`.

In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

from libs.models.resnet import ResNet

  from .autonotebook import tqdm as notebook_tqdm


Proof that the self-implemented version of ResNet is identical to the pytorch built-in model.

In [2]:
resnet = ResNet.make_resnet18(10)
resnet_ori = torchvision.models.ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 10)

flag_different = False
for (n1, p1), (n2, p2) in zip(resnet.named_parameters(), resnet_ori.named_parameters()):
    if p1.size() != p2.size():
        flag_different = True
        break

if flag_different:
    raise Exception("Two models are not identical.")
else:
    print("Two models are identical.")

Two models are identical.


Run the self-implemented model for testing.

In [3]:
model = ResNet.make_resnet18(num_classes=10, init_kernel_size=3)
model(torch.rand(10, 3, 32, 32))

tensor([[-0.5548, -0.4115, -0.1389, -0.2831, -0.1696,  0.4623,  0.3616,  0.1232,
          0.4944,  0.1456],
        [-0.3215, -0.0894,  0.2850,  0.5918, -0.4268,  0.7596,  0.5146,  0.1531,
         -0.2377, -0.1255],
        [-0.6762, -1.2966,  0.1408,  0.2616,  0.0466,  0.1078,  0.2966,  0.3645,
         -0.0714, -0.0775],
        [-0.2356,  0.0014,  0.2276,  0.2469, -0.2372,  0.4808,  0.2287,  0.8208,
          0.1635,  0.3746],
        [-0.4932, -0.0779, -0.1026,  0.2825, -0.2175,  0.1204,  0.6633,  0.5851,
          0.4210,  0.0791],
        [ 0.0151, -1.0856,  0.2241,  0.4727,  0.0396,  0.6296,  0.7116,  0.9440,
          0.8056,  0.4021],
        [-0.3994, -0.2544,  0.3355,  0.3875, -0.2821,  0.3720,  0.1096,  0.1705,
          0.5432, -0.2086],
        [-0.7231, -0.3039,  0.5045, -0.3811,  0.1033,  0.2052,  0.5621,  0.3010,
         -0.0134, -0.0479],
        [-0.5354, -0.2594,  0.3958,  0.0047, -0.0433,  0.2814,  0.5826,  0.9028,
          0.3260,  0.3085],
        [-0.3721,  