Skip to content

Latest commit

 

History

History
90 lines (61 loc) · 4.32 KB

trainer_engine.md

File metadata and controls

90 lines (61 loc) · 4.32 KB

Build your engine & Customize your trainer

Build your engine

To better understand the function of Engine class, you should know the conception of the process function in common engines. The process function usually controls the behavior over a batch of a dataset, Engine class just controls the process function. For example, common process function looks like this:

def process_function(dataloader, model, criterion, optim):
    optim.zero_grad()
    data, label = next(dataloader)
    output = model(data)
    loss = criterion(output, label)
    loss.backward()
    optim.setp()

In ignite.engine or keras.engine, the process function is always provided by users. However, it is hard for users to write their own functions for pipeline parallelism. Aiming at accessible hybrid parallelism for users, we provide powerful Engine class. It enables pipeline parallelism and offers 1F1B non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in your Engine to adjust learning rate during training.

In order to build your engine, just set model, criterion, optimizer, learning rate scheduler and schedule. Consider the following code as an example.

import torch
import torch.nn as nn
import torchvision.models as models
import colossalai


model = models.resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model)
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
schedule = colossalai.engine.schedule.NoPipelineSchedule()

MyEngine = Engine(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    schedule=schedule
)

More information is in API reference.

Customize your trainer

Overview

Before starting to learn how to customize a trainer meeting your need, you should have a basic understanding about the function of Trainer. We recommend you to read Get Started section and Build your engine first.

Trainer class tends to enable researchers and engineers to use our framework more conveniently, instead of writing their own scripts, we provide Trainer class and you can simply construct it with your own Engine by calling MyTrainer = Trainer(MyEngine). Then use method fit to train or evaluate your model. In order to make our Trainer class more powerful, we add some useful features to it, such as monitor or record running states and metrics which indicate model's performance, or save after a training epoch.

To accomplish that, specific actions must be added to the training or evaluation. BaseHook class allow you to add desired actions in specific time points. We have already created practical hooks for those useful features. What you need to do is just picking the hooks you want.

More detailed class descriptions can be found in API reference.

Example

hooks = [
    dict(type='LogMetricByEpochHook'),
    dict(type='LogTimingByEpochHook'),
    dict(type='LogMemoryByEpochHook'),
    dict(type='AccuracyHook'),
    dict(type='LossHook'),
    # dict(type='TensorboardHook', log_dir='./tfb_logs'),
    # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
    # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]

Above hooks will record metrics, used time and memory usage to log every epoch. Also it prints loss and accuracy to let users monitor the performance of the model.

Hook

You can extend our BaseHook class. Hooks can be called at twelve time points. More detailed information can be found in API reference.

Or extend from MetricHook to write a metric collector. You should also use the decorator @HOOKS.register_module for your own hook class, and import it in your main python script.

For after_train_iter(), it receives the output of engine per iteration, which is a list including output, label and loss.

Note that you can define the priority to arrange the execution order of all hooks.

Metric

You can write your own metric by extending Metric class. It is always used with MetricHook. If you write your own metric hooks, please set the priority carefully and make sure is called before other hooks which may use the results of metrics.

We've already provided some metric hooks. We store metric objects in runner.states['metrics']. It is a dictionary and you can use the name of the metric to access it.