本项目是本科毕设"基于昇腾AI架构的高效化无人机射频信号识别"的训练代码实现,提供轻量级和标准的 ResNet 模型,用于 2D .npy 格式数据集的训练、评估和可视化。
- ✅ 多种 ResNet 架构(轻量级和标准)
- ✅ FP16 混合精度训练(AMP)
- ✅ Warmup + Cosine Annealing 学习率调度
- ✅ 完整的训练、验证和测试流程
- ✅ 混淆矩阵生成和统计
- ✅ UMAP 可视化(内存优化版本)
- ✅ TensorBoard 日志记录
- ✅ GPU 内存监控和管理
- ✅ 模块化代码架构,易于扩展
- Python 3.12+
- CUDA 13.0+(如需 GPU 加速)
- NVIDIA GPU(推荐 8GB+ 显存)
- Pixi(用于提供 GCC/构建工具链环境)
- direnv(可选,用于自动激活环境变量)
- 克隆项目:
git clone git@github.com:wh-wang132/ResNet.git cd ResNet - 安装依赖(使用 uv):
uv sync
- 同步 Pixi 工具链环境(模型编译依赖):
当前
pixi install
pixi.toml已包含:gxxmakecmake
- 启用 direnv 自动激活(推荐):
项目根目录已提供
.envrc,内容会通过pixi shell-hook自动注入环境变量。之后每次进入项目根目录会自动激活 Pixi 环境。# 首次安装 direnv 后执行一次 direnv allow - 准备数据集:
- 将 .npy 格式数据集放入
Data/目录 - 数据集结构详见 数据准备
- 将 .npy 格式数据集放入
# 完整训练流程(训练 + 测试)
uv run src/base_model_main.py --epochs 20 --model resnet6_2d
# 仅训练
uv run src/base_model_main.py --epochs 20 --Test False
# 仅测试和可视化
uv run src/base_model_main.py --Train False --UMAP True
# 使用不同的模型
uv run src/base_model_main.py --model resnet18_2d
# 指定数据集输出精度
uv run src/base_model_main.py --data_dtype fp32# 最小剪枝 + 微调命令
uv run src/pruning_main.py --base_checkpoint output/base_model/resnet6_2d/epochs20_bs64/best_model.pth
# 调整剪枝比例并开启全局剪枝
uv run src/pruning_main.py \
--base_checkpoint output/base_model/resnet18_2d/epochs20_bs64/best_model.pth \
--pruning_ratio 0.30 \
--global_pruning True \
--finetune_epochs 10
# 仅执行剪枝并保存结果,不做微调
uv run src/pruning_main.py \
--base_checkpoint output/base_model/resnet14_2d/epochs20_bs64/best_model.pth \
--finetune_epochs 0 \
--evaluate_test False| 技术 | 版本 | 用途 |
|---|---|---|
| Python | 3.12+ | 开发语言 |
| PyTorch | 2.10.0+ | 深度学习框架 |
| NumPy | 2.4.3+ | 数值计算 |
| Matplotlib | 3.10.8+ | 数据可视化 |
| Scikit-learn | 1.8.0+ | 机器学习工具 |
| UMAP-learn | 0.5.11+ | 降维可视化 |
| TensorBoard | 2.20.0+ | 训练日志记录 |
| uv | - | 包管理工具 |
| Pixi | - | GCC/Make/CMake 工具链环境管理 |
| direnv | - | 自动激活项目环境变量 |
ResNet/
├── src/
│ ├── base_model_main.py # 基座模型训练入口(项目根目录执行)
│ ├── pruning_main.py # 剪枝 + 微调入口(项目根目录执行)
│ ├── base_model/ # 基座模型核心模块
│ │ ├── dataset.py
│ │ ├── utils.py
│ │ ├── trainer.py
│ │ ├── tester.py
│ │ ├── visualizer.py
│ │ ├── resnet_lightweight.py
│ │ ├── resnet_standard.py
│ │ ├── confusionMatrix.py
│ │ └── lr_scheduler.py
│ ├── pruning/ # 剪枝阶段核心模块
│ │ ├── args.py
│ │ ├── checkpoint.py
│ │ ├── evaluator.py
│ │ ├── output.py
│ │ ├── pruner.py
│ │ ├── topology.py
│ │ ├── trainer.py
│ │ ├── utils.py
│ │ └── README.md
│ └── qat/ # QAT 阶段目录(待实现)
├── docs/ # 文档目录
├── Data/ # 数据集目录
├── output/ # 训练输出目录
├── .envrc # direnv 自动激活(调用 pixi shell-hook)
├── pixi.toml # Pixi 环境定义(含 gxx/make/cmake)
├── pixi.lock # Pixi 锁文件
├── pyproject.toml # 项目依赖配置
├── uv.lock # 锁定依赖版本
├── README.md # 本文件
└── LICENSE # 许可证
| 模型 | 参数量 | 适用场景 |
|---|---|---|
| ResNet-6 | 310,392 | 快速实验,资源受限环境 |
| ResNet-10 | 694,440 | 平衡精度与速度 |
| ResNet-14 | 902,376 | 更高精度,轻量级架构 |
| 模型 | 参数量 | 残差块 |
|---|---|---|
| ResNet-18 | 11.2M | BasicBlock |
| ResNet-34 | 21.3M | BasicBlock |
| ResNet-50 | 23.6M | Bottleneck |
详细模型说明请参考 模型架构。
| 参数 | 默认值 | 说明 |
|---|---|---|
--epochs |
60 | 训练轮数 |
--lr |
0.0003 | 学习率 |
--batch_size |
64 | 批次大小 |
--model_path |
best_model.pth | 模型保存路径 |
--class_num |
24 | 分类数 |
--model |
resnet6_2d | 选择模型架构 |
--data_dir |
Data | 数据集路径 |
--data_dtype |
fp16 | 数据集输出 tensor 精度,可选 fp16/fp32 |
| 参数 | 默认值 | 说明 |
|---|---|---|
--Train |
True | 启用训练 |
--Test |
True | 启用测试 |
--UMAP |
False | 启用 UMAP 可视化 |
| 参数 | 默认值 | 说明 |
|---|---|---|
--dropout_p |
0.3 | Dropout 概率 |
--weight_decay |
0.0001 | 权重衰减 |
| 参数 | 默认值 | 说明 |
|---|---|---|
--warmup_ratio |
0.05 | Warmup 占总步数的比例 |
--warmup_steps |
0 | Warmup 步数(优先使用) |
--min_lr |
1e-6 | 最小学习率 |
--plot_lr_schedule |
True | 绘制学习率曲线 |
--plot_lr_schedule False |
- | 禁用学习率曲线绘制 |
详细参数说明请参考 命令行参数。
训练完成后,输出目录会包含以下文件:
output/base_model/<model>/epochs<epochs>_bs<batch_size>/
├── best_model.pth # 最佳模型权重
├── best_val_acc_info.txt # 最佳验证准确率摘要
├── lr_schedule.png # 学习率调度曲线
├── training_curves.png # 训练曲线(损失、准确率、学习率)
├── Confusion matrix.png # 混淆矩阵图
├── umap_plot.png # UMAP 可视化图(如启用)
└── runs/ # 当前实验目录下的 TensorBoard 日志
剪枝 + 微调阶段默认输出:
output/pruning/<model>/ratio<ratio>_<global|local>_ft<epochs>_bs<batch_size>/
├── best_pruned_model.pth # 最佳剪枝模型 checkpoint
├── best_pruned_info.txt # 最佳剪枝模型验证指标摘要
├── pruning_summary.json # 剪枝前后统计与流程摘要
└── runs/ # 当前实验目录下的 TensorBoard 日志
- 数据准备指南 - 如何准备和组织数据集
- 模型架构说明 - 各种 ResNet 架构的详细说明
- 项目架构分析 - 当前项目分层架构与阶段完成度分析
- 训练参数调优 - 训练参数调优建议
- 剪枝指南 - 基于 torch-pruning 的结构化剪枝与微调说明
- 命令行参数详解 - 完整的命令行参数说明
- 模块说明 - 代码模块结构和功能说明
欢迎提交 Issue 和 Pull Request!请遵循以下规范:
- 代码风格遵循 PEP 8
- 提交前运行测试
- 新功能请添加相应文档
- 提交信息清晰明确
详细规范请参考 贡献指南。
本项目采用 GPLv3 许可证。
如有问题或建议,请通过 Issue 联系。
项目维护: 持续更新中