In [2]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision

In [27]:
class Residual(nn.Module):
    def __init__(self, input_channel, num_channel,
                 use_1x1conv=False, stride=1):

        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=input_channel, out_channels=num_channel,
                               kernel_size=3, padding=1, stride=stride)  # use stride to down sampling, this stride correspond to conv3's stride

        self.conv2 = nn.Conv2d(in_channels=num_channel, out_channels=num_channel,
                               kernel_size=3, padding=1)

        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels=input_channel, out_channels=num_channel,
                                   kernel_size=1, stride=stride)
        else:
            self.conv3 = None

        self.bn1 = nn.BatchNorm2d(num_channel)
        self.bn2 = nn.BatchNorm2d(num_channel)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

In [4]:
blk = Residual(input_channel=3, num_channel=3)
X = torch.randn(4, 3, 6, 6)
Y = blk(X)
Y.shape

torch.Size([4, 3, 6, 6])

In [5]:
blk = Residual(input_channel=3, num_channel=6, use_1x1conv=True, stride=2)
Y = blk(X)
Y.shape

torch.Size([4, 6, 3, 3])

In [6]:
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

In [7]:
def resnet_block(input_channel, num_channel, num_residual, first_block=False):
    blk = []

    for i in range(num_residual):
        if i == 0 and not first_block:
            blk.append(Residual(input_channel, num_channel, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(num_channel, num_channel))
    return blk

In [8]:
b2 = nn.Sequential(*resnet_block(64, 64, num_residual=2, first_block=True))  # without 1x1
b3 = nn.Sequential(*resnet_block(64, 128, num_residual=2))
b4 = nn.Sequential(*resnet_block(128, 256, num_residual=2))
b5 = nn.Sequential(*resnet_block(256, 512, num_residual=2))

In [9]:
net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(), nn.Linear(512, 10))

In [10]:
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)

Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 128, 28, 28])
Sequential output shape:	 torch.Size([1, 256, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 512, 1, 1])
Flatten output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 10])


In [11]:
import torchvision

In [12]:
def load_data_fashion_mnist(batch_size, resize=None, root='./dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory."""
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())

    transform = torchvision.transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True,
                                                    download=True, transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False,
                                                   download=True, transform=transform)

    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_iter, test_iter

In [23]:
lr, num_epochs, batch_size = 0.05, 10, 64
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=96)

In [24]:
import time

def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        device = list(net.parameters())[0].device

    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval()
                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum()
                net.train()
            else:  # 自定义的模型, 3.13节之后不会用到, 不考虑GPU
                if 'is_training' in net.__code__.co_varnames:  # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item()
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
            n += y.shape[0]
    return acc_sum / n


def sgd(params, lr, batch_size):
    for param in params:
        #  Modifying param with param.data will not be passed to the calculation diagram
        param.data -= lr * param.grad / batch_size

def train_func(net: nn.Module, train_iter, test_iter, loss, num_epoch: int, batch_size: int, device,
               params=None, lr=None,
               optimizer=None,
               writer=None):
    """
    :param net:
    :param train_iter:
    :param test_iter:
    :param loss:
    :param num_epoch:
    :param batch_size:
    :param device:
    :param params:
    :param lr:
    :param optimizer:
    :param writer: for tensorboard
    :return:
    """
    net = net.to(device)
    print("training on ", device)

    if loss is None:
        loss = torch.nn.CrossEntropyLoss()

    for epoch in range(num_epoch):
        train_loss_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for batch_count, (X, y) in enumerate(train_iter):
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss_fn = loss(y_hat, y)

            if optimizer is not None:  # use module
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()

            loss_fn.backward()

            if optimizer is None:
                sgd(params, lr, batch_size)
            else:
                optimizer.step()

            train_loss_sum += loss_fn.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).cpu().sum().item()
            n += y.shape[0]

            if writer is not None:
                if batch_count % 100 == 99:
                    writer.add_scalar('train_loss',
                                      train_loss_sum / n,
                                      epoch * len(train_iter) + batch_count)

        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_loss_sum / n, train_acc_sum / n, test_acc, time.time() - start))

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_func(net, train_iter, test_iter, None, num_epochs, batch_size, device, lr=lr, optimizer=optimizer)

training on  cuda
epoch 1, loss 0.0157, train acc 0.708, test acc 0.818, time 141.3 sec
epoch 2, loss 0.0071, train acc 0.834, test acc 0.837, time 148.3 sec
epoch 3, loss 0.0061, train acc 0.855, test acc 0.808, time 145.1 sec
epoch 4, loss 0.0055, train acc 0.871, test acc 0.845, time 126.2 sec
epoch 5, loss 0.0051, train acc 0.881, test acc 0.862, time 127.0 sec
epoch 6, loss 0.0048, train acc 0.890, test acc 0.854, time 128.0 sec
epoch 7, loss 0.0045, train acc 0.894, test acc 0.886, time 128.7 sec
epoch 8, loss 0.0043, train acc 0.900, test acc 0.877, time 129.8 sec
epoch 9, loss 0.0041, train acc 0.903, test acc 0.890, time 124.9 sec
epoch 10, loss 0.0040, train acc 0.907, test acc 0.833, time 127.2 sec


In [31]:
my_layer = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)

In [33]:
X = torch.randn((1, 1, 224, 224))

In [34]:
my_layer(X).shape

torch.Size([1, 64, 112, 112])