In [7]:
import torch
from torch import nn
from torchsummary import summary

In [10]:
class ResidualBlockPytorch(nn.Module):
  def __init__(self, num_channels, output_channels, strides=1, is_used_conv11=False, **kwargs):
    """
    num_channels: số kênh
    """
    super(ResidualBlockPytorch, self).__init__(**kwargs)
    self.is_used_conv11 = is_used_conv11
    self.conv1 = nn.Conv2d(num_channels, num_channels, padding=1, 
                           kernel_size=3, stride=1)
    self.batch_norm = nn.BatchNorm2d(num_channels)
    self.conv2 = nn.Conv2d(num_channels, num_channels, padding=1, 
                           kernel_size=3, stride=1)
    if self.is_used_conv11:
      self.conv3 = nn.Conv2d(num_channels, num_channels, padding=0, 
                           kernel_size=1, stride=1)
    # Last convolutional layer to reduce output block shape.
    self.conv4 = nn.Conv2d(num_channels, output_channels, padding=0, 
                           kernel_size=1, stride=strides)
    self.relu = nn.ReLU(inplace=True)
    
  def forward(self, X):
    if self.is_used_conv11:
      Y = self.conv3(X)
    else:
      Y = X
    X = self.conv1(X)
    X = self.relu(X)
    X = self.batch_norm(X)
    X = self.relu(X)
    X = self.conv2(X)
    X = self.batch_norm(X)
    X = self.relu(X+Y)
    X = self.conv4(X)
    return X

In [11]:
X = torch.rand((4, 1, 28, 28)) # shape=(batch_size, channels, width, height)
X = ResidualBlockPytorch(num_channels=1, output_channels=64, strides=2, is_used_conv11=True)(X)
print(X.shape)

torch.Size([4, 64, 14, 14])


In [12]:
class ResNet18PyTorch(nn.Module):
  def __init__(self, residual_blocks, output_shape):
    super(ResNet18PyTorch, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3)
    self.batch_norm = nn.BatchNorm2d(64)
    self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.relu = nn.ReLU()
    self.residual_blocks = nn.Sequential(*residual_blocks)
    self.global_avg_pool = nn.Flatten()
    self.dense = nn.Linear(in_features=512, out_features=output_shape)

  def forward(self, X):
    X = self.conv1(X)
    X = self.batch_norm(X)
    X = self.relu(X)
    X = self.max_pool(X)
    X = self.residual_blocks(X)
    X = self.global_avg_pool(X)
    X = self.dense(X)
    return X

In [13]:
residual_blocks = [
    # Two start conv mapping
    ResidualBlockPytorch(num_channels=64, output_channels=64, strides=2, is_used_conv11=False),
    ResidualBlockPytorch(num_channels=64, output_channels=64, strides=2, is_used_conv11=False),
    # Next three [conv mapping + identity mapping]
    ResidualBlockPytorch(num_channels=64, output_channels=128, strides=2, is_used_conv11=True),
    ResidualBlockPytorch(num_channels=128, output_channels=128, strides=2, is_used_conv11=False),
    ResidualBlockPytorch(num_channels=128, output_channels=256, strides=2, is_used_conv11=True),
    ResidualBlockPytorch(num_channels=256, output_channels=256, strides=2, is_used_conv11=False),
    ResidualBlockPytorch(num_channels=256, output_channels=512, strides=2, is_used_conv11=True),
    ResidualBlockPytorch(num_channels=512, output_channels=512, strides=2, is_used_conv11=False)
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ptmodel = ResNet18PyTorch(residual_blocks, output_shape=10)
ptmodel.to(device)
summary(ptmodel, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 14, 14]           3,200
       BatchNorm2d-2           [-1, 64, 14, 14]             128
              ReLU-3           [-1, 64, 14, 14]               0
         MaxPool2d-4             [-1, 64, 7, 7]               0
            Conv2d-5             [-1, 64, 7, 7]          36,928
              ReLU-6             [-1, 64, 7, 7]               0
       BatchNorm2d-7             [-1, 64, 7, 7]             128
              ReLU-8             [-1, 64, 7, 7]               0
            Conv2d-9             [-1, 64, 7, 7]          36,928
      BatchNorm2d-10             [-1, 64, 7, 7]             128
             ReLU-11             [-1, 64, 7, 7]               0
           Conv2d-12             [-1, 64, 4, 4]           4,160
ResidualBlockPytorch-13             [-1, 64, 4, 4]               0
           Conv2d-14             [-1

In [14]:
import torch.optim as optim
import torch
import torchvision
import torchvision.transforms as transforms
import time

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.05), (0.05))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=8)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(ptmodel.parameters(), lr=0.001, betas=(0.9, 0.99))

In [16]:
def acc(output, label):
    # output: (batch, num_output) float32 ndarray
    # label: (batch, ) int32 ndarray
    return (torch.argmax(output, axis=1)==label).float().mean()

In [17]:
for epoch in range(10):  # loop over the dataset multiple times
    total_loss = 0.0
    tic = time.time()
    tic_step = time.time()
    train_acc = 0.0
    valid_acc = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = ptmodel(inputs)
        train_acc += acc(outputs, labels)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        total_loss += loss.item()
        if i % 500 == 499:
          print("iter %d: loss %.3f, train acc %.3f in %.1f sec" % (
            i+1, total_loss/i, train_acc/i, time.time()-tic_step))
          tic_step = time.time()

    # calculate validation accuracy
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        valid_acc += acc(ptmodel(inputs), labels)

    print("Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec" % (
            epoch, total_loss/len(trainloader), train_acc/len(trainloader),
            valid_acc/len(testloader), time.time()-tic))

print('Finished Training')

iter 500: loss 0.816, train acc 0.764 in 209.2 sec
iter 1000: loss 0.584, train acc 0.841 in 192.3 sec
iter 1500: loss 0.470, train acc 0.877 in 185.3 sec
Epoch 0: loss 0.416, train acc 0.892, test acc 0.972, in 745.2 sec
iter 500: loss 0.164, train acc 0.968 in 190.4 sec
iter 1000: loss 0.158, train acc 0.968 in 185.9 sec
iter 1500: loss 0.151, train acc 0.970 in 190.2 sec
Epoch 1: loss 0.146, train acc 0.970, test acc 0.979, in 717.1 sec
iter 500: loss 0.125, train acc 0.975 in 188.6 sec
iter 1000: loss 0.119, train acc 0.977 in 183.8 sec
iter 1500: loss 0.115, train acc 0.977 in 187.0 sec
Epoch 2: loss 0.113, train acc 0.977, test acc 0.984, in 713.0 sec
iter 500: loss 0.086, train acc 0.986 in 202.3 sec
iter 1000: loss 0.088, train acc 0.983 in 192.1 sec
iter 1500: loss 0.088, train acc 0.983 in 191.0 sec
Epoch 3: loss 0.088, train acc 0.982, test acc 0.983, in 740.1 sec
iter 500: loss 0.075, train acc 0.987 in 187.9 sec
iter 1000: loss 0.078, train acc 0.985 in 185.6 sec
iter 1500