# Can we improve model accuracy by using `TwistLayer`?

In response to a request from Yao Liu https://forums.fast.ai/t/imagenette-imagewoof-leaderboards/45822/28, we want to see if using twist layers can improve imagewoof accuracy.

Here's a summary of results over 5 epochs;

| model              | twist? | seconds per epoch | accuracy | Total trainable params |
| -------------------|--------|-------------------|----------|------------------------|
| xse_resnext18      | No     | 19                | 60.8%    |  13,142,122            |
| xse_resnext18      | Yes    | 60                | 61.4%    |  71,878,890            |
| xse_resnext18_deep | No     | 19                | 69.7%    |  15,756,938            |
| xse_resnext18_deep | Yes    | 66                | 64.6%    |  93,338,890            |
| xse_resnext50      | No     | 42                | 65.2%    |  23,613,290            |
| xse_resnext50      | Yes    | --                | -----    | 361,619,306            |

Results with adding twist to xse_resnext18 look good - but come at quite a big additional compute cost.

Note: even with bs=16 I don't have enough GPU mem to run xse_resnext50 with twist )o:

Questions
- Can anyone recommend a resource to understand the ideas behind twist?
- Is the increase in trainable params expected? or have I hade a mistake?
- Do we expect twist to bring improvements with all models? or do we have to run with xse_resnext50
- Do we expect twist to bring improvements with just 5 epochs of training? or does it only improve things with lots of training?

In [1]:
import PIL # hack to re-instate PILLOW_VERSION
PIL.PILLOW_VERSION = PIL.__version__

from fastai2.basics import *
from fastai2.vision.all import *
from fastai2.callback.all import *
from fastai2.distributed import *
from fastprogress import fastprogress
from torchvision.models import *
from fastai2.vision.models.xresnet import *
from fastai2.callback.mixup import *
from fastscript import *

torch.backends.cudnn.benchmark = True
fastprogress.MAX_COLS = 80

In [2]:
def get_dls(size, woof, bs, sh=0., workers=None):
    if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
    else        : path = URLs.IMAGEWOOF     if woof else URLs.IMAGENETTE
    source = untar_data(path)
    if workers is None: workers = min(8, num_cpus())
    resize_ftm = RandomResizedCrop(size, min_scale=0.35)
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label,
                       item_tfms=[resize_ftm, FlipItem(0.5)],
                       batch_tfms=RandomErasing(p=0.3, max_count=3, sh=sh) if sh else None)
    return dblock.dataloaders(source, path=source, bs=bs, num_workers=workers)

In [3]:
@call_parse
def main(
    gpu:   Param("GPU to run on", int)=None,
    woof:  Param("Use imagewoof (otherwise imagenette)", int)=0,
    lr:    Param("Learning rate", float)=1e-2,
    size:  Param("Size (px: 128,192,256)", int)=128,
    sqrmom:Param("sqr_mom", float)=0.99,
    mom:   Param("Momentum", float)=0.9,
    eps:   Param("epsilon", float)=1e-6,
    epochs:Param("Number of epochs", int)=5,
    bs:    Param("Batch size", int)=64,
    mixup: Param("Mixup", float)=0.,
    opt:   Param("Optimizer (adam,rms,sgd,ranger)", str)='ranger',
    arch:  Param("Architecture", str)='xresnet50',
    sh:    Param("Random erase max proportion", float)=0.,
    sa:    Param("Self-attention", int)=0,
    sym:   Param("Symmetry for self-attention", int)=0,
    beta:  Param("SAdam softplus beta", float)=0.,
    act_fn:Param("Activation function", str)='Mish',
    fp16:  Param("Use mixed precision training", int)=0,
    pool:  Param("Pooling method", str)='AvgPool',
    dump:  Param("Print model; don't train", int)=0,
    runs:  Param("Number of times to repeat training", int)=1,
    meta:  Param("Metadata (ignored)", str)='',
    wd:    Param("Weight decay", float)=1e-2
):
    "Training of Imagenette."

    #gpu = setup_distrib(gpu)
    if gpu is not None: torch.cuda.set_device(gpu)
    if   opt=='adam'  : opt_func = partial(Adam, mom=mom, sqr_mom=sqrmom, eps=eps)
    elif opt=='rms'   : opt_func = partial(RMSprop, sqr_mom=sqrmom)
    elif opt=='sgd'   : opt_func = partial(SGD, mom=mom)
    elif opt=='ranger': opt_func = partial(ranger, mom=mom, sqr_mom=sqrmom, eps=eps, beta=beta)

    dls = get_dls(size, woof, bs, sh=sh)
    if not gpu: 
        print(f'epochs: {epochs}; lr: {lr}; size: {size}; sqrmom: {sqrmom}; mom: {mom}; eps: {eps}')
        print(f'fp16: {fp16}; arch: {arch}; wd: {wd}; act_fn: {act_fn}; bs: {bs}')
        print(f'pool: {pool}; woof: {woof}; sh:{sh}')
        
    m,act_fn,pool = [globals()[o] for o in (arch,act_fn,pool)]
    m = m(c_out=10, act_cls=act_fn, sa=sa, sym=sym, pool=pool)
    model_created_callback(m)
    
    for run in range(runs):
        print(f'Run: {run}')
        learn = Learner(dls, m, opt_func=opt_func, metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
        if dump: return learn
        if fp16: learn = learn.to_fp16()
        cbs = MixUp(mixup) if mixup else []
        #n_gpu = torch.cuda.device_count()
        #if gpu is None and n_gpu: learn.to_parallel()
        if num_distrib()>1: learn.to_distributed(gpu) # Requires `-m fastai.launch`
        learn.fit_flat_cos(epochs, lr, wd=wd, cbs=cbs)

## Try an "out of the box" model to see what we have to beat

Expect ~69% from xse_resnext18_deep, ~65% from xresnet50 but ...

... we're aiming for 73.37% (Dmytro Mishkin https://github.com/ducha-aiki/imagewoofv2-fastv2-maxpoolblur/blob/master/fastai2-imagenette-train-maxblurpool.ipynb)

In [4]:
# copy non-default params from Ranger-Mish-ImageWoof-5
lr=3e-3
mom=.95
sa=1
arch='xse_resnext18'

In [13]:
def model_created_callback(m): return m

In [6]:
main(woof=1, lr=lr, mom=mom, sa=sa, arch=arch)

epochs: 5; lr: 0.003; size: 128; sqrmom: 0.99; mom: 0.95; eps: 1e-06
fp16: 0; arch: xse_resnext18; wd: 0.01; act_fn: Mish; bs: 64
pool: AvgPool; woof: 1; sh:0.0
Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.075071,2.031707,0.299313,0.806312,00:20
1,1.866026,1.904647,0.383304,0.858234,00:19
2,1.684238,1.71692,0.45126,0.900229,00:19
3,1.572252,1.536817,0.538305,0.921354,00:19
4,1.41512,1.39147,0.608552,0.949606,00:19


## Try TwistLayer from https://github.com/liuyao12/Ranger-Mish-ImageWoof-5/blob/master/mxresnet.py

`def conv_layer(...):
    ...
    layers = [conv(ni, nf, ks, stride=stride), bn] if ks==1 or nf<=32 else [TwistLayer(ni, nf, stride=stride), bn]
    ...`
    
This snippet, from mxresnet.py, tells me that we want a TwistLayer instead of a Conv2d if `ks!=1` and `nf>32`.
Rather than build the model with TwistLayer from scatch, we'll just swap out some Conv2d layers

In [7]:
def conv(ni, nf, ks=3, stride=1, bias=False):
    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)

def ramp_func(r, a=0.3):
    # r is squared distance (scaled down to 0~1) from the center of a 2d image
    # a=0 makes it a step function
    one = torch.ones(r.size()).type(r.dtype).to(r.device)
    zero = torch.zeros(r.size()).type(r.dtype).to(r.device)
    return torch.where(r<=1-a, one, torch.where(r>=1+a, zero, 1-(r-1+a)/(2*a)))

class TwistLayer(Module):
    def __init__(self, ni, nf, stride=1):
        self.radii = torch.nn.Parameter(torch.ones(nf), requires_grad=True)
        self.center_x = torch.nn.Parameter(torch.ones(nf), requires_grad=True)
        self.center_y = torch.nn.Parameter(torch.ones(nf), requires_grad=True)
        self.radii.data.uniform_(0.3, 0.7)
        self.center_x.data.uniform_(-0.7, 0.7)
        self.center_y.data.uniform_(-0.7, 0.7)
        self.conv = conv(ni, nf, stride=stride)
        self.convx = conv(ni, nf, stride=stride)
        self.convy = conv(ni, nf, stride=stride)

    def forward(self, x):
        # make convx a first-order operator by symmetrizing it
        self.convx.weight.data = (self.convx.weight - self.convx.weight.flip(2).flip(3)) / 2
        # self.convy.weight.data = (self.convy.weight - self.convy.weight.flip(2).flip(3)) / 2
        # make convy a 90 degree rotation of convx
        self.convy.weight.data = self.convx.weight.transpose(2,3).flip(2)
        x1 = self.conv(x)
        _, c, h, w = x1.size()
        XX = torch.from_numpy(np.indices((1,h,w))[2]*2/w).type(x.dtype).to(x.device) - self.center_x.view(-1,1,1)
        YY = torch.from_numpy(np.indices((1,h,w))[1]*2/h).type(x.dtype).to(x.device) - self.center_y.view(-1,1,1)
        mask = ramp_func((XX**2+YY**2)/(self.radii.type(x.dtype).to(x.device).view(-1,1,1)**2))
        return x1 + mask * (XX * self.convx(x) + YY * self.convy(x))

In [8]:
def model_created_callback(m): 
    "swap conv for twist if ks!=1 and nf>32"
    for name, module in m._modules.items():
        if len(list(module.children())) > 0: model_created_callback(module)
        if not isinstance(module, nn.Conv2d): continue
        if module.kernel_size == 1 or module.out_channels <= 32: continue
#         print('swapping', module)
        m._modules[name] = TwistLayer(module.in_channels, module.out_channels, module.stride)

In [9]:
main(woof=1, lr=lr, mom=mom, sa=sa, arch=arch)

epochs: 5; lr: 0.003; size: 128; sqrmom: 0.99; mom: 0.95; eps: 1e-06
fp16: 0; arch: xse_resnext18; wd: 0.01; act_fn: Mish; bs: 64
pool: AvgPool; woof: 1; sh:0.0
Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,2.032181,2.068419,0.306948,0.807076,01:00
1,1.870831,1.754632,0.42988,0.89692,00:59
2,1.744489,1.819987,0.435989,0.890048,00:59
3,1.605472,1.622742,0.519216,0.91041,00:59
4,1.455842,1.395145,0.614406,0.945279,00:59


In [10]:
learn = main(dump=1, woof=1, lr=lr, mom=mom, sa=sa, arch=arch)

epochs: 5; lr: 0.003; size: 128; sqrmom: 0.99; mom: 0.95; eps: 1e-06
fp16: 0; arch: xse_resnext18; wd: 0.01; act_fn: Mish; bs: 64
pool: AvgPool; woof: 1; sh:0.0
Run: 0


In [11]:
learn.model

XResNet(
  (0): ConvLayer(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
  )
  (1): ConvLayer(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
  )
  (2): ConvLayer(
    (0): TwistLayer(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (convx): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (convy): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
  )
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): ResBlock(
      (convp

In [12]:
learn.summary()

XResNet (Input shape: ['64 x 3 x 128 x 128'])
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               64 x 32 x 64 x 64    864        True      
________________________________________________________________
BatchNorm2d          64 x 32 x 64 x 64    64         True      
________________________________________________________________
Mish                 64 x 32 x 64 x 64    0          False     
________________________________________________________________
Conv2d               64 x 32 x 64 x 64    9,216      True      
________________________________________________________________
BatchNorm2d          64 x 32 x 64 x 64    64         True      
________________________________________________________________
Mish                 64 x 32 x 64 x 64    0          False     
________________________________________________________________
Conv2d               64 x 64 x 64 x 64    18,432     True      
____________________________________________________