In [1]:
import os
from pathlib import Path
import math
import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CosineAnnealingLR
from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm

In [2]:
class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=n_classes),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=1)
        return logits


def run(optimizer_method, scheduler_method, learning_rate):
    
    logger = get_dist_logger()
    d_name = f'{optimizer_method}_{scheduler_method}_{learning_rate}'

    # build 
    model = LeNet5(n_classes=10)

    # build dataloaders
    train_dataset = MNIST(
        root=Path('./tmp/'),
        download=True,
        transform = transforms.Compose([transforms.Resize((32, 32)),
                                  transforms.ToTensor()])
    )

    test_dataset = MNIST(
        root=Path('./tmp/'),
        train=False,
        transform = transforms.Compose([transforms.Resize((32, 32)),
                                  transforms.ToTensor()])
    )

    train_dataloader = get_dataloader(dataset=train_dataset,
                                      shuffle=True,
                                      batch_size=gpc.config.BATCH_SIZE,
                                      num_workers=1,
                                      pin_memory=True,
                                      )

    test_dataloader = get_dataloader(dataset=test_dataset,
                                      add_sampler=False,
                                      batch_size=gpc.config.BATCH_SIZE,
                                      num_workers=1,
                                      pin_memory=True,
                                      )

    # build criterion
    criterion = torch.nn.CrossEntropyLoss()

    # optimizer
    if optimizer_method.lower()=='sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
    elif optimizer_method.lower()=='adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4)
    else:
        raise NotImplementedError

    #exponentially increase learning rate from low to high
    def lrs(batch):
        low = math.log2(1e-5)
        high = math.log2(10)
        return 2**(low+(high-low)*batch/len(train_dataloader)/gpc.config.NUM_EPOCHS)

    # lr_scheduler
    if scheduler_method.lower()=='lambda':
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lrs)
    elif scheduler_method.lower()=='multistep':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
    elif scheduler_method.lower()=='onecycle':
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=learning_rate*10, 
            steps_per_epoch=len(train_dataloader), 
            epochs=gpc.config.NUM_EPOCHS
        )

    engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
        model, optimizer, criterion, train_dataloader, test_dataloader
    )
    
    # build a timer to measure time
    timer = MultiTimer()

    # create a trainer object
    trainer = Trainer(
        engine=engine,
        timer=timer,
        logger=logger
    )

    # define the hooks to attach to the trainer
    hook_list = [
        hooks.LossHook(),
        hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
        # hooks.AccuracyHook(accuracy_func=Accuracy()),
        hooks.LogMetricByEpochHook(logger),
        hooks.LogMemoryByEpochHook(logger),
        hooks.LogTimingByEpochHook(timer, logger),

        # you can uncomment these lines if you wish to use them
        hooks.TensorboardHook(log_dir=f'./tb_logs/{d_name}', ranks=[0]),
        # hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
    ]

    # start training
    trainer.fit(
        train_dataloader=train_dataloader,
        epochs=gpc.config.NUM_EPOCHS,
        test_dataloader=test_dataloader,
        test_interval=1,
        hooks=hook_list,
        display_progress=True
    )
    
    logger.log_to_file(f'./tb_logs/{d_name}')

In [None]:
# Propose several learning rates for real training.
learning_rates = [0.1,0.05,0.001]
# Choose one optimizer 
optimizer_methods = ['sgd','adamw']
# Choose two learning rate scheduling method 
scheduler_methods = ['lambda','multistep','onecycle']

# Single launch
config = {'BATCH_SIZE':128,'NUM_EPOCHS':30}
colossalai.launch(config=config,rank=0,world_size=1,host='127.0.0.1',port=1234)

# Loop
for learning_rate in learning_rates:
    for optimizer_method in optimizer_methods:
        for scheduler_method in scheduler_methods:
            print('>>>>>>', optimizer_method, scheduler_method, learning_rate)
            run(optimizer_method, scheduler_method, learning_rate)

>>>>>> sgd lambda 0.1
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to tmp/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting tmp/MNIST/raw/train-images-idx3-ubyte.gz to tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to tmp/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting tmp/MNIST/raw/train-labels-idx1-ubyte.gz to tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to tmp/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting tmp/MNIST/raw/t10k-images-idx3-ubyte.gz to tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to tmp/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting tmp/MNIST/raw/t10k-labels-idx1-ubyte.gz to tmp/MNIST/raw



[Epoch 0 / Train]: 100%|██████████████████████████████████████████████████████████████| 469/469 [00:10<00:00, 43.58it/s]


[Epoch 0 / Test]: 100%|█████████████████████████████████████████████████████████████████| 79/79 [00:01<00:00, 47.20it/s]


[Epoch 1 / Train]:  12%|███████▍                                                       | 55/469 [00:01<00:09, 43.99it/s]