# Can we improve model accuracy by using a smaller model?

Using fastai `XResNet` we built a model smaller than `ResNeXt 18 deep` that improved accuracy by ~0.5% after 5 epochs of Imagenette training.

We also found that reducing batch size to 32 improved accuracy by ~0.5%.

| model              | batch size | seconds per epoch | accuracy |
| -------------------|------------|-------------------|----------|
| xse_resnext50      | 64         | 62                | 84.8%    |
| xse_resnext18_deep | 64         | 20                | 84.8%    |
| mini net           | 64         | 17                | 85.3%    |
| mini net           | 32         | 17                | 85.8%    |

Note: this notebook is based on [train_imagenette.py](https://github.com/fastai/fastai/blob/master/nbs/examples/train_imagenette.py)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pete88b/data-science/blob/master/fastai-things/train-imagenette-mininet.ipynb)

In [1]:
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
    !pip install -Uqq fastai
    !pip install -Uqq fastscript

In [1]:
from fastai.basics import *
from fastai.vision.all import *
from fastai.callback.all import *
from fastai.distributed import *
from fastprogress import fastprogress
from torchvision.models import *
from fastai.vision.models.xresnet import *
from fastai.callback.mixup import *
from fastscript import *

torch.backends.cudnn.benchmark = True
fastprogress.MAX_COLS = 80

In [2]:
def get_dls(size, woof, bs, sh=0., workers=None):
    if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
    else        : path = URLs.IMAGEWOOF     if woof else URLs.IMAGENETTE
    source = untar_data(path)
    if workers is None: workers = min(8, num_cpus())
    # Resize seems to give slightly better accuracy than RandomResizedCrop
    resize_ftm = Resize(size) # RandomResizedCrop(size, min_scale=0.35)
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=[resize_ftm, FlipItem(0.5)],
                       batch_tfms=RandomErasing(p=0.3, max_count=3, sh=sh) if sh else None)
    return dblock.dataloaders(source, path=source, bs=bs, num_workers=workers)

In [3]:
@call_parse
def main(
    gpu:   Param("GPU to run on", int)=None,
    woof:  Param("Use imagewoof (otherwise imagenette)", int)=0,
    lr:    Param("Learning rate", float)=1e-2,
    size:  Param("Size (px: 128,192,256)", int)=128,
    sqrmom:Param("sqr_mom", float)=0.99,
    mom:   Param("Momentum", float)=0.9,
    eps:   Param("epsilon", float)=1e-6,
    epochs:Param("Number of epochs", int)=5,
    bs:    Param("Batch size", int)=64,
    mixup: Param("Mixup", float)=0.,
    opt:   Param("Optimizer (adam,rms,sgd,ranger)", str)='ranger',
    arch:  Param("Architecture", str)='xresnet50',
    sh:    Param("Random erase max proportion", float)=0.,
    sa:    Param("Self-attention", int)=0,
    sym:   Param("Symmetry for self-attention", int)=0,
    beta:  Param("SAdam softplus beta", float)=0.,
    act_fn:Param("Activation function", str)='Mish',
    fp16:  Param("Use mixed precision training", int)=0,
    pool:  Param("Pooling method", str)='AvgPool',
    dump:  Param("Print model; don't train", int)=0,
    runs:  Param("Number of times to repeat training", int)=1,
    meta:  Param("Metadata (ignored)", str)='',
    wd:    Param("Weight decay", float)=1e-2
):
    "Training of Imagenette."

    #gpu = setup_distrib(gpu)
    if gpu is not None: torch.cuda.set_device(gpu)
    if   opt=='adam'  : opt_func = partial(Adam, mom=mom, sqr_mom=sqrmom, eps=eps)
    elif opt=='rms'   : opt_func = partial(RMSprop, sqr_mom=sqrmom)
    elif opt=='sgd'   : opt_func = partial(SGD, mom=mom)
    elif opt=='ranger': opt_func = partial(ranger, mom=mom, sqr_mom=sqrmom, eps=eps, beta=beta)

    dls = get_dls(size, woof, bs, sh=sh)
    if not gpu: 
        print(f'epochs: {epochs}; lr: {lr}; size: {size}; sqrmom: {sqrmom}; mom: {mom}; eps: {eps}')
        print(f'fp16: {fp16}; arch: {arch}; wd: {wd}; act_fn: {act_fn}; bs: {bs}')
        print(f'pool: {pool}; woof: {woof}; sh:{sh}')
        
    m,act_fn,pool = [globals()[o] for o in (arch,act_fn,pool)]
    
    final_accuracies = L()
    
    for run in range(runs):
        print(f'Run: {run}')
        learn = Learner(dls, m(n_out=10, act_cls=act_fn, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
        if dump: print(learn.model); exit()
        if fp16: learn = learn.to_fp16()
        cbs = MixUp(mixup) if mixup else []
        #n_gpu = torch.cuda.device_count()
        #if gpu is None and n_gpu: learn.to_parallel()
        if num_distrib()>1: learn.to_distributed(gpu) # Requires `-m fastai.launch`
        learn.fit_flat_cos(epochs, lr, wd=wd, cbs=cbs)
        final_accuracies.append(learn.final_record[2])
    
    print('mean accuracy', np.mean(final_accuracies), 
          'median accuracy', np.median(final_accuracies),
          'over', runs, 'run' if runs==1 else 'runs')

## Try a small "out of the box" model to see what we have to beat

`ResNeXt 18 deep` seems to be pretty good to start with. Reducing weight decay improves accuracy a little.

In [4]:
lr=1e-2
arch='xse_resnext18_deep'
wd=1e-4
main(lr=lr, arch=arch, wd=wd)

epochs: 5; lr: 0.01; size: 128; sqrmom: 0.99; mom: 0.9; eps: 1e-06
fp16: 0; arch: xse_resnext18_deep; wd: 0.0001; act_fn: Mish; bs: 64
pool: AvgPool; woof: 0; sh:0.0
Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.497614,1.691584,0.550573,0.909809,00:21
1,1.249366,1.16052,0.734013,0.967643,00:19
2,1.110032,1.356235,0.65656,0.945987,00:19
3,1.023334,0.988059,0.803057,0.976051,00:20
4,0.869452,0.880741,0.848408,0.981911,00:20


mean accuracy 0.8484076261520386 median accuracy 0.8484076261520386 over 1 run


## Try a smaller model

Reducing layers from `[2,2,2,2,1,1]` to `[1,1,1,1]` improved accuracy and made training faster ... but I had to fiddle with `groups` and `reduction` to get consitently better results.

In [5]:
def mini_net(n_out=1000, pretrained=False, **kwargs):
    block=SEResNeXtBlock
    expansion=1
    layers=[1,1,1,1] # [2,2,2,2,1,1] xse_resnext18_deep
    groups=64        # 32 
    reduction=8      # 16
    print(f'block={block} expansion={expansion} layers={layers} groups={groups} reduction={reduction}')
    return XResNet(block, expansion, layers, n_out=n_out, groups=groups, reduction=reduction, **kwargs)

globals()['mini_net'] = mini_net

With layers=[1, 1, 1, 1], groups=32 and reduction=16 we see nearly 0.86 accuracy most of the time - but it sometimes drops to ~0.84.

I think the changes to `groups` and `reduction` improved the consistency of `mini net` - but may have reduced peak accuracy a tiny bit.

Note: I tried a few changes to `groups` and `reduction` with `xse_resnext18_deep` but anything other than 32/16 made it worse.

In [6]:
arch='mini_net'
main(runs=1, lr=lr, arch=arch, wd=wd)

epochs: 5; lr: 0.01; size: 128; sqrmom: 0.99; mom: 0.9; eps: 1e-06
fp16: 0; arch: mini_net; wd: 0.0001; act_fn: Mish; bs: 64
pool: AvgPool; woof: 0; sh:0.0
Run: 0
block=<function SEResNeXtBlock at 0x7f0722e79a60> expansion=1 layers=[1, 1, 1, 1] groups=64 reduction=8


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.564288,1.676641,0.538854,0.904968,00:18
1,1.27344,1.15203,0.735287,0.969682,00:16
2,1.105641,1.211823,0.718471,0.962548,00:17
3,1.004001,0.974918,0.816815,0.980892,00:17
4,0.850758,0.890801,0.85172,0.987771,00:17


mean accuracy 0.8517197370529175 median accuracy 0.8517197370529175 over 1 run


### Try a smaller batch size

This increases train time but improves accuracy by ~0.5%

In [7]:
arch='mini_net'
main(runs=1, lr=lr, arch=arch, wd=wd, bs=32)

epochs: 5; lr: 0.01; size: 128; sqrmom: 0.99; mom: 0.9; eps: 1e-06
fp16: 0; arch: mini_net; wd: 0.0001; act_fn: Mish; bs: 32
pool: AvgPool; woof: 0; sh:0.0
Run: 0
block=<function SEResNeXtBlock at 0x7f0722e79a60> expansion=1 layers=[1, 1, 1, 1] groups=64 reduction=8


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.473351,1.631924,0.58293,0.921019,00:20
1,1.192614,1.178539,0.740637,0.964331,00:19
2,1.063384,1.190584,0.728153,0.962293,00:19
3,0.998295,1.137689,0.755159,0.966369,00:19
4,0.831468,0.880717,0.85758,0.985223,00:19


mean accuracy 0.8575795888900757 median accuracy 0.8575795888900757 over 1 run


let's try this config over 5 runs ...

In [8]:
arch='mini_net'
main(runs=5, lr=lr, arch=arch, wd=wd, bs=32)

epochs: 5; lr: 0.01; size: 128; sqrmom: 0.99; mom: 0.9; eps: 1e-06
fp16: 0; arch: mini_net; wd: 0.0001; act_fn: Mish; bs: 32
pool: AvgPool; woof: 0; sh:0.0
Run: 0
block=<function SEResNeXtBlock at 0x7f0722e79a60> expansion=1 layers=[1, 1, 1, 1] groups=64 reduction=8


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.461116,1.312267,0.677452,0.955414,00:19
1,1.201678,1.160001,0.742166,0.966115,00:19
2,1.06898,1.167217,0.739108,0.96586,00:19
3,0.971632,1.048307,0.787516,0.975541,00:19
4,0.818533,0.870459,0.861147,0.986752,00:19


Run: 1
block=<function SEResNeXtBlock at 0x7f0722e79a60> expansion=1 layers=[1, 1, 1, 1] groups=64 reduction=8


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.47762,1.386195,0.635414,0.940382,00:19
1,1.17611,1.149951,0.741401,0.968662,00:19
2,1.095543,1.186612,0.728917,0.961274,00:19
3,0.998261,1.201254,0.713885,0.974013,00:19
4,0.834541,0.871187,0.860892,0.988535,00:20


Run: 2
block=<function SEResNeXtBlock at 0x7f0722e79a60> expansion=1 layers=[1, 1, 1, 1] groups=64 reduction=8


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.471035,1.754115,0.534268,0.915924,00:19
1,1.184112,1.148899,0.744968,0.967898,00:19
2,1.071816,1.117462,0.756943,0.972229,00:19
3,0.966618,0.984165,0.814777,0.981656,00:19
4,0.799388,0.86639,0.858599,0.98828,00:19


Run: 3
block=<function SEResNeXtBlock at 0x7f0722e79a60> expansion=1 layers=[1, 1, 1, 1] groups=64 reduction=8


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.485978,1.379896,0.628025,0.947261,00:19
1,1.208322,1.201154,0.712102,0.962803,00:19
2,1.083099,1.178641,0.733503,0.965096,00:19
3,0.996163,1.164397,0.732994,0.972484,00:19
4,0.831603,0.882625,0.85172,0.987516,00:20


Run: 4
block=<function SEResNeXtBlock at 0x7f0722e79a60> expansion=1 layers=[1, 1, 1, 1] groups=64 reduction=8


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.464468,1.5209,0.596943,0.935796,00:19
1,1.175033,1.240049,0.695796,0.963312,00:20
2,1.075945,1.315083,0.678726,0.949554,00:19
3,0.989066,1.217997,0.702675,0.970955,00:19
4,0.825066,0.886576,0.856306,0.987006,00:19


mean accuracy 0.8577324748039246 median accuracy 0.8585987091064453 over 5 runs
