In [None]:
import torch
# 为什么要进行池化操作
# 池化操作是为了减少特征图的大小，降低计算复杂度，防止过拟合
# 池化操作是卷积神经网络中常用的操作之一，主要用于下采样特征图

input = torch.tensor(
    [
        [1, 2, 0, 3, 1],
        [0, 1, 2, 3, 1],
        [1, 2, 1, 0, 0],
        [5, 2, 3, 1, 1],
        [2, 1, 0, 1, 1],
    ]
)

input = torch.reshape(input, (-1, 1, 5, 5))


class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # 最大池化层
        # kernel_size=3表示池化核大小为3，步长为1,ceil_mode=True表示向上取整
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self, x):
        x = self.maxpool1(x)
        return x
    

# 创建模型
model = MyModule()
output = model(input)
print(f'output: {output}')

output: tensor([[[[2]]]])


In [5]:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 创建数据集
dataset = torchvision.datasets.CIFAR10(
    root="./dataset",
    transform=torchvision.transforms.ToTensor(),
    train=False,
    download=True,
)

# 创建数据加载器
dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=True, drop_last=False)

# 创建TensorBoard
writer = SummaryWriter("./logs")
step = 0
# 训练模型
for data in dataloader:
    images, labels = data
    print(f"images.shape: {images.shape}")
    # 将数据传入模型
    output = model(images)
    print(f"output.shape: {output.shape}")
    # 将数据写入TensorBoard
    writer.add_images("images-input", images, step)
    writer.add_images("images-output", output, step)
    step += 1

# 关闭TensorBoard
writer.close()

Files already downloaded and verified
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 32])
output.shape: torch.Size([64, 3, 10, 10])
images.shape: torch.Size([64, 3, 32, 3