In [None]:
# default_exp models.XResNet1dPlus

# XResNet1dPlus

> This is a modified version of fastai's XResNet model in github. Changes include:
* API is modified to match the default timeseriesAI's API.
* (Optional) Uber's CoordConv 1d

In [None]:
#export
from tsai.imports import *
from tsai.models.layers import *
from tsai.models.utils import *

In [None]:
#export
class XResNet1dPlus(nn.Sequential):
    @delegates(ResBlock1dPlus)
    def __init__(self, block, expansion, layers, fc_dropout=0.0, c_in=3, n_out=1000, stem_szs=(32,32,64),
                 widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, **kwargs):
        store_attr('block,expansion,act_cls,ks')
        if ks % 2 == 0: raise Exception('kernel size has to be odd!')
        stem_szs = [c_in, *stem_szs]
        stem = [ConvBlock(stem_szs[i], stem_szs[i+1], ks=ks, coord=coord, stride=stride if i==0 else 1,
                          act=act_cls)
                for i in range(3)]

        block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
        block_szs = [64//expansion] + block_szs
        blocks    = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)
        backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)
        head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout),
                             nn.Linear(block_szs[-1]*expansion, n_out))
        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))
        self._init_cnn(self)

    def _make_blocks(self, layers, block_szs, sa, coord, stride, **kwargs):
        return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l, coord=coord, 
                                 stride=1 if i==0 else stride, sa=sa and i==len(layers)-4, **kwargs)
                for i,l in enumerate(layers)]

    def _make_layer(self, ni, nf, blocks, coord, stride, sa, **kwargs):
        return nn.Sequential(
            *[self.block(self.expansion, ni if i==0 else nf, nf, coord=coord, stride=stride if i==0 else 1,
                      sa=sa and i==(blocks-1), act_cls=self.act_cls, ks=self.ks, **kwargs)
              for i in range(blocks)])
    
    def _init_cnn(self, m):
        if getattr(self, 'bias', None) is not None: nn.init.constant_(self.bias, 0)
        if isinstance(self, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): nn.init.kaiming_normal_(self.weight)
        for l in m.children(): self._init_cnn(l)

In [None]:
#export
def _xresnetplus(expansion, layers, **kwargs):
    return XResNet1dPlus(ResBlock1dPlus, expansion, layers, **kwargs)

In [None]:
#export
@delegates(ResBlock)
def xresnet1d18plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [2, 2,  2, 2], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d34plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [3, 4,  6, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d50plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3, 4,  6, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d101plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3, 4, 23, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d152plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3, 8, 36, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d18_deepplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [2,2,2,2,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d34_deepplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [3,4,6,3,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d50_deepplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3,4,6,3,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d18_deeperplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [2,2,1,1,1,1,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d34_deeperplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [3,4,6,3,1,1,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)
@delegates(ResBlock)
def xresnet1d50_deeperplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3,4,6,3,1,1,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)

In [None]:
net = xresnet1d18plus(3, 2, coord=True)
x = torch.rand(32, 3, 50)
net(x)

tensor([[ 0.2051, -0.4082],
        [ 0.1320, -0.2454],
        [ 0.1718, -0.5900],
        [ 0.3423, -0.1778],
        [ 0.2053, -0.5386],
        [ 0.2830, -0.5609],
        [ 0.4794, -0.4538],
        [ 0.1914,  0.1688],
        [ 0.3116, -0.3382],
        [ 0.0513, -0.3032],
        [-0.0072, -0.4724],
        [ 0.0463, -0.3346],
        [ 0.4596, -0.7458],
        [ 0.2832, -0.5932],
        [ 0.2439, -0.0979],
        [ 0.2160, -0.4570],
        [ 0.2408, -0.2508],
        [ 0.0452, -0.0187],
        [ 0.4669, -0.1605],
        [ 0.3136, -0.6333],
        [-0.0930, -0.0047],
        [-0.0522, -0.4743],
        [-0.0063, -0.2872],
        [ 0.5206, -0.2603],
        [ 0.0385, -0.8225],
        [ 0.1682, -0.7900],
        [ 0.1275, -0.4700],
        [ 0.0180, -0.4667],
        [ 0.1500, -0.3306],
        [ 0.4787, -0.5689],
        [ 0.3416, -0.7539],
        [ 0.0963, -0.5110]], grad_fn=<AddmmBackward>)

In [None]:
bs, c_in, seq_len = 2, 4, 32
c_out = 2
x = torch.rand(bs, c_in, seq_len)
archs = [
    xresnet1d18plus, xresnet1d34plus, xresnet1d50plus, 
    xresnet1d18_deepplus, xresnet1d34_deepplus, xresnet1d50_deepplus, xresnet1d18_deeperplus,
    xresnet1d34_deeperplus, xresnet1d50_deeperplus
#     # Long test
#     xresnet1d101, xresnet1d152,
]
for i, arch in enumerate(archs):
    print(i, arch.__name__)
    test_eq(arch(c_in, c_out, sa=True, act=Mish, coord=True)(x).shape, (bs, c_out))

0 xresnet1d18plus
1 xresnet1d34plus
2 xresnet1d50plus
3 xresnet1d18_deepplus
4 xresnet1d34_deepplus
5 xresnet1d50_deepplus
6 xresnet1d18_deeperplus
7 xresnet1d34_deeperplus
8 xresnet1d50_deeperplus


In [None]:
m = xresnet1d34plus(4, 2, act=Mish)
test_eq(len(get_layers(m, is_bn)), 38)
test_eq(check_weight(m, is_bn)[0].sum(), 22)

In [None]:
#hide
from tsai.imports import *
from tsai.export import *
nb_name = get_nb_name()
nb_name = "112b_models.XResNet1dPlus.ipynb"
create_scripts(nb_name);