###### ReZNet: (Up to 6x) faster ResNet training with ReZero

In this notebook we will examine how the [ReZero](https://arxiv.org/abs/2003.04887) architecture addition enables or accelerates training in deep [ResNet](https://arxiv.org/pdf/1512.03385.pdf) networks. We will find for example that for a ResNet110 the number of epochs to reach 50% accuracy decreases by a factor of 7 upon implementing ReZero. In this particular example the accuracy after convergence also improves with ReZero. The architecture here differs importantly from [Fixup](https://arxiv.org/pdf/1901.09321.pdf) and [SkipInit](https://arxiv.org/pdf/2002.10444.pdf) in that the skip connection is implemented **after** the nonlinearity to preserve signal propagation.

The official ReZero repo is [here](https://github.com/majumderb/rezero).

This notebook is heavily inspired by [Yerlan Idelbayev's beautiful ResNet implementation](https://github.com/akamaster/pytorch_resnet_cifar10).

Running time of the notebook: 15 minutes on laptop with single RTX 2060 GPU.

Note: This notebook as evaluated with PyTorch 1.4, the test accuracies may differ slightly for other versions.

In [1]:
######################################################################
# Import and set manual seed

import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torch.nn.init as init

torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################
# Define ResNet model as in 
# https://github.com/akamaster/pytorch_resnet_cifar10

def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A', rezero = True):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.rezero = rezero
        if self.rezero:
            self.resweight = self.resweight = nn.Parameter(torch.Tensor([0]), requires_grad=True)
            
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = LambdaLayer(lambda x:
                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        if self.rezero == True:
            # In a ReZero ResNet the skip connection is after the nonlinearity
            out = self.resweight * F.relu(out) + self.shortcut(x)
        elif self.rezero == False:
            # In a vanilla ResNet the skip connection is before the nonlinearity
            out = F.relu(out + self.shortcut(x))
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, rezero = False):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, rezero = rezero)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, rezero = rezero)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, rezero = rezero)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride, rezero = False):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, rezero = rezero))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

######################################################################
# Define various variants

def resnet20(rezero = False):
    return ResNet(BasicBlock, [3, 3, 3], rezero = rezero)


def resnet56(rezero = False):
    return ResNet(BasicBlock, [9, 9, 9], rezero = rezero)


def resnet110(rezero = False):
    return ResNet(BasicBlock, [18, 18, 18], rezero = rezero)


def test(net):
    import numpy as np
    total_params = 0

    for x in filter(lambda p: p.requires_grad, net.parameters()):
        total_params += np.prod(x.data.numpy().shape)
    print("Total number of params {:2.3f}M".format(total_params/1e6))
    print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))

######################################################################
# Define function to train

def train(train_loader, model, criterion, optimizer, epoch,print_freq,lr_scheduler):
    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda()
        input_var = input.cuda()
        target_var = target

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()


            
        if i % print_freq == 0:   
            print('| epoch {:3d} | {:4d}/{:4d} batches | '
          'lr {:02.2f} | ms/batch {:4.0f} | '
          'loss {loss.avg:1.3f} | Top 1 accuracy {top1.avg:2.2f} %'.format(
            epoch+1, i, len(train_loader), lr_scheduler.get_lr()[0],
            1000*batch_time.avg,loss=losses,top1=top1))


def validate(val_loader, model, criterion):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()


            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            

    return losses.avg, top1.avg

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

######################################################################
# Package model setup and training into one simple function

def setup_and_train(model,batch_size = 128, lr = 0.1,momentum = 0.9,
                    weight_decay = 1e-4,epochs = 200,print_freq = 50):
    model = model.to(device)
    start_epoch = 0
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])


    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=batch_size, shuffle=True,
        num_workers=1, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                        normalize,
                ])),
        batch_size=128, shuffle=False,
        num_workers=1, pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()


    
    optimizer = torch.optim.SGD(model.parameters(), lr,
                            momentum=momentum,
                            weight_decay=weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                    milestones=[100, 150], last_epoch=start_epoch - 1)
    best_prec1 = 0
    for epoch in range(start_epoch, epochs):
        epoch_start_time = time.time()
        print('-'*95)
        train(train_loader, model, criterion, optimizer, epoch,print_freq,lr_scheduler)
        lr_scheduler.step()

        # evaluate on validation set
        loss, prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        print('-'*95)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:1.3f} | '
              'valid precision {:3.2f}% (best: {:3.2f}%) '.format(epoch+1, (time.time() - epoch_start_time),
                                         loss, prec1,best_prec1))



## ResNet20

First, we train a ResNet20 for with ReZero for one epoch, and then train a ResNet20 without Rezero until it achieves the same accuracy.

* In this example ReZero accelerates initial training by a factor of about 2.

In [2]:
model = resnet20(rezero=True)
test(model)
setup_and_train(model, batch_size = 128, lr = 0.2, epochs = 1, print_freq = 130)

Total number of params 0.270M
Total layers 20
Files already downloaded and verified
-----------------------------------------------------------------------------------------------
| epoch   1 |    0/ 391 batches | lr 0.20 | ms/batch  322 | loss 2.382 | Top 1 accuracy 5.47 %
| epoch   1 |  130/ 391 batches | lr 0.20 | ms/batch   35 | loss 1.826 | Top 1 accuracy 31.22 %
| epoch   1 |  260/ 391 batches | lr 0.20 | ms/batch   33 | loss 1.665 | Top 1 accuracy 38.30 %
| epoch   1 |  390/ 391 batches | lr 0.20 | ms/batch   33 | loss 1.557 | Top 1 accuracy 42.68 %
-----------------------------------------------------------------------------------------------
| end of epoch   1 | time: 14.14s | valid loss 1.507 | valid precision 49.52% (best: 49.52%) 


In [3]:
model = resnet20(rezero=False)
test(model)
setup_and_train(model, batch_size = 128, lr = 0.2, epochs = 2, print_freq = 130)

Total number of params 0.270M
Total layers 20
Files already downloaded and verified
-----------------------------------------------------------------------------------------------
| epoch   1 |    0/ 391 batches | lr 0.20 | ms/batch   71 | loss 3.223 | Top 1 accuracy 9.38 %
| epoch   1 |  130/ 391 batches | lr 0.20 | ms/batch   31 | loss 2.108 | Top 1 accuracy 23.83 %
| epoch   1 |  260/ 391 batches | lr 0.20 | ms/batch   31 | loss 1.895 | Top 1 accuracy 29.94 %
| epoch   1 |  390/ 391 batches | lr 0.20 | ms/batch   31 | loss 1.780 | Top 1 accuracy 34.10 %
-----------------------------------------------------------------------------------------------
| end of epoch   1 | time: 13.35s | valid loss 1.556 | valid precision 43.43% (best: 43.43%) 
-----------------------------------------------------------------------------------------------
| epoch   2 |    0/ 391 batches | lr 0.20 | ms/batch   79 | loss 1.460 | Top 1 accuracy 46.09 %
| epoch   2 |  130/ 391 batches | lr 0.20 | ms/batch   

## ResNet56

Next, we train a ResNet56 for one epoch, and then train a ResNet56 without Rezero until it achieves the same accuracy.

* In this example ReZero accelerates initial training by a factor of about 3.

In [4]:
model = resnet56(rezero=True)
test(model)
setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 1, print_freq = 130)

Total number of params 0.853M
Total layers 56
Files already downloaded and verified
-----------------------------------------------------------------------------------------------
| epoch   1 |    0/ 391 batches | lr 0.10 | ms/batch  133 | loss 2.377 | Top 1 accuracy 3.12 %
| epoch   1 |  130/ 391 batches | lr 0.10 | ms/batch   94 | loss 1.800 | Top 1 accuracy 31.95 %
| epoch   1 |  260/ 391 batches | lr 0.10 | ms/batch   94 | loss 1.631 | Top 1 accuracy 39.12 %
| epoch   1 |  390/ 391 batches | lr 0.10 | ms/batch   94 | loss 1.494 | Top 1 accuracy 44.63 %
-----------------------------------------------------------------------------------------------
| end of epoch   1 | time: 39.11s | valid loss 1.339 | valid precision 53.50% (best: 53.50%) 


In [5]:
model = resnet56(rezero=False)
test(model)
setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 4, print_freq = 130)

Total number of params 0.853M
Total layers 56
Files already downloaded and verified
-----------------------------------------------------------------------------------------------
| epoch   1 |    0/ 391 batches | lr 0.10 | ms/batch  130 | loss 11.472 | Top 1 accuracy 8.59 %
| epoch   1 |  130/ 391 batches | lr 0.10 | ms/batch   89 | loss 2.912 | Top 1 accuracy 10.69 %
| epoch   1 |  260/ 391 batches | lr 0.10 | ms/batch   89 | loss 2.573 | Top 1 accuracy 12.43 %
| epoch   1 |  390/ 391 batches | lr 0.10 | ms/batch   89 | loss 2.388 | Top 1 accuracy 16.01 %
-----------------------------------------------------------------------------------------------
| end of epoch   1 | time: 37.08s | valid loss 1.904 | valid precision 26.63% (best: 26.63%) 
-----------------------------------------------------------------------------------------------
| epoch   2 |    0/ 391 batches | lr 0.10 | ms/batch  134 | loss 1.885 | Top 1 accuracy 27.34 %
| epoch   2 |  130/ 391 batches | lr 0.10 | ms/batch  

## ResNet110

Next, we train a ResNet110 for one epoch, and then train a ResNet110 without Rezero until it achieves the same accuracy.

* In this example ReZero accelerates initial training by a factor of about 6.

In [6]:
model = resnet110(rezero=True)
test(model)
setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 1, print_freq = 130)

Total number of params 1.728M
Total layers 110
Files already downloaded and verified
-----------------------------------------------------------------------------------------------
| epoch   1 |    0/ 391 batches | lr 0.10 | ms/batch  223 | loss 2.402 | Top 1 accuracy 9.38 %
| epoch   1 |  130/ 391 batches | lr 0.10 | ms/batch  189 | loss 1.793 | Top 1 accuracy 32.82 %
| epoch   1 |  260/ 391 batches | lr 0.10 | ms/batch  189 | loss 1.589 | Top 1 accuracy 41.12 %
| epoch   1 |  390/ 391 batches | lr 0.10 | ms/batch  189 | loss 1.440 | Top 1 accuracy 47.20 %
-----------------------------------------------------------------------------------------------
| end of epoch   1 | time: 78.27s | valid loss 1.641 | valid precision 50.65% (best: 50.65%) 


In [7]:
model = resnet110(rezero=False)
test(model)
setup_and_train(model, batch_size = 128, lr = 0.1, epochs = 8, print_freq = 130)

Total number of params 1.728M
Total layers 110
Files already downloaded and verified
-----------------------------------------------------------------------------------------------
| epoch   1 |    0/ 391 batches | lr 0.10 | ms/batch  222 | loss 12.921 | Top 1 accuracy 10.16 %
| epoch   1 |  130/ 391 batches | lr 0.10 | ms/batch  178 | loss 3.905 | Top 1 accuracy 11.36 %
| epoch   1 |  260/ 391 batches | lr 0.10 | ms/batch  177 | loss 3.072 | Top 1 accuracy 13.86 %
| epoch   1 |  390/ 391 batches | lr 0.10 | ms/batch  177 | loss 2.731 | Top 1 accuracy 17.16 %
-----------------------------------------------------------------------------------------------
| end of epoch   1 | time: 73.58s | valid loss 1.948 | valid precision 27.01% (best: 27.01%) 
-----------------------------------------------------------------------------------------------
| epoch   2 |    0/ 391 batches | lr 0.10 | ms/batch  223 | loss 2.013 | Top 1 accuracy 25.00 %
| epoch   2 |  130/ 391 batches | lr 0.10 | ms/batch