Skip to content

Commit

Permalink
Merge pull request #8 from marouaneamz/multi_task_mzr
Browse files Browse the repository at this point in the history
add eval _mask
  • Loading branch information
piercus committed Dec 9, 2022
2 parents 7983caf + da32c23 commit deb3bf4
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 58 deletions.
11 changes: 2 additions & 9 deletions mmcls/evaluation/metrics/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,10 @@ def process(self, data_batch, data_samples: Sequence[dict]):
for task_name in self.task_metrics.keys():
filtered_data_samples = []
for data_sample in data_samples:
sample_mask = task_name in data_sample
if sample_mask:
eval_mask = data_sample[task_name]['eval_mask']
if eval_mask:
filtered_data_samples.append(data_sample[task_name])
for metric in self._metrics[task_name]:
# Current implementation is only comptaible
# With 2 types of metrics :
# * Cls Metrics
# * Nested Cls Metrics
# In order to make it work with other
# non-cls heads/metrics, one will have to
# override the current implementation
metric.process(data_batch, filtered_data_samples)

def compute_metrics(self, results: list) -> dict:
Expand Down
52 changes: 22 additions & 30 deletions mmcls/evaluation/metrics/single_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,14 @@ def process(self, data_batch, data_samples: Sequence[dict]):
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
# when predictions is called without data_sample in input
# it is create empty one without gt_labl those can not use
# by metric
if 'gt_label' in data_sample:
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Expand Down Expand Up @@ -422,24 +418,20 @@ def process(self, data_batch, data_samples: Sequence[dict]):
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
# when predictions is called without data_sample in input
# it is create empty one without gt_labl those can not use
# by metric
if 'gt_label' in data_sample:
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
num_classes = self.num_classes or data_sample.get(
'num_classes')
assert num_classes is not None, \
'The `num_classes` must be specified if `pred_label` '\
'has only `label`.'
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = num_classes
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
num_classes = self.num_classes or data_sample.get(
'num_classes')
assert num_classes is not None, \
'The `num_classes` must be specified if `pred_label` '\
'has only `label`.'
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = num_classes
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Expand Down
2 changes: 1 addition & 1 deletion mmcls/models/heads/cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _get_loss(self, cls_score: torch.Tensor,
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample | None] = None
data_samples: List[Union[ClsDataSample, None]] = None
) -> List[ClsDataSample]:
"""Inference without augmentation.
Expand Down
4 changes: 4 additions & 0 deletions mmcls/models/heads/multi_task_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def predict(

for task_name, task_samples in predictions_dict.items():
for data_sample, task_sample in zip(data_samples, task_samples):
task_sample.set_field(
task_name in data_sample,
'eval_mask',
field_type='metainfo')
if task_name in data_sample:
data_sample.get(task_name).update(task_sample)
else:
Expand Down
41 changes: 23 additions & 18 deletions tests/test_evaluation/test_metrics/test_multi_task_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch

from mmcls.evaluation.metrics import MultiTasksMetric
from mmcls.structures import ClsDataSample, MultiTaskDataSample
from mmcls.structures import ClsDataSample


class MultiTaskMetric(TestCase):
data = zip([
data_pred = [
{
'task0': torch.tensor([0.7, 0.0, 0.3]),
'task1': torch.tensor([0.5, 0.2, 0.3])
Expand All @@ -17,21 +17,23 @@ class MultiTaskMetric(TestCase):
'task0': torch.tensor([0.0, 0.0, 1.0]),
'task1': torch.tensor([0.0, 0.0, 1.0])
},
], [{
'task0': 0,
'task1': 2
}, {
'task0': 2,
'task1': 2
}])
pred = []
for score, label in data:
]
data_gt = [{'task0': 0, 'task1': 2}, {'task1': 2}]

preds = []
for i, pred in enumerate(data_pred):
sample = {}
for task_name in score:
task_sample = ClsDataSample().set_pred_score(score[task_name])
task_sample.set_gt_label(label[task_name])
for task_name in pred:
task_sample = ClsDataSample().set_pred_score(pred[task_name])
if task_name in data_gt[i]:
task_sample.set_gt_label(data_gt[i][task_name])
task_sample.set_field(True, 'eval_mask', field_type='metainfo')
else:
task_sample.set_field(
False, 'eval_mask', field_type='metainfo')
sample[task_name] = task_sample.to_dict()
pred.append(sample)

preds.append(sample)
data2 = zip([
{
'task0': torch.tensor([0.7, 0.0, 0.3]),
Expand Down Expand Up @@ -69,17 +71,20 @@ class MultiTaskMetric(TestCase):
task_sample = ClsDataSample().set_pred_score(score[task_name])
task_sample.set_gt_label(label[task_name])
sample[task_name] = task_sample.to_dict()
sample[task_name]['eval_mask'] = True
else:
sample[task_name] = {}
sample[task_name]['eval_mask'] = True
for task_name2 in score[task_name]:
task_sample = ClsDataSample().set_pred_score(
score[task_name][task_name2])
task_sample.set_gt_label(label[task_name][task_name2])
sample[task_name][task_name2] = task_sample.to_dict()
pred2.append(sample)
sample[task_name][task_name2]['eval_mask'] = True

pred3 = [MultiTaskDataSample().to_dict()]
pred2.append(sample)

pred3 = [{'task0': {'eval_mask': False}, 'task1': {'eval_mask': False}}]
task_metrics = {
'task0': [dict(type='Accuracy', topk=(1, ))],
'task1': [
Expand Down Expand Up @@ -107,7 +112,7 @@ def test_evaluate(self):

# Test with score (use score instead of label if score exists)
metric = MultiTasksMetric(self.task_metrics)
metric.process(None, self.pred)
metric.process(None, self.preds)
results = metric.evaluate(2)
self.assertIsInstance(results, dict)
self.assertAlmostEqual(results['task0_accuracy/top1'], 100)
Expand Down

0 comments on commit deb3bf4

Please sign in to comment.