Skip to content

Commit

Permalink
Merge pull request #7 from marouaneamz/multi_task_mzr
Browse files Browse the repository at this point in the history
Multi task mzr
  • Loading branch information
piercus committed Dec 6, 2022
2 parents 87d78fd + 3dc8324 commit 3fe628e
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 546 deletions.
4 changes: 2 additions & 2 deletions configs/_base_/datasets/multi_task_medic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(256, 200), backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatMultiTaskLabels'),
dict(type='PackMultiTaskInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(256, 200), backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatMultiTaskLabels')
dict(type='PackMultiTaskInputs')
]

train_dataloader = dict(
Expand Down
6 changes: 2 additions & 4 deletions configs/_base_/models/mobilenet_v2_1x_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
'humanitarian': dict(type='LinearClsHead', num_classes=4),
'disaster_types': dict(type='LinearClsHead', num_classes=7)
},
common_cfg=dict(
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))
6 changes: 3 additions & 3 deletions mmcls/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
Brightness, ColorTransform, Contrast, Cutout,
Equalize, Invert, Posterize, RandAugment, Rotate,
Sharpness, Shear, Solarize, SolarizeAdd, Translate)
from .formatting import (Collect, FormatMultiTaskLabels, PackClsInputs,
ToNumpy, ToPIL, Transpose)
from .formatting import (Collect, PackClsInputs, PackMultiTaskInputs, ToNumpy,
ToPIL, Transpose)
from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop,
EfficientNetRandomCrop, Lighting, RandomCrop,
RandomErasing, RandomResizedCrop, ResizeEdge)
Expand All @@ -17,5 +17,5 @@
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop',
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
'FormatMultiTaskLabels'
'PackMultiTaskInputs'
]
58 changes: 33 additions & 25 deletions mmcls/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import defaultdict
from collections.abc import Sequence
from typing import List
from functools import partial

import numpy as np
import torch
Expand Down Expand Up @@ -86,12 +86,6 @@ def transform(self, results: dict) -> dict:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
else:
warnings.warn(
'Cannot get "img" in the input dict of `PackClsInputs`,'
'please make sure `LoadImageFromFile` has been added '
'in the data pipeline or images have been loaded in '
'the dataset.')

data_sample = ClsDataSample()
if 'gt_label' in results:
Expand All @@ -110,7 +104,7 @@ def __repr__(self) -> str:


@TRANSFORMS.register_module()
class FormatMultiTaskLabels(BaseTransform):
class PackMultiTaskInputs(BaseTransform):
"""Convert all image labels of multi-task dataset to a dict of tensor.
Args:
Expand All @@ -131,11 +125,17 @@ class FormatMultiTaskLabels(BaseTransform):
"""

def __init__(self,
metainfo: List[str] = None,
task_handlers=dict(),
multi_task_fields=('gt_label', ),
meta_keys=('sample_idx', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')):
self.metainfo = metainfo
self.multi_task_fields = multi_task_fields
self.meta_keys = meta_keys
self.task_handlers = defaultdict(
partial(PackClsInputs, meta_keys=meta_keys))
for task_name, task_handler in task_handlers.items():
self.task_handlers[task_name] = TRANSFORMS.build(
dict(type=task_handler, meta_keys=meta_keys))

def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Expand All @@ -144,32 +144,40 @@ def transform(self, results: dict) -> dict:
'img': array([[[ 0, 0, 0])
"""
packed_results = dict()
results = results.copy()

if 'img' in results:
img = results['img']
img = results.pop('img')
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
else:
warnings.warn(
'Cannot get "img" in the input dict of `PackClsInputs`,'
'please make sure `LoadImageFromFile` has been added '
'in the data pipeline or images have been loaded in ')

data_sample = MultiTaskDataSample(metainfo=self.metainfo)
if 'gt_label' in results:
gt_label = results['gt_label']
data_sample.set_gt_task(gt_label)
task_results = defaultdict(dict)
for field in self.multi_task_fields:
if field in results:
value = results.pop(field)
for k, v in value.items():
task_results[k].update({field: v})

data_sample = MultiTaskDataSample()
for task_name, task_result in task_results.items():
task_handler = self.task_handlers[task_name]
task_pack_result = task_handler({**results, **task_result})
data_sample.set_field(task_pack_result['data_samples'], task_name)

img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results

def __repr__(self):
repr = self.__class__.__name__
repr += f'(meta_keys={self.meta_keys})'
repr += f'(tasks={self.metainfo})'
task_handlers = {
name: handler.__class__.__name__
for name, handler in self.task_handlers.items()
}
repr += f'(task_handlers={task_handlers}, '
repr += f'multi_task_fields={self.multi_task_fields}, '
repr += f'meta_keys={self.meta_keys})'
return repr


Expand Down
74 changes: 15 additions & 59 deletions mmcls/evaluation/metrics/multi_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Sequence
from typing import Dict, Sequence

from mmengine.evaluator import BaseMetric

from mmcls.registry import METRICS
from mmcls.structures import MultiTaskDataSample


@METRICS.register_module()
Expand Down Expand Up @@ -62,36 +61,6 @@ def __init__(self,
for metric in self.task_metrics[task_name]:
self._metrics[task_name].append(METRICS.build(metric))

def pre_process_nested(self, data_samples: List[MultiTaskDataSample],
task_name):
"""Retrieve data_samples corresponds to the task_name for a data_sample
type MultiTaskDataSample Args :
data_samples (List[MultiTaskDataSample]):The annotation data of every
samples. task_name (str)
"""
task_data_sample = []
for data_sample in data_samples:
task_data_sample.append(
data_sample.to_target_data_sample('MultiTaskDataSample',
task_name).to_dict())
return task_data_sample

def pre_process_cls(self, data_samples: List[MultiTaskDataSample],
task_name):
"""Retrieve data_samples corresponds to the task_name for a data_sample
type ClsDataSample Args :
data_samples (List[MultiTaskDataSample]):The annotation data of every
samples. task_name (str)
"""
task_data_sample_dicts = []
for data_sample in data_samples:
task_data_sample_dicts.append(
data_sample.to_target_data_sample('ClsDataSample',
task_name).to_dict())
return task_data_sample_dicts

def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
Expand All @@ -101,38 +70,25 @@ def process(self, data_batch, data_samples: Sequence[dict]):
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
data_sample_instances = []
for data_sample in data_samples:
if 'gt_task' in data_sample:
data_sample_instances.append(MultiTaskDataSample().set_gt_task(
data_sample['gt_task']).set_pred_task(
data_sample['pred_task']))
for task_name in self.task_metrics.keys():
filtered_data_samples = []
for data_sample in data_sample_instances:
sample_mask = data_sample.get_task_mask(task_name)
for data_sample in data_samples:
sample_mask = task_name in data_sample
if sample_mask:
filtered_data_samples.append(data_sample)
if type(self.task_metrics[task_name]) != dict:
for metric in self._metrics[task_name]:
# Current implementation is only comptaible
# With 2 types of metrics :
# * Cls Metrics
# * Nested Cls Metrics
# In order to make it work with other
# non-cls heads/metrics, one will have to
# override the current implementation
if metric.__class__.__name__ != 'MultiTasksMetric':
task_data_sample_dicts = self.pre_process_cls(
filtered_data_samples, task_name)
metric.process(data_batch, task_data_sample_dicts)
else:
task_data_sample_dicts = self.pre_process_nested(
filtered_data_samples, task_name)
metric.process(data_batch, task_data_sample_dicts)
filtered_data_samples.append(data_sample[task_name])
for metric in self._metrics[task_name]:
# Current implementation is only comptaible
# With 2 types of metrics :
# * Cls Metrics
# * Nested Cls Metrics
# In order to make it work with other
# non-cls heads/metrics, one will have to
# override the current implementation
metric.process(data_batch, filtered_data_samples)

def compute_metrics(self, results: list) -> dict:
raise Exception('compute metrics should not be used here directly')
raise NotImplementedError(
'compute metrics should not be used here directly')

def evaluate(self, size):
"""Evaluate the model performance of the whole dataset after processing
Expand Down
52 changes: 30 additions & 22 deletions mmcls/evaluation/metrics/single_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,18 @@ def process(self, data_batch, data_samples: Sequence[dict]):
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
# when predictions is called without data_sample in input
# it is create empty one without gt_labl those can not use
# by metric
if 'gt_label' in data_sample:
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Expand Down Expand Up @@ -418,20 +422,24 @@ def process(self, data_batch, data_samples: Sequence[dict]):
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
num_classes = self.num_classes or data_sample.get(
'num_classes')
assert num_classes is not None, \
'The `num_classes` must be specified if `pred_label` has '\
'only `label`.'
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = num_classes
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
# when predictions is called without data_sample in input
# it is create empty one without gt_labl those can not use
# by metric
if 'gt_label' in data_sample:
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
num_classes = self.num_classes or data_sample.get(
'num_classes')
assert num_classes is not None, \
'The `num_classes` must be specified if `pred_label` '\
'has only `label`.'
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = num_classes
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.
self.results.append(result)

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Expand Down
13 changes: 10 additions & 3 deletions mmcls/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

import torch
import torch.nn as nn

from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
Expand Down Expand Up @@ -61,13 +62,19 @@ def __init__(self,
super(ImageClassifier, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)

self.backbone = MODELS.build(backbone)
def build_module(module):
if isinstance(module, nn.Module):
return module
else:
return MODELS.build(module)

self.backbone: nn.Module = build_module(backbone)

if neck is not None:
self.neck = MODELS.build(neck)
self.neck: nn.Module = build_module(neck)

if head is not None:
self.head = MODELS.build(head)
self.head: nn.Module = build_module(head)

def forward(self,
inputs: torch.Tensor,
Expand Down
23 changes: 12 additions & 11 deletions mmcls/models/heads/cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,15 @@ def _get_predictions(self, cls_score, data_samples):
pred_scores = F.softmax(cls_score, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()

if data_samples is not None:
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
data_sample.set_pred_score(score).set_pred_label(label)
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))

return data_samples
out_data_samples = []
if data_samples is None:
data_samples = [None for _ in range(pred_scores.size(0))]

for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
if data_sample is None:
data_sample = ClsDataSample()

data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample)
return out_data_samples
Loading

0 comments on commit 3fe628e

Please sign in to comment.