Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#20 from heavengate/add_tsm
Browse files Browse the repository at this point in the history
Add tsm
  • Loading branch information
heavengate committed Apr 8, 2020
2 parents f2e17b2 + 3080634 commit 3203066
Show file tree
Hide file tree
Showing 12 changed files with 1,085 additions and 3 deletions.
5 changes: 4 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
from . import mobilenetv2
from . import darknet
from . import yolov3
from . import tsm

from .resnet import *
from .mobilenetv1 import *
from .mobilenetv2 import *
from .vgg import *
from .darknet import *
from .yolov3 import *
from .tsm import *

__all__ = resnet.__all__ \
+ vgg.__all__ \
+ mobilenetv1.__all__ \
+ mobilenetv2.__all__ \
+ darknet.__all__ \
+ yolov3.__all__
+ yolov3.__all__ \
+ tsm.__all__
147 changes: 147 additions & 0 deletions tsm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# TSM 视频分类模型

---

## 内容

- [模型简介](#模型简介)
- [快速开始](#快速开始)
- [参考论文](#参考论文)


## 模型简介

Temporal Shift Module是由MIT和IBM Watson AI Lab的Ji Lin,Chuang Gan和Song Han等人提出的通过时间位移来提高网络视频理解能力的模块,其位移操作原理如下图所示。

<p align="center">
<img src="./images/temporal_shift.png" height=250 width=800 hspace='10'/> <br />
Temporal shift module
</p>

上图中矩阵表示特征图中的temporal和channel维度,通过将一部分的channel在temporal维度上向前位移一步,一部分的channel在temporal维度上向后位移一步,位移后的空缺补零。通过这种方式在特征图中引入temporal维度上的上下文交互,提高在时间维度上的视频理解能力。

TSM模型是将Temporal Shift Module插入到ResNet网络中构建的视频分类模型,本模型库实现版本为以ResNet-50作为主干网络的TSM模型。

详细内容请参考论文[Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1)

## 快速开始

### 安装说明

#### paddle安装

本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装

#### 代码下载及环境变量设置

克隆代码库到本地,并设置`PYTHONPATH`环境变量

```bash
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=$PYTHONPATH:`pwd`
cd tsm
```

### 数据准备

TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](./dataset/README.md)

#### 小数据集验证

为了便于快速迭代,我们采用了较小的数据集进行动态图训练验证,从Kinetics-400数据集中选取分类标签(label)分别为 0, 2, 3, 4, 6, 7, 9, 12, 14, 15的即前10类数据验证模型精度。

### 模型训练

数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`tsm_checkpoint`目录下。

`main.py`脚本参数可通过如下命令查询

```bash
python main.py --help
```

#### 静态图训练

使用如下方式进行单卡训练:

```bash
export CUDA_VISIBLE_DEVICES=0
python main.py --data=<path/to/dataset> --batch_size=16
```

使用如下方式进行多卡训练:

```bash
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8
```

#### 动态图训练

动态图训练只需要在运行脚本时添加`-d`参数即可。

使用如下方式进行单卡训练:

```bash
export CUDA_VISIBLE_DEVICES=0
python main.py --data=<path/to/dataset> --batch_size=16 -d
```

使用如下方式进行多卡训练:

```bash
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data=<path/to/dataset> --batch_size=8 -d
```

**注意:** 对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数

### 模型评估

可通过如下两种方式进行模型评估。

1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重评估

```bash
python main.py --data=<path/to/dataset> --eval_only
```

2. 加载checkpoint进行精度评估

```bash
python main.py --data=<path/to/dataset> --eval_only --weights=tsm_checkpoint/final
```

#### 评估精度

在10类小数据集下训练模型权重见[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams),评估精度如下:

|Top-1|Top-5|
|:-:|:-:|
|76%|98%|

### 模型推断

可通过如下两种方式进行模型推断。

1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重推断

```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle>
```

2. 加载checkpoint进行精度推断

```bash
python infer.py --data=<path/to/dataset> --label_list=<path/to/label_list> --infer_file=<path/to/pickle> --weights=tsm_checkpoint/final
```

模型推断结果会以如下日志形式输出

```text
2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
```

## 参考论文

- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han

62 changes: 62 additions & 0 deletions tsm/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import paddle.fluid as fluid

import logging
logger = logging.getLogger(__name__)

__all__ = ['check_gpu', 'check_version']


def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"

try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass


def check_version(version='1.7.0'):
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version {} or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
.format(version)

try:
fluid.require_version(version)
except Exception as e:
logger.error(err)
sys.exit(1)
78 changes: 78 additions & 0 deletions tsm/dataset/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# 数据使用说明

## Kinetics数据集

Kinetics数据集是DeepMind公开的大规模视频动作识别数据集,有Kinetics400与Kinetics600两个版本。这里使用Kinetics400数据集,具体的数据预处理过程如下。

### mp4视频下载
在Code\_Root目录下创建文件夹

cd $Code_Root/data/dataset && mkdir kinetics

cd kinetics && mkdir data_k400 && cd data_k400

mkdir train_mp4 && mkdir val_mp4

ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo ](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics)即可下载Kinetics400的mp4视频集合。将kinetics400的训练与验证集合分别下载到data/dataset/kinetics/data\_k400/train\_mp4与data/dataset/kinetics/data\_k400/val\_mp4。

### mp4文件预处理

为提高数据读取速度,提前将mp4文件解帧并打pickle包,dataloader从视频的pkl文件中读取数据(该方法耗费更多存储空间)。pkl文件里打包的内容为(video-id, label, [frame1, frame2,...,frameN])。

在 data/dataset/kinetics/data\_k400目录下创建目录train\_pkl和val\_pkl

cd $Code_Root/data/dataset/kinetics/data_k400

mkdir train_pkl && mkdir val_pkl

进入$Code\_Root/data/dataset/kinetics目录,使用video2pkl.py脚本进行数据转化。首先需要下载[train](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_train.csv)[validation](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_val.csv)数据集的文件列表。

首先生成预处理需要的数据集标签文件

python generate_label.py kinetics-400_train.csv kinetics400_label.txt

然后执行如下程序:

python video2pkl.py kinetics-400_train.csv $Source_dir $Target_dir 8 #以8个进程为例

- 该脚本依赖`ffmpeg`库,请预先安装`ffmpeg`

对于train数据,

Source_dir = $Code_Root/data/dataset/kinetics/data_k400/train_mp4

Target_dir = $Code_Root/data/dataset/kinetics/data_k400/train_pkl

对于val数据,

Source_dir = $Code_Root/data/dataset/kinetics/data_k400/val_mp4

Target_dir = $Code_Root/data/dataset/kinetics/data_k400/val_pkl

这样即可将mp4文件解码并保存为pkl文件。

### 生成训练和验证集list
··
cd $Code_Root/data/dataset/kinetics

ls $Code_Root/data/dataset/kinetics/data_k400/train_pkl/* > train.list

ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > val.list

ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > test.list

ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > infer.list

即可生成相应的文件列表,train.list和val.list的每一行表示一个pkl文件的绝对路径,示例如下:

/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-097
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-114
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-118
...

或者

/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-085
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-086
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-090
...
44 changes: 44 additions & 0 deletions tsm/dataset/kinetics/generate_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

# kinetics-400_train.csv should be down loaded first and set as sys.argv[1]
# sys.argv[2] can be set as kinetics400_label.txt
# python generate_label.py kinetics-400_train.csv kinetics400_label.txt

num_classes = 400

fname = sys.argv[1]
outname = sys.argv[2]
fl = open(fname).readlines()
fl = fl[1:]
outf = open(outname, 'w')

label_list = []
for line in fl:
label = line.strip().split(',')[0].strip('"')
if label in label_list:
continue
else:
label_list.append(label)

assert len(label_list
) == num_classes, "there should be {} labels in list, but ".format(
num_classes, len(label_list))

label_list.sort()
for i in range(num_classes):
outf.write('{} {}'.format(label_list[i], i) + '\n')

outf.close()
Loading

0 comments on commit 3203066

Please sign in to comment.