In [None]:
import time
import torch
import torchvision
from torchvision.transforms import v2

""" 在使用 torch.nn.ReLU 之前, 需要使用 torch.nn.BatchNorm2d 和 torch.nn.BatchNorm1d 将数据分布归一! 否则 torch.nn.ReLU 的效果不如 torch.nn.Sigmoid """

class LeNet(torch.nn.Module):
    def __init__(self, act: torch.nn.Module):
        super().__init__() # 执行父类的 __init__() 方法
        self.act = act
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=(2, 2))
        self.bn1 = torch.nn.BatchNorm2d(num_features=6)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.bn2 = torch.nn.BatchNorm2d(num_features=16)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.fc1 = torch.nn.Linear(in_features=400, out_features=120)
        self.bn3 = torch.nn.BatchNorm1d(num_features=120)
        self.fc2 = torch.nn.Linear(in_features=120, out_features=84)
        self.bn4 = torch.nn.BatchNorm1d(num_features=84)
        self.fc3 = torch.nn.Linear(in_features=84, out_features=10)
        self.bn5 = torch.nn.BatchNorm1d(num_features=10)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)           # torch.Size([100, 6, 28, 28])
        x = self.bn1(x)             # torch.Size([100, 6, 28, 28])
        x = self.act(x)             # torch.Size([100, 6, 28, 28])
        x = self.pool1(x)           # torch.Size([100, 6, 14, 14])
        x = self.conv2(x)           # torch.Size([100, 16, 10, 10])
        x = self.bn2(x)             # torch.Size([100, 16, 10, 10])
        x = self.act(x)             # torch.Size([100, 16, 10, 10])
        x = self.pool2(x)           # torch.Size([100, 16, 5, 5])
        x = x.view(x.size(0), -1)   # torch.Size([100, 400])
        x = self.fc1(x)             # torch.Size([100, 120])
        x = self.bn3(x)             # torch.Size([100, 120])
        x = self.act(x)             # torch.Size([100, 120])
        x = self.fc2(x)             # torch.Size([100, 84])
        x = self.bn4(x)             # torch.Size([100, 84])
        x = self.act(x)             # torch.Size([100, 84])
        x = self.fc3(x)             # torch.Size([100, 10])
        x = self.bn5(x)             # torch.Size([100, 10])
        x = self.act(x)             # torch.Size([100, 10])
        # x = torch.nn.functional.log_softmax(x, dim=1) # torch.Size([100, 10])

        return x

def get_data_loader(train: bool = True, batch_size: int = 1):
    return torch.utils.data.DataLoader(
        dataset=torchvision.datasets.MNIST(
            "code/python/datasets",
            train=train,
            transform=v2.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True)
            ]),
            download=True
        ),
        batch_size=batch_size,
        shuffle=True
    )

def evaluate(test_data, net: torch.nn.Module):
    n_correct = 0
    n_total = 0
    with torch.no_grad(): # 不计算梯度
        for (x, y) in test_data:
            outputs = net.forward(x)
            for i, output in enumerate(outputs): # torch.Size([10])
                # print(output.size())
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total

def main():
    train_data = get_data_loader(train=True, batch_size=100)
    test_data = get_data_loader(train=False, batch_size=100)
    net = LeNet(act=torch.nn.ReLU())
    # net = LeNet(act=torch.nn.Sigmoid())

    print("initial accuracy:", evaluate(test_data, net))

    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(3):
        t0 = time.perf_counter()
        for (x, y) in train_data:
            net.zero_grad() # 清空梯度
            output = net.forward(x) # 前向传播
            loss = torch.nn.CrossEntropyLoss()(output, y) # 等价于 nll_loss(log_softmax)
            # loss = torch.nn.functional.nll_loss(output, y) # 计算损失
            loss.backward() # 反向传播
            optimizer.step() # 优化参数
        t1 = time.perf_counter()
        print("epoch ", epoch, "accuracy: ", evaluate(test_data, net), "time: ", t1 - t0)

    # for (n, (x, _)) in enumerate(test_data):
    #     if n > 3:
    #         break
    #     predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))
    #     plt.figure(n)
    #     plt.imshow(x[0].view(28, 28))
    #     plt.title("prediction: " + str(int(predict)))
    # plt.show()

if __name__ == "__main__":
    main()
    # test_data = get_data_loader(train=False)
    # evaluate(test_data, LeNet())
    # for (x, y) in test_data:
    #     # print(x.size(), y.size())
    #     # print(x[0].size())
    #     x = LeNet().forward(x)
    #     break

