# This is a sample Jupyter Notebook

Below is an example of a code cell. 
Put your cursor into the cell and press Shift+Enter to execute it and select the next one, or click 'Run Cell' button.

Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.

To learn more about Jupyter Notebooks in PyCharm, see [help](https://www.jetbrains.com/help/pycharm/ipython-notebook-support.html).
For an overview of PyCharm, go to Help -> Learn IDE features or refer to [our documentation](https://www.jetbrains.com/help/pycharm/getting-started.html).

In [1]:
print("Hello World")

Hello World


In [2]:
import torch
import os
import io
import zipfile
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR
import datetime
import matplotlib.pyplot as plt
import time



def setup_device():
    """设置设备配置"""
    #print("PyTorch 版本：", torch.__version__)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("设备：", device)
    #print("CUDA 可用：", torch.cuda.is_available())
    #print("cuDNN 已启用：", torch.backends.cudnn.enabled)
    #print("支持的 CUDA 版本：", torch.version.cuda)
    #print("cuDNN 版本：", torch.backends.cudnn.version())
    return device


class MyDataset(Dataset):
    def __init__(self, zip_path, transform=None, preload_to_memory=True):
        self.zip_path = zip_path
        self.transform = transform
        self.preload_to_memory = preload_to_memory

        # 初始化数据结构
        self.samples = []
        self.labels_set = set()
        self.images = [] if preload_to_memory else None
        self.zip_file = None

        # 打开ZIP文件
        self.zip_file = zipfile.ZipFile(zip_path, 'r')

        # 获取ZIP文件中的所有BMP文件
        bmp_files = [f for f in self.zip_file.namelist()
                    if f.lower().endswith('.bmp') and not f.endswith('/')]

        # 第一步：收集所有文件信息和标签
        file_info = []
        for img_path in bmp_files:
            img_name = os.path.basename(img_path)
            # 根据你的文件命名规则提取标签
            # 例如：对于 "0_0.bmp"，提取第一个字符 "0" 作为标签
            first_char = img_name[0]  # 提取文件名第一个字符作为标签
            file_info.append((img_path, first_char))
            self.labels_set.add(first_char)

        # 建立标签映射
        self.classes = sorted(list(self.labels_set))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.idx_to_class = {i: cls_name for cls_name, i in self.class_to_idx.items()}

        # 第二步：根据预加载选项构建样本
        if preload_to_memory:
            self.images = []
            for img_path, label in file_info:
                # 从ZIP文件中读取图像
                with self.zip_file.open(img_path) as f:
                    image_data = f.read()
                    image = Image.open(io.BytesIO(image_data)).convert('L')
                    self.images.append(image)
                    label_idx = self.class_to_idx[label]
                    self.samples.append((len(self.images) - 1, label_idx))
        else:
            for img_path, label in file_info:
                label_idx = self.class_to_idx[label]
                self.samples.append((img_path, label_idx))

        print(f"数据集加载完成：{len(self.samples)} 个样本，{len(self.classes)} 个类别")
        print(f"类别映射：{self.class_to_idx}")

    def __getitem__(self, idx):
        if self.preload_to_memory:
            img_idx, label_idx = self.samples[idx]
            image = self.images[img_idx]
        else:
            img_path, label_idx = self.samples[idx]
            # 从ZIP文件中读取图像
            with self.zip_file.open(img_path) as f:
                image_data = f.read()
                image = Image.open(io.BytesIO(image_data)).convert('L')

        if self.transform:
            image = self.transform(image)
        return image, label_idx

    def __len__(self):
        return len(self.samples)

    def get_original_label(self, idx):
        _, label_idx = self.samples[idx]
        return self.idx_to_class.get(label_idx)

    def close(self):
        """关闭ZIP文件"""
        if self.zip_file:
            self.zip_file.close()

    def __del__(self):
        """析构函数，确保ZIP文件被关闭"""
        self.close()


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=3, stride=1, padding=2,dilation = 2)
        self.bn1 = nn.BatchNorm2d(10)

        self.conv2 = nn.Conv2d(10, 20, kernel_size=3, stride=1, padding=1)
        self.bn2= nn.BatchNorm2d(20)

        self.fc1 = nn.Linear(20 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        #x = torch.relu(self.conv1(x))
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.max_pool2d(x, 2)
        # x = torch.relu(self.conv2(x))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 20 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def train(train_loader,model, device, optimizer, epoch, loss_fn):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        #loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 49:
            print('{:.6f}'.format( loss.item()))
            #if args.dry_run:
              #  break


def test(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test result: \n Accuracy:{(100 * correct):>0.1f}%, Avg loss:{test_loss:>8f}")


def main():
    # 设备设置
    device = setup_device()
    torch.manual_seed(42)

    epochs = 15#时效果稳定99.2%~99.3%
    #epochs = 5 #epochs = 5  Accuracy>=99.2%
    batch_size = 100
    num_workers = 0 # 统一设置工作进程数（必须=0）

    # 数据变换
    transform = transforms.Compose([
        #transforms.Resize((28, 28)),
        #transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # 创建数据集和数据加载器（一次性创建）
    train_dataset = MyDataset('data/1-Digit-TrainSet.zip', transform=transform)
    test_dataset = MyDataset('data/1-Digit-TestSet.zip', transform=transform)

    # 创建数据加载器（使用相同的num_workers）
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory = True,
        #prefetch_factor= 2,#单线程不可用
        num_workers=num_workers
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,  # 测试集通常不需要shuffle
        pin_memory=True,
        #prefetch_factor=2,#单线程不可用
        num_workers=num_workers
    )

    # 模型创建和训练
    model = SimpleCNN().to(device)
    print(model)
    loss_fn = nn.CrossEntropyLoss()

    #optimizer = torch.optim.Adam(model.parameters(), lr= 0.001 )#收敛速度慢
    optimizer = torch.optim.Adadelta(model.parameters(), lr= 1)
    scheduler = StepLR(optimizer, step_size=3, gamma=0.7)

    for epoch in range(1, epochs + 1):
        test(test_dataloader, model, loss_fn, device)
        train(train_dataloader, model, device, optimizer, epoch, loss_fn)
        scheduler.step()
        if epoch % 5 == 0:
            current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            #torch.save(model.state_dict(), f'models/model_{current_time}_epoch_{epoch}.pt')
            #print(f"模型已保存: models/model_{current_time}_epoch_{epoch}.pt")
    test(test_dataloader, model, loss_fn, device)

    print("训练完成!")


if __name__ == "__main__":
    main()

设备： cuda:0
数据集加载完成：60000 个样本，10 个类别
类别映射：{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
数据集加载完成：10000 个样本，10 个类别
类别映射：{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
SimpleCNN(
  (conv1): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
  (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=980, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
Test result: 
 Accuracy:9.0%, Avg loss:2.299569
0.399451
0.249750
0.135281
0.146359
0.167526
0.039293
0.039751
0.096251
0.036110
0.043454
0.060675
0.091914
Test result: 
 Accuracy:98.3%, Avg loss:0.048961
0.045189
0.099772
0.027312
0.144304
0.014965
0.044718
0.024317
0.057719
0.1343

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# 解析数据
data =
lines = data.strip().split('\n')
losses = []
accuracies = []
avg_losses = []

current_acc = 0.0
current_avg_loss = 0.0

for line in lines:
    parts = line.split('\t')
    loss = float(parts[0])
    losses.append(loss)

    if len(parts) > 1 and parts[1]:
        # 提取准确率数值（去掉百分号）
        current_acc = float(parts[1].replace('%', ''))
    accuracies.append(current_acc)

    if len(parts) > 2 and parts[2]:
        current_avg_loss = float(parts[2])
    avg_losses.append(current_avg_loss)

# 创建图表
plt.figure(figsize=(15, 10))

# 1. 损失函数变化图
plt.subplot(2, 2, 1)
plt.plot(losses, 'b-', alpha=0.7, linewidth=0.8)
plt.title('Training Loss per Batch')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.yscale('log')  # 使用对数坐标更好地显示变化
plt.grid(True, alpha=0.3)

# 2. 平均损失变化图
plt.subplot(2, 2, 2)
# 只显示有平均损失值的点
avg_loss_indices = [i for i, val in enumerate(avg_losses) if val > 0]
avg_loss_values = [avg_losses[i] for i in avg_loss_indices]
plt.plot(avg_loss_indices, avg_loss_values, 'ro-', markersize=4)
plt.title('Average Loss per Epoch')
plt.xlabel('Batch')
plt.ylabel('Avg Loss')
plt.grid(True, alpha=0.3)

# 3. 准确率变化图
plt.subplot(2, 2, 3)
# 只显示准确率变化的点
acc_change_indices = []
acc_change_values = []
for i, acc in enumerate(accuracies):
    if i == 0 or acc != accuracies[i-1]:
        acc_change_indices.append(i)
        acc_change_values.append(acc)

plt.plot(acc_change_indices, acc_change_values, 'go-', markersize=4)
plt.title('Accuracy Progress')
plt.xlabel('Batch')
plt.ylabel('Accuracy (%)')
plt.grid(True, alpha=0.3)

# 4. 综合视图 - 损失和准确率的对比
plt.subplot(2, 2, 4)
fig, ax1 = plt.subplots(figsize=(12, 6))

color = 'tab:blue'
ax1.set_xlabel('Batch')
ax1.set_ylabel('Loss', color=color)
ax1.plot(losses, color=color, alpha=0.7, linewidth=0.8)
ax1.tick_params(axis='y', labelcolor=color)
ax1.set_yscale('log')

ax2 = ax1.twinx()
color = 'tab:green'
ax2.set_ylabel('Accuracy (%)', color=color)
ax2.plot(acc_change_indices, acc_change_values, color=color, linewidth=2, marker='o', markersize=3)
ax2.tick_params(axis='y', labelcolor=color)

plt.title('Loss and Accuracy Progress')
plt.tight_layout()

plt.show()

# 打印训练总结
print("训练总结:")
print(f"总batch数: {len(losses)}")
print(f"最终准确率: {accuracies[-1]:.1f}%")
print(f"最终平均损失: {avg_losses[-1]:.6f}")
print(f"最小损失: {min(losses):.6f}")
print(f"最大损失: {max(losses):.6f}")