# Train Image Classifiers

In this notebook we will train an image classifier that classify fruit images, using MMClassificaiton.

## Prepare a Dataset

We have already prepared a dataset.

Credit to Zihao: https://github.com/TommyZihao/MMClassification_Tutorials

To download and extract the dataset, in command line:
```
curl -O https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip
unzip -d data fruit30_split.zip
```

The dataset should be categorized by folders, for MMClassification to read.

## Prepare a Config and Checkpoint File

For speed consideration, we use a lightweight neural network, MobileNetV2.

we use mim to download the config file and checkpoint file.

```
mim download mmcls --config mobilenet-v2_8xb32_in1k --dest .
mv mobilenet-v2_8xb32_in1k.py mobilenet-v2_fruit.py
```

If you prefer to play with other models, navigate to [MMClassification model zoo](https://mmclassification.readthedocs.io/en/latest/model_zoo.html).

In [None]:
!mim download mmcls --config mobilenet-v2_8xb32_in1k --dest .
!mv mobilenet-v2_8xb32_in1k.py mobilenet-v2_fruit.py

## Modify the Config File

1. Remove some intermediate item for clean: `dataset_type`, `img_norm_cfg`, `train_pipeline`, `test_pipeline`
1. Modify model
    1. number of class: from 1000 to 30
    2. pretrain weights: from None to the downloaded checkpoint file, as we finetune the model instead of training from scratch
1. Data: for train/val/test 
    1. `type`: `ImageNet` -> `CustomDataset`
    2. `prefix`, which is the root path to images: modify to `"data/fruit30_split/train"` or `"data/fruit30_split/val"`
    3. `ann_file`, use folder name as class name: modify to `None`
1. Runner and Optimizer
    1. number of training epochs: `runner.max_epochs`
    1. learning rates: `optimizer.lr`, usually divided by 8 due to linear scaling rules.
1. Misc
    1. Decrease `log_confg.interval` for small computation power
    1. Increase `checkpoint_config.interval` to avoid saving too many checkpoint, to same time and disk space
1. Further parameter tuning you may try
    1. learning rates: Decrease `optimizer.lr` for finetuning 
    1. configure learning scheduler to decrease learning when loss saturates. Moreover, by setting `by_epoch=False`, we decrease learning rate by iteration instead of by epoches.
    1. Monitor loss decrease and re-tune
    1. More available lr_schedulers are available in [mmcv](https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py)

## Launch Training

In command line

```
mim train mmcls mobilenet-v2_fruit.py
```

## Understand Logs


The log is long but mainly contains the following parts:

1. Toolbox information
2. Dumped Config files
3. Model Initialization Logs
    1. Check `mmcls - INFO - load checkpoint from local path: mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth`, which means pretrained weights are loaded correctly.
4. Information on Hooks: we don't configure this explicitly in this tutorial, so ignore that
5. Training progress
    1. Training logs: including current learning, training loss, time consumption, memory occupation
    2. Validation logs: Accuracy on validation set

## Test the Model

The trained model (checkpoint file) is usually saved under `work_dirs/{experiment_name}/latest.pth`. 
We can load it to test with a new image. 

In [None]:
from mmcls.apis import init_model, inference_model

model = init_model('mobilenet-v2_fruit.py', 'work_dirs/mobilenet-v2_fruit/latest.pth')
result = inference_model(model, 'banana.png')
print(result)

## PyTorch codes under the hood

### Runner

Runner construct the framework of training.

Specifically, MMClassification is based on `mmcv.EpochBasedRunner`.

In [None]:
## PyTorch codes under the hood

In [None]:

class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.
    This runner train models epoch by epoch.
    """

    def run_iter(self, data_batch: Any, train_mode: bool, **kwargs) -> None:
        if self.batch_processor is not None:
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=train_mode, **kwargs)
        elif train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            'and "model.val_step()" must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
            del self.data_batch
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')
            del self.data_batch
        self.call_hook('after_val_epoch')

    def run(self,
            data_loaders: List[DataLoader],
            workflow: List[Tuple[str, int]],
            max_epochs: Optional[int] = None,
            **kwargs) -> None:
        """Start running.
        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_epochs is not None:
            warnings.warn(
                'setting max_epochs in run is deprecated, '
                'please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs = max_epochs

        assert self._max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('Hooks will be executed in the following order:\n%s',
                         self.get_hook_info())
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self._max_epochs)
        self.call_hook('before_run')

        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')