From 77b48910881c17efdc6a1417ce0b38f5d29438b2 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Wed, 24 Mar 2021 10:58:58 +0800 Subject: [PATCH] [Fix] Fix a bug about multi-class in VideoDataset (#723) * Fix 722 * add unittest and update changelog --- docs/changelog.md | 2 ++ mmaction/datasets/video_dataset.py | 9 +-------- .../video_test_list_multi_label.txt | 2 ++ tests/test_data/test_datasets/base.py | 2 ++ .../test_datasets/test_video_dataset.py | 18 ++++++++++++++++++ 5 files changed, 25 insertions(+), 8 deletions(-) create mode 100644 tests/data/annotations/video_test_list_multi_label.txt diff --git a/docs/changelog.md b/docs/changelog.md index 60a69cfba8..5976165b50 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -16,6 +16,8 @@ **Bug and Typo Fixes** +- Fix a bug about multi-class in VideoDataset ([#723](https://github.com/open-mmlab/mmaction2/pull/678)) + **ModelZoo** - Add LFB for AVA2.1 ([#553](https://github.com/open-mmlab/mmaction2/pull/553)) diff --git a/mmaction/datasets/video_dataset.py b/mmaction/datasets/video_dataset.py index bf9a30afce..08b90862be 100644 --- a/mmaction/datasets/video_dataset.py +++ b/mmaction/datasets/video_dataset.py @@ -1,7 +1,5 @@ import os.path as osp -import torch - from .base import BaseDataset from .registry import DATASETS @@ -53,15 +51,10 @@ def load_annotations(self): assert self.num_classes is not None filename, label = line_split[0], line_split[1:] label = list(map(int, label)) - onehot = torch.zeros(self.num_classes) - onehot[label] = 1.0 else: filename, label = line_split label = int(label) if self.data_prefix is not None: filename = osp.join(self.data_prefix, filename) - video_infos.append( - dict( - filename=filename, - label=onehot if self.multi_class else label)) + video_infos.append(dict(filename=filename, label=label)) return video_infos diff --git a/tests/data/annotations/video_test_list_multi_label.txt b/tests/data/annotations/video_test_list_multi_label.txt new file mode 100644 index 0000000000..0d59b257e0 --- /dev/null +++ b/tests/data/annotations/video_test_list_multi_label.txt @@ -0,0 +1,2 @@ +test.mp4 0 3 +test.mp4 0 2 4 diff --git a/tests/test_data/test_datasets/base.py b/tests/test_data/test_datasets/base.py index ced5da4d77..c3fc5c3ce5 100644 --- a/tests/test_data/test_datasets/base.py +++ b/tests/test_data/test_datasets/base.py @@ -41,6 +41,8 @@ def setup_class(cls): 'rawvideo_test_anno.txt') cls.video_ann_file = osp.join(cls.ann_file_prefix, 'video_test_list.txt') + cls.video_ann_file_multi_label = osp.join( + cls.ann_file_prefix, 'video_test_list_multi_label.txt') # pipeline configuration cls.action_pipeline = [] diff --git a/tests/test_data/test_datasets/test_video_dataset.py b/tests/test_data/test_datasets/test_video_dataset.py index 590fbbe2b1..20a7c596e0 100644 --- a/tests/test_data/test_datasets/test_video_dataset.py +++ b/tests/test_data/test_datasets/test_video_dataset.py @@ -28,6 +28,24 @@ def test_video_dataset(self): assert video_infos == [dict(filename=video_filename, label=0)] * 2 assert video_dataset.start_index == 0 + def test_video_dataset_multi_label(self): + video_dataset = VideoDataset( + self.video_ann_file_multi_label, + self.video_pipeline, + data_prefix=self.data_prefix, + multi_class=True, + num_classes=100) + video_infos = video_dataset.video_infos + video_filename = osp.join(self.data_prefix, 'test.mp4') + label0 = [0, 3] + label1 = [0, 2, 4] + labels = [label0, label1] + for info, label in zip(video_infos, labels): + print(info, video_filename) + assert info['filename'] == video_filename + assert set(info['label']) == set(label) + assert video_dataset.start_index == 0 + def test_video_pipeline(self): target_keys = ['filename', 'label', 'start_index', 'modality']