In [None]:
import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch神经网络模块
import torch.optim as optim  # 导入PyTorch优化器模块
from torchvision.datasets.mnist import MNIST  # 导入PyTorch的MNIST数据集
import torchvision.transforms as transforms  # 导入PyTorch的图像预处理模块
from torch.utils.data import DataLoader  # 导入PyTorch的数据加载器
import visdom  # 导入Visdom库，用于可视化
# import onnx  # 导入ONNX库，用于模型的序列化和跨平台部署

viz = visdom.Visdom()  # 创建一个Visdom实例

# 加载并预处理训练数据
data_train = MNIST('./data/mnist',  # 数据集的路径
                   download=True,  # 如果数据集不存在，就下载数据集
                   transform=transforms.Compose([  # 定义数据预处理操作
                       transforms.Resize((32, 32)),  # 将图像大小调整为32x32
                       transforms.ToTensor()]))  # 将图像转换为PyTorch张量

# 加载并预处理测试数据
data_test = MNIST('./data/mnist',  # 数据集的路径
                  train=False,  # 加载测试数据
                  download=True,  # 如果数据集不存在，就下载数据集
                  transform=transforms.Compose([  # 定义数据预处理操作
                      transforms.Resize((32, 32)),  # 将图像大小调整为32x32
                      transforms.ToTensor()]))  # 将图像转换为PyTorch张量

# 创建数据加载器，用于在训练和测试过程中加载数据
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)  # 训练数据加载器
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)  # 测试数据加载器

In [None]:
import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.C1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.S2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.C3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.S4 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.C5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1)
        self.F6 = nn.Linear(in_features=120, out_features=84)
        self.OUTPUT = nn.Linear(in_features=84, out_features=10)
 
        
    def forward(self, x):
        # print('input:', x.size())
        x = self.C1(x)
        # print('C1:', x.size())
        x = self.S2(torch.relu(x))
        # print('S2:', x.size())
        x = self.C3(x)
        # print('C3:', x.size())
        x = self.S4(torch.relu(x))
        # print('S4:', x.size())
        x = torch.relu(self.C5(x))
        # print('C5:', x.size())
        x = x.view(x.size(0), -1)
        # print('view:', x.size())
        x = torch.relu(self.F6(x))
        # print('F6:', x.size())
        x = self.OUTPUT(x)
        # print('OUTPUT:', x.size())
        return x


In [None]:
net = LeNet()  # 创建LeNet模型实例
criterion = nn.CrossEntropyLoss()  # 定义损失函数为交叉熵损失
class GradientDescentOptimizer:
    def __init__(self, lr):
        self.lr = lr

    def step(self, weights, gradients):
        return weights - self.lr * gradients
    

# optimizer = optim.Adam(net.parameters(), lr=2e-3)  # 定义优化器为Adam，学习率为0.002



# 初始化优化器
optimizer = GradientDescentOptimizer(lr=0.01)






def train():
    net.train()  # 将模型设置为训练模式
    for i, (images, labels) in enumerate(data_train_loader):  # 遍历训练数据
        optimizer.zero_grad()  # 清零梯度
        output = net(images)  # 前向传播
        loss = criterion(output, labels)  # 计算损失

        
        # 计算梯度
        gradients = compute_gradients(loss, model.parameters())

        # 更新权重
        net.parameters() = optimizer.step(net.parameters(), gradients)


        # 每10个批次打印一次损失
        if i % 10 == 0:
            print('Train - , Batch: %d, Loss: %f' % ( i, loss.detach().cpu().item()))
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重

def test():
    net.eval()  # 将模型设置为评估模式
    total_correct = 0
    avg_loss = 0.0
    for i, (images, labels) in enumerate(data_test_loader):  # 遍历测试数据
        output = net(images)  # 前向传播
        avg_loss += criterion(output, labels).sum()  # 累计损失
        pred = output.detach().max(1)[1]  # 获取预测结果
        total_correct += pred.eq(labels.view_as(pred)).sum()  # 计算正确预测的数量

    avg_loss /= len(data_test)  # 计算平均损失
    print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))  # 打印平均损失和准确率


In [None]:
train()
train()
test()