forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#20 from heavengate/add_tsm
Add tsm
- Loading branch information
Showing
12 changed files
with
1,085 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# TSM 视频分类模型 | ||
|
||
--- | ||
|
||
## 内容 | ||
|
||
- [模型简介](#模型简介) | ||
- [快速开始](#快速开始) | ||
- [参考论文](#参考论文) | ||
|
||
|
||
## 模型简介 | ||
|
||
Temporal Shift Module是由MIT和IBM Watson AI Lab的Ji Lin,Chuang Gan和Song Han等人提出的通过时间位移来提高网络视频理解能力的模块,其位移操作原理如下图所示。 | ||
|
||
<p align="center"> | ||
<img src="./images/temporal_shift.png" height=250 width=800 hspace='10'/> <br /> | ||
Temporal shift module | ||
</p> | ||
|
||
上图中矩阵表示特征图中的temporal和channel维度,通过将一部分的channel在temporal维度上向前位移一步,一部分的channel在temporal维度上向后位移一步,位移后的空缺补零。通过这种方式在特征图中引入temporal维度上的上下文交互,提高在时间维度上的视频理解能力。 | ||
|
||
TSM模型是将Temporal Shift Module插入到ResNet网络中构建的视频分类模型,本模型库实现版本为以ResNet-50作为主干网络的TSM模型。 | ||
|
||
详细内容请参考论文[Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1) | ||
|
||
## 快速开始 | ||
|
||
### 安装说明 | ||
|
||
#### paddle安装 | ||
|
||
本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 | ||
|
||
#### 代码下载及环境变量设置 | ||
|
||
克隆代码库到本地,并设置`PYTHONPATH`环境变量 | ||
|
||
```bash | ||
git clone https://github.com/PaddlePaddle/hapi | ||
cd hapi | ||
export PYTHONPATH=$PYTHONPATH:`pwd` | ||
cd tsm | ||
``` | ||
|
||
### 数据准备 | ||
|
||
TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](./dataset/README.md) | ||
|
||
#### 小数据集验证 | ||
|
||
为了便于快速迭代,我们采用了较小的数据集进行动态图训练验证,从Kinetics-400数据集中选取分类标签(label)分别为 0, 2, 3, 4, 6, 7, 9, 12, 14, 15的即前10类数据验证模型精度。 | ||
|
||
### 模型训练 | ||
|
||
数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`tsm_checkpoint`目录下。 | ||
|
||
`main.py`脚本参数可通过如下命令查询 | ||
|
||
```bash | ||
python main.py --help | ||
``` | ||
|
||
#### 静态图训练 | ||
|
||
使用如下方式进行单卡训练: | ||
|
||
```bash | ||
export CUDA_VISIBLE_DEVICES=0 | ||
python main.py --data=<path/to/dataset> --batch_size=16 | ||
``` | ||
|
||
使用如下方式进行多卡训练: | ||
|
||
```bash | ||
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8 | ||
``` | ||
|
||
#### 动态图训练 | ||
|
||
动态图训练只需要在运行脚本时添加`-d`参数即可。 | ||
|
||
使用如下方式进行单卡训练: | ||
|
||
```bash | ||
export CUDA_VISIBLE_DEVICES=0 | ||
python main.py --data=<path/to/dataset> --batch_size=16 -d | ||
``` | ||
|
||
使用如下方式进行多卡训练: | ||
|
||
```bash | ||
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8 -d | ||
``` | ||
|
||
**注意:** 对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数 | ||
|
||
### 模型评估 | ||
|
||
可通过如下两种方式进行模型评估。 | ||
|
||
1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重评估 | ||
|
||
```bash | ||
python main.py --data=<path/to/dataset> --eval_only | ||
``` | ||
|
||
2. 加载checkpoint进行精度评估 | ||
|
||
```bash | ||
python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/final | ||
``` | ||
|
||
#### 评估精度 | ||
|
||
在10类小数据集下训练模型权重见[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams),评估精度如下: | ||
|
||
|Top-1|Top-5| | ||
|:-:|:-:| | ||
|76%|98%| | ||
|
||
### 模型推断 | ||
|
||
可通过如下两种方式进行模型推断。 | ||
|
||
1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重推断 | ||
|
||
```bash | ||
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle> | ||
``` | ||
|
||
2. 加载checkpoint进行精度推断 | ||
|
||
```bash | ||
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle> --weights=tsm_checkpoint/final | ||
``` | ||
|
||
模型推断结果会以如下日志形式输出 | ||
|
||
```text | ||
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6 | ||
``` | ||
|
||
## 参考论文 | ||
|
||
- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import sys | ||
|
||
import paddle.fluid as fluid | ||
|
||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
__all__ = ['check_gpu', 'check_version'] | ||
|
||
|
||
def check_gpu(use_gpu): | ||
""" | ||
Log error and exit when set use_gpu=true in paddlepaddle | ||
cpu version. | ||
""" | ||
err = "Config use_gpu cannot be set as true while you are " \ | ||
"using paddlepaddle cpu version ! \nPlease try: \n" \ | ||
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \ | ||
"\t2. Set use_gpu as false in config file to run " \ | ||
"model on CPU" | ||
|
||
try: | ||
if use_gpu and not fluid.is_compiled_with_cuda(): | ||
logger.error(err) | ||
sys.exit(1) | ||
except Exception as e: | ||
pass | ||
|
||
|
||
def check_version(version='1.7.0'): | ||
""" | ||
Log error and exit when the installed version of paddlepaddle is | ||
not satisfied. | ||
""" | ||
err = "PaddlePaddle version {} or higher is required, " \ | ||
"or a suitable develop version is satisfied as well. \n" \ | ||
"Please make sure the version is good with your code." \ | ||
.format(version) | ||
|
||
try: | ||
fluid.require_version(version) | ||
except Exception as e: | ||
logger.error(err) | ||
sys.exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# 数据使用说明 | ||
|
||
## Kinetics数据集 | ||
|
||
Kinetics数据集是DeepMind公开的大规模视频动作识别数据集,有Kinetics400与Kinetics600两个版本。这里使用Kinetics400数据集,具体的数据预处理过程如下。 | ||
|
||
### mp4视频下载 | ||
在Code\_Root目录下创建文件夹 | ||
|
||
cd $Code_Root/data/dataset && mkdir kinetics | ||
|
||
cd kinetics && mkdir data_k400 && cd data_k400 | ||
|
||
mkdir train_mp4 && mkdir val_mp4 | ||
|
||
ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo ](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics)即可下载Kinetics400的mp4视频集合。将kinetics400的训练与验证集合分别下载到data/dataset/kinetics/data\_k400/train\_mp4与data/dataset/kinetics/data\_k400/val\_mp4。 | ||
|
||
### mp4文件预处理 | ||
|
||
为提高数据读取速度,提前将mp4文件解帧并打pickle包,dataloader从视频的pkl文件中读取数据(该方法耗费更多存储空间)。pkl文件里打包的内容为(video-id, label, [frame1, frame2,...,frameN])。 | ||
|
||
在 data/dataset/kinetics/data\_k400目录下创建目录train\_pkl和val\_pkl | ||
|
||
cd $Code_Root/data/dataset/kinetics/data_k400 | ||
|
||
mkdir train_pkl && mkdir val_pkl | ||
|
||
进入$Code\_Root/data/dataset/kinetics目录,使用video2pkl.py脚本进行数据转化。首先需要下载[train](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_train.csv)和[validation](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_val.csv)数据集的文件列表。 | ||
|
||
首先生成预处理需要的数据集标签文件 | ||
|
||
python generate_label.py kinetics-400_train.csv kinetics400_label.txt | ||
|
||
然后执行如下程序: | ||
|
||
python video2pkl.py kinetics-400_train.csv $Source_dir $Target_dir 8 #以8个进程为例 | ||
|
||
- 该脚本依赖`ffmpeg`库,请预先安装`ffmpeg` | ||
|
||
对于train数据, | ||
|
||
Source_dir = $Code_Root/data/dataset/kinetics/data_k400/train_mp4 | ||
|
||
Target_dir = $Code_Root/data/dataset/kinetics/data_k400/train_pkl | ||
|
||
对于val数据, | ||
|
||
Source_dir = $Code_Root/data/dataset/kinetics/data_k400/val_mp4 | ||
|
||
Target_dir = $Code_Root/data/dataset/kinetics/data_k400/val_pkl | ||
|
||
这样即可将mp4文件解码并保存为pkl文件。 | ||
|
||
### 生成训练和验证集list | ||
·· | ||
cd $Code_Root/data/dataset/kinetics | ||
|
||
ls $Code_Root/data/dataset/kinetics/data_k400/train_pkl/* > train.list | ||
|
||
ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > val.list | ||
|
||
ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > test.list | ||
|
||
ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > infer.list | ||
|
||
即可生成相应的文件列表,train.list和val.list的每一行表示一个pkl文件的绝对路径,示例如下: | ||
|
||
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-097 | ||
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-114 | ||
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-118 | ||
... | ||
|
||
或者 | ||
|
||
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-085 | ||
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-086 | ||
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-090 | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import sys | ||
|
||
# kinetics-400_train.csv should be down loaded first and set as sys.argv[1] | ||
# sys.argv[2] can be set as kinetics400_label.txt | ||
# python generate_label.py kinetics-400_train.csv kinetics400_label.txt | ||
|
||
num_classes = 400 | ||
|
||
fname = sys.argv[1] | ||
outname = sys.argv[2] | ||
fl = open(fname).readlines() | ||
fl = fl[1:] | ||
outf = open(outname, 'w') | ||
|
||
label_list = [] | ||
for line in fl: | ||
label = line.strip().split(',')[0].strip('"') | ||
if label in label_list: | ||
continue | ||
else: | ||
label_list.append(label) | ||
|
||
assert len(label_list | ||
) == num_classes, "there should be {} labels in list, but ".format( | ||
num_classes, len(label_list)) | ||
|
||
label_list.sort() | ||
for i in range(num_classes): | ||
outf.write('{} {}'.format(label_list[i], i) + '\n') | ||
|
||
outf.close() |
Oops, something went wrong.