Skip to content

Commit

Permalink
[Fix] Fix a bug about multi-class in VideoDataset (open-mmlab#723)
Browse files Browse the repository at this point in the history
* Fix 722

* add unittest and update changelog
  • Loading branch information
irvingzhang0512 committed Mar 24, 2021
1 parent afce5ca commit 77b4891
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

**Bug and Typo Fixes**

- Fix a bug about multi-class in VideoDataset ([#723](https://github.com/open-mmlab/mmaction2/pull/678))

**ModelZoo**

- Add LFB for AVA2.1 ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
Expand Down
9 changes: 1 addition & 8 deletions mmaction/datasets/video_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os.path as osp

import torch

from .base import BaseDataset
from .registry import DATASETS

Expand Down Expand Up @@ -53,15 +51,10 @@ def load_annotations(self):
assert self.num_classes is not None
filename, label = line_split[0], line_split[1:]
label = list(map(int, label))
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
else:
filename, label = line_split
label = int(label)
if self.data_prefix is not None:
filename = osp.join(self.data_prefix, filename)
video_infos.append(
dict(
filename=filename,
label=onehot if self.multi_class else label))
video_infos.append(dict(filename=filename, label=label))
return video_infos
2 changes: 2 additions & 0 deletions tests/data/annotations/video_test_list_multi_label.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test.mp4 0 3
test.mp4 0 2 4
2 changes: 2 additions & 0 deletions tests/test_data/test_datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def setup_class(cls):
'rawvideo_test_anno.txt')
cls.video_ann_file = osp.join(cls.ann_file_prefix,
'video_test_list.txt')
cls.video_ann_file_multi_label = osp.join(
cls.ann_file_prefix, 'video_test_list_multi_label.txt')

# pipeline configuration
cls.action_pipeline = []
Expand Down
18 changes: 18 additions & 0 deletions tests/test_data/test_datasets/test_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ def test_video_dataset(self):
assert video_infos == [dict(filename=video_filename, label=0)] * 2
assert video_dataset.start_index == 0

def test_video_dataset_multi_label(self):
video_dataset = VideoDataset(
self.video_ann_file_multi_label,
self.video_pipeline,
data_prefix=self.data_prefix,
multi_class=True,
num_classes=100)
video_infos = video_dataset.video_infos
video_filename = osp.join(self.data_prefix, 'test.mp4')
label0 = [0, 3]
label1 = [0, 2, 4]
labels = [label0, label1]
for info, label in zip(video_infos, labels):
print(info, video_filename)
assert info['filename'] == video_filename
assert set(info['label']) == set(label)
assert video_dataset.start_index == 0

def test_video_pipeline(self):
target_keys = ['filename', 'label', 'start_index', 'modality']

Expand Down

0 comments on commit 77b4891

Please sign in to comment.