# AlexNet

    '''
    参数设置参考Justin 598WI2022课件

    todo
        原论文 conv1 输入为 3x224x224
        227是合理值，卷积后正好输出56
        输入224，卷积后输出55.25近似56
    '''

<img src="./image/alexnet.png" alt="Model Image" width="800">

In [None]:
# 环境配置
%cd ../../
import sys
sys.path.append('./python')

In [None]:
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

from sgd_cv.model import AlexNet

## 测试模型

In [None]:
# 测试模型结构
model = AlexNet(num_classes=10)
print(model)

# 随机生成一个批次的输入 (cifar10 图像大小: 3x224x224)
input_tensor = torch.randn(1, 3, 224, 224) # 227
output = model(input_tensor)

print(f"输入张量大小: {input_tensor.shape}")
print(f"输出张量大小: {output.shape}")  # 应为 [1, 10]

## 训练模型

In [None]:
# 数据集加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(227), # 224
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225]),
                                ])
# CIFAR10
train_dataset = torchvision.datasets.CIFAR10('./data/CIFAR10/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10('./data/CIFAR10/', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 定义模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [None]:
# 训练
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

In [None]:
# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")