In [None]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt


In [None]:
torch.manual_seed(42)

# define data preprocessing
transform = transforms.Compose([
	transforms.Grayscale(1),
	transforms.ToTensor(),
	transforms.Normalize((0.5,), (0.5,)),
	transforms.Resize((28, 28))
])

# load the FashionMNIST dataset
train_dataset = datasets.FashionMNIST(root='data', train=True, download=False, transform=transform)
test_dataset = datasets.FashionMNIST(root='data', train=False, download=False, transform=transform)


# create data loaders
batch_size = 1024
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# define model

In [None]:
class FMNIST(nn.Module):
	def __init__(self):
		super().__init__()
		self.fc1 = nn.Linear(28*28, 128)
		self.fc2 = nn.Linear(128, 10)

	def forward(self, x):
		x = torch.flatten(x, start_dim=1)
		x = torch.relu(self.fc1(x))
		x = self.fc2(x)
		return x

# 解释:
# 1. x = torch.flatten(x, start_dim=1)：将输入张量x展平成二维（batch_size, 28*28），适配全连接层输入。
# 2. x = torch.relu(self.fc1(x))：先通过第一个全连接层，然后应用ReLU激活函数，增加非线性。
# 3. x = self.fc2(x)：通过第二个全连接层，输出为10类的logits。
# 4. return x：返回最终的输出结果。

# 检查是否支持apple sillicon、cuda或cpu
if torch.backends.mps.is_available():
	device = torch.device("mps")
elif torch.cuda.is_available():
	device = torch.device("cuda")
else:
	device = torch.device("cpu")

print(f"Using device: {device}")

# 实例化模型
model = FMNIST().to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)


# 训练模型

In [None]:
epochs = 50
best_accuracy = 0.0
best_model_weights = None

In [None]:

# 加载已保存的权重
checkpoint_path = './best_fmnist_model.pth'
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print("Loaded weights from best_fmnist_model.pth")
else:
    print("No checkpoint found, training from scratch.")

In [None]:
# 训练参数可视化配置
train_losses = []
test_accuracies = []

# 训练模型


# 加载已保存的权重
checkpoint_path = './best_fmnist_model.pth'
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print("Loaded weights from best_fmnist_model.pth")
else:
    print("No checkpoint found, training from scratch.")

for epoch in range(epochs):
	running_loss = 0.0
	model.train()
	for inputs, labels in train_loader:
		inputs, labels = inputs.to(device), labels.to(device)
		
		# 前向传播
		outputs = model(inputs)
		loss = criterion(outputs, labels)
		
		optimizer.zero_grad() # 清除梯度
		ouputs = model(inputs) # 前向传播
		loss = criterion(outputs, labels) # 计算损失
		loss.backward() # 反向传播
		optimizer.step() # 更新参数
		running_loss += loss.item() 
		avg_loss = running_loss / len(train_loader)
	train_losses.append(avg_loss)  # 记录每轮loss
	print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

	# 验证模型
	model.eval() # 设置为评估模式
	correct = 0
	total = 0
	with torch.no_grad():
		for inputs, labels in test_loader:
			inputs, labels = inputs.to(device), labels.to(device)
			outputs = model(inputs)
			_, predicted = torch.max(outputs, 1)
			total += labels.size(0)
			correct += (predicted == labels).sum().item()
	accuracy = correct / total
	test_accuracies.append(accuracy)  # 记录每轮accuracy
	print(f"Epoch [{epoch+1}/{epochs}], Accuracy: {accuracy:.2%}")

	# 保存最佳模型
	if accuracy > best_accuracy:
		best_accuracy = accuracy
		best_model_weights = model.state_dict()
		print(f"New best model found with accuracy: {best_accuracy:.2%}")
		# 保存最佳模型权重
		if best_model_weights is not None:
			torch.save(best_model_weights, './best_fmnist_model.pth')
			print("Best model weights saved.")

	# 训练后画图
	plt.figure(figsize=(12,5))
	plt.subplot(1,2,1)
	plt.plot(train_losses, label='Train Loss')
	plt.xlabel('Epoch')
	plt.ylabel('Loss')
	plt.title('Training Loss')
	plt.legend()

	plt.subplot(1,2,2)
	plt.plot(test_accuracies, label='Test Accuracy')
	plt.xlabel('Epoch')
	plt.ylabel('Accuracy')
	plt.title('Test Accuracy')
	plt.legend()

	plt.show()

# 可视化预测结果

In [None]:
import torch.nn.functional as F
import random

model.eval()
inputs, labels = next(iter(test_loader))
num_samples = 10
indices = random.sample(range(inputs.size(0)), num_samples)

plt.figure(figsize=(15, 2 * num_samples))
for i, idx in enumerate(indices):
    img = inputs[idx].unsqueeze(0).to(device)
    label = labels[idx].item()
    with torch.no_grad():
        outputs = model(img)
        probs = F.softmax(outputs, dim=1).cpu().numpy().flatten()
    # 显示图片
    plt.subplot(num_samples, 2, 2*i+1)
    plt.imshow(inputs[idx].squeeze(), cmap='gray')
    plt.title(f"True Label: {label}")
    plt.axis('off')
    # 显示概率条形图
    plt.subplot(num_samples, 2, 2*i+2)
    bars = plt.bar(range(10), probs)
    plt.xlabel('Class')
    plt.ylabel('Probability')
    plt.title('Predicted Probabilities')
    plt.xticks(range(10))
    for j, bar in enumerate(bars):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f"{probs[j]:.2f}", 
                 ha='center', va='bottom', fontsize=8)
plt.tight_layout()
plt.show()