Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi task mzr #7

Merged
merged 7 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/_base_/datasets/multi_task_medic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(256, 200), backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatMultiTaskLabels'),
dict(type='PackMultiTaskInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(256, 200), backend='pillow'),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatMultiTaskLabels')
dict(type='PackMultiTaskInputs')
]

train_dataloader = dict(
Expand Down
6 changes: 2 additions & 4 deletions configs/_base_/models/mobilenet_v2_1x_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
'humanitarian': dict(type='LinearClsHead', num_classes=4),
'disaster_types': dict(type='LinearClsHead', num_classes=7)
},
common_cfg=dict(
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
),
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))
6 changes: 3 additions & 3 deletions mmcls/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
Brightness, ColorTransform, Contrast, Cutout,
Equalize, Invert, Posterize, RandAugment, Rotate,
Sharpness, Shear, Solarize, SolarizeAdd, Translate)
from .formatting import (Collect, FormatMultiTaskLabels, PackClsInputs,
ToNumpy, ToPIL, Transpose)
from .formatting import (Collect, PackClsInputs, PackMultiTaskInputs, ToNumpy,
ToPIL, Transpose)
from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop,
EfficientNetRandomCrop, Lighting, RandomCrop,
RandomErasing, RandomResizedCrop, ResizeEdge)
Expand All @@ -17,5 +17,5 @@
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop',
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
'FormatMultiTaskLabels'
'PackMultiTaskInputs'
]
58 changes: 33 additions & 25 deletions mmcls/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import defaultdict
from collections.abc import Sequence
from typing import List
from functools import partial

import numpy as np
import torch
Expand Down Expand Up @@ -86,12 +86,6 @@ def transform(self, results: dict) -> dict:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
else:
warnings.warn(
'Cannot get "img" in the input dict of `PackClsInputs`,'
'please make sure `LoadImageFromFile` has been added '
'in the data pipeline or images have been loaded in '
'the dataset.')

data_sample = ClsDataSample()
if 'gt_label' in results:
Expand All @@ -110,7 +104,7 @@ def __repr__(self) -> str:


@TRANSFORMS.register_module()
class FormatMultiTaskLabels(BaseTransform):
class PackMultiTaskInputs(BaseTransform):
"""Convert all image labels of multi-task dataset to a dict of tensor.

Args:
Expand All @@ -131,11 +125,17 @@ class FormatMultiTaskLabels(BaseTransform):
"""

def __init__(self,
metainfo: List[str] = None,
task_handlers=dict(),
multi_task_fields=('gt_label', ),
meta_keys=('sample_idx', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')):
self.metainfo = metainfo
self.multi_task_fields = multi_task_fields
self.meta_keys = meta_keys
self.task_handlers = defaultdict(
partial(PackClsInputs, meta_keys=meta_keys))
for task_name, task_handler in task_handlers.items():
self.task_handlers[task_name] = TRANSFORMS.build(
dict(type=task_handler, meta_keys=meta_keys))

def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Expand All @@ -144,32 +144,40 @@ def transform(self, results: dict) -> dict:
'img': array([[[ 0, 0, 0])
"""
packed_results = dict()
results = results.copy()

if 'img' in results:
img = results['img']
img = results.pop('img')
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
else:
warnings.warn(
'Cannot get "img" in the input dict of `PackClsInputs`,'
'please make sure `LoadImageFromFile` has been added '
'in the data pipeline or images have been loaded in ')

data_sample = MultiTaskDataSample(metainfo=self.metainfo)
if 'gt_label' in results:
gt_label = results['gt_label']
data_sample.set_gt_task(gt_label)
task_results = defaultdict(dict)
for field in self.multi_task_fields:
if field in results:
value = results.pop(field)
for k, v in value.items():
task_results[k].update({field: v})

data_sample = MultiTaskDataSample()
for task_name, task_result in task_results.items():
task_handler = self.task_handlers[task_name]
task_pack_result = task_handler({**results, **task_result})
data_sample.set_field(task_pack_result['data_samples'], task_name)

img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results

def __repr__(self):
repr = self.__class__.__name__
repr += f'(meta_keys={self.meta_keys})'
repr += f'(tasks={self.metainfo})'
task_handlers = {
name: handler.__class__.__name__
for name, handler in self.task_handlers.items()
}
repr += f'(task_handlers={task_handlers}, '
repr += f'multi_task_fields={self.multi_task_fields}, '
repr += f'meta_keys={self.meta_keys})'
return repr


Expand Down
74 changes: 15 additions & 59 deletions mmcls/evaluation/metrics/multi_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Sequence
from typing import Dict, Sequence

from mmengine.evaluator import BaseMetric

from mmcls.registry import METRICS
from mmcls.structures import MultiTaskDataSample


@METRICS.register_module()
Expand Down Expand Up @@ -62,36 +61,6 @@ def __init__(self,
for metric in self.task_metrics[task_name]:
self._metrics[task_name].append(METRICS.build(metric))

def pre_process_nested(self, data_samples: List[MultiTaskDataSample],
task_name):
"""Retrieve data_samples corresponds to the task_name for a data_sample
type MultiTaskDataSample Args :

data_samples (List[MultiTaskDataSample]):The annotation data of every
samples. task_name (str)
"""
task_data_sample = []
for data_sample in data_samples:
task_data_sample.append(
data_sample.to_target_data_sample('MultiTaskDataSample',
task_name).to_dict())
return task_data_sample

def pre_process_cls(self, data_samples: List[MultiTaskDataSample],
task_name):
"""Retrieve data_samples corresponds to the task_name for a data_sample
type ClsDataSample Args :

data_samples (List[MultiTaskDataSample]):The annotation data of every
samples. task_name (str)
"""
task_data_sample_dicts = []
for data_sample in data_samples:
task_data_sample_dicts.append(
data_sample.to_target_data_sample('ClsDataSample',
task_name).to_dict())
return task_data_sample_dicts

def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.

Expand All @@ -101,38 +70,25 @@ def process(self, data_batch, data_samples: Sequence[dict]):
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
data_sample_instances = []
for data_sample in data_samples:
if 'gt_task' in data_sample:
data_sample_instances.append(MultiTaskDataSample().set_gt_task(
data_sample['gt_task']).set_pred_task(
data_sample['pred_task']))
for task_name in self.task_metrics.keys():
filtered_data_samples = []
for data_sample in data_sample_instances:
sample_mask = data_sample.get_task_mask(task_name)
for data_sample in data_samples:
sample_mask = task_name in data_sample
if sample_mask:
filtered_data_samples.append(data_sample)
if type(self.task_metrics[task_name]) != dict:
for metric in self._metrics[task_name]:
# Current implementation is only comptaible
# With 2 types of metrics :
# * Cls Metrics
# * Nested Cls Metrics
# In order to make it work with other
# non-cls heads/metrics, one will have to
# override the current implementation
if metric.__class__.__name__ != 'MultiTasksMetric':
task_data_sample_dicts = self.pre_process_cls(
filtered_data_samples, task_name)
metric.process(data_batch, task_data_sample_dicts)
else:
task_data_sample_dicts = self.pre_process_nested(
filtered_data_samples, task_name)
metric.process(data_batch, task_data_sample_dicts)
filtered_data_samples.append(data_sample[task_name])
for metric in self._metrics[task_name]:
# Current implementation is only comptaible
# With 2 types of metrics :
# * Cls Metrics
# * Nested Cls Metrics
# In order to make it work with other
# non-cls heads/metrics, one will have to
# override the current implementation
metric.process(data_batch, filtered_data_samples)
marouaneamz marked this conversation as resolved.
Show resolved Hide resolved

def compute_metrics(self, results: list) -> dict:
raise Exception('compute metrics should not be used here directly')
raise NotImplementedError(
'compute metrics should not be used here directly')

def evaluate(self, size):
"""Evaluate the model performance of the whole dataset after processing
Expand Down
13 changes: 10 additions & 3 deletions mmcls/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

import torch
import torch.nn as nn

from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
Expand Down Expand Up @@ -61,13 +62,19 @@ def __init__(self,
super(ImageClassifier, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)

self.backbone = MODELS.build(backbone)
def build_module(module):
if isinstance(module, nn.Module):
return module
else:
return MODELS.build(module)

self.backbone: nn.Module = build_module(backbone)

if neck is not None:
self.neck = MODELS.build(neck)
self.neck: nn.Module = build_module(neck)

if head is not None:
self.head = MODELS.build(head)
self.head: nn.Module = build_module(head)

def forward(self,
inputs: torch.Tensor,
Expand Down
Loading