In [None]:
# default_exp models.modules

In [None]:
# hide
%load_ext autoreload
%autoreload 2

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from grade_classif.models.utils import get_sizes
from grade_classif.models.hooks import Hooks
from grade_classif.imports import *
from torch.nn.functional import interpolate, pad

In [None]:
# export
def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None):
    "Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if p != 0: layers.append(nn.Dropout(p))
    layers.append(nn.Linear(n_in, n_out))
    if actn is not None: layers.append(actn)
    return layers

In [None]:
#export
class ConvBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, bias=True, eps=1e-5, momentum=0.01, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride=stride,
            padding=padding, bias=bias, **kwargs)
        self.bn = nn.BatchNorm2d(
            out_channels, eps=eps, momentum=momentum)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
class ConvBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, bias=True, eps=1e-5, momentum=0.01, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride=stride,
            padding=padding, bias=bias, **kwargs)
        self.bn = nn.BatchNorm2d(
            out_channels, eps=eps, momentum=momentum)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class ConvRelu(nn.Module):
    def __init__(
            self, in_channels, out_channels, kernel_size, stride=1, padding=0,
            bias=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride=stride,
            padding=padding, bias=bias, **kwargs)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

In [None]:
# export
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
    ni, nf, h, w = x.shape
    ni2 = int(ni/(scale**2))
    k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
    k = k.contiguous().view(ni2, nf, -1)
    k = k.repeat(1, 1, scale**2)
    k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
    x.data.copy_(k)

class PixelShuffleICNR(nn.Module):
    def __init__(
            self, in_channels, out_channels, bias=True, scale_factor=2, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels*scale_factor**2, 1, bias=bias, **kwargs)
        icnr(self.conv.weight)
        self.shuf = nn.PixelShuffle(scale_factor)
        # self.pad = nn.ReflectionPad2d((1, 0, 1, 0))
        # self.blur = nn.AvgPool2d(2, stride=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.shuf(x)
        # x = self.pad(x)
        # x = self.blur(x)
        return x

In [None]:
# export
class DecoderBlock(nn.Module):
    def __init__(self, in_chans, skip_chans, hook, final_div=True, **kwargs):
        super().__init__()
        self.hook = hook
        self.shuf = PixelShuffleICNR(in_chans, in_chans//2, **kwargs)
        self.bn = nn.BatchNorm2d(skip_chans)
        ni = in_chans//2 + skip_chans
        nf = ni if not final_div else skip_chans
        self.relu = nn.ReLU()
        self.conv1 = ConvBnRelu(ni, nf, 3, padding=1, **kwargs)
        self.conv2 = ConvBnRelu(nf, nf, 3, padding=1, **kwargs)

    def forward(self, x):
        skipco = self.hook.stored
        x = self.shuf(x)
        ssh = skipco.shape[-2:]
        if ssh != x.shape[-2:]:
            x = interpolate(x, ssh, mode='nearest')
        x = self.relu(torch.cat([x, self.bn(skipco)], dim=1))
        return self.conv2(self.conv1(x))
    
class LastCross(nn.Module):
    def __init__(self, n_chans, bottle=False):
        super(LastCross, self).__init__()
        n_mid = n_chans//2 if bottle else n_chans
        self.conv1 = ConvBnRelu(n_chans, n_mid, 3, padding=1)
        self.conv2 = ConvBnRelu(n_mid, n_chans, 3, padding=1)
        
    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        return x+y

In [None]:
# export
class DynamicUnet(nn.Module):
    def __init__(self, encoder_name, cut=-2, n_classes=2, input_shape=(3, 224, 224), pretrained=True):
        super().__init__()
        encoder = timm.create_model(encoder_name, pretrained=pretrained)
        # encoder = resnet34()
        self.encoder = nn.Sequential(*(list(encoder.children())[:cut]+[nn.ReLU()]))
        encoder_sizes, idxs = self.register_output_hooks(input_shape=input_shape)
        n_chans = encoder_sizes[-1][1]
        middle_conv = nn.Sequential(ConvBnRelu(n_chans, n_chans//2, 3),
                                    ConvBnRelu(n_chans//2, n_chans, 3))
        decoder = [middle_conv]
        for k, (idx, hook) in enumerate(zip(idxs[::-1], self.hooks)):
            skip_chans = encoder_sizes[idx][1]
            final_div = (k != len(idxs)-1)
            decoder.append(DecoderBlock(n_chans, skip_chans, hook, final_div=final_div))
            n_chans = n_chans//2 + skip_chans
            n_chans = n_chans if not final_div else skip_chans
        self.decoder = nn.Sequential(*decoder, PixelShuffleICNR(n_chans, n_chans))
        n_chans += input_shape[0]
        self.head = nn.Sequential(LastCross(n_chans), nn.Conv2d(n_chans, n_classes, 1))
        
        
    def forward(self, x):
        y = self.encoder(x)
        y = self.decoder(y)
        if y.shape[-2:] != x.shape[-2:]:
            y = interpolate(y, x.shape[-2:], mode='nearest')
        y = torch.cat([x, y], dim=1)
        y = self.head(y)
        return y
    
        
    def register_output_hooks(self, input_shape=(3, 224, 224)):
        sizes, modules = get_sizes(self.encoder, input_shape=input_shape)
        mods = []
        idxs = np.where(sizes[:-1, -1] != sizes[1:, -1])[0]
        def _hook(model, input, output):
            return output
                
        for k in idxs[::-1]:
            out_shape = sizes[k]
            m = modules[k]
            if 'downsample' not in m.name:
                mods.append(m)
        self.hooks = Hooks(mods, _hook)
        
        return sizes, idxs
    
    def __del__(self):
        if hasattr(self, "hooks"): self.hooks.remove()

In [None]:
# export
class CBR(nn.Module):
    def __init__(self, kernel_size, n_kernels, n_layers, n_classes=2):
        super().__init__()
        in_c = 3
        out_c = n_kernels
        for k in range(n_layers):
            self.add_module(f'cbr{k}', ConvBnRelu(in_c, out_c, kernel_size, stride=2, padding=kernel_size//2, padding_mode='reflect'))
            # self.add_module(f'maxpool{k}', nn.MaxPool2d(3, stride=2, padding=1))
            in_c = out_c
            out_c *= 2
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.flat = nn.Flatten(-2, -1)
        self.fc = nn.Linear(out_c, n_classes)
        
    def forward(self, x):
        for m in self.children():
            x = m(x)
        return x

In [None]:
mod = CBR(7, 32, 4); mod

CBR(
  (cbr0): ConvBnRelu(
    (conv): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), padding_mode=reflect)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (cbr1): ConvBnRelu(
    (conv): Conv2d(32, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), padding_mode=reflect)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (cbr2): ConvBnRelu(
    (conv): Conv2d(64, 128, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), padding_mode=reflect)
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (cbr3): ConvBnRelu(
    (conv): Conv2d(128, 256, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), padding_mode=reflect)
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (gap): AdaptiveAvgPool2d(output_size=1)
  (fl

In [None]:
# export
class SelfAttentionBlock(nn.Module):
    def __init__(self, c_in, c_out, k, stride=1, groups=1, bias=False):
        super().__init__()
        assert c_in % groups == c_out % groups == 0, "c_in and c_out must be divided by groups"
        assert k % 2 == 1, "k must be odd"
        assert c_out % 2 == 0, "c_out must be even"
        
        padding = k // 2
        self.c_in = c_in
        self.c_out = c_out
        self.k = k
        self.stride = stride
        self.groups = groups
        
        self.key_conv = nn.Conv2d(c_in, c_out, 1, padding=padding, groups=groups, bias=bias, padding_mode='reflect')
        self.query_conv = nn.Conv2d(c_in, c_out, 1, groups=groups, bias=bias)
        self.value_conv = nn.Conv2d(c_in, c_out, 1, padding=padding, groups=groups, bias=bias, padding_mode='reflect')
        
        self.r_ai = nn.Parameter(torch.randn(1, c_out//2, k, 1))
        self.r_aj = nn.Parameter(torch.randn(1, c_out//2, 1, k))
    
    def forward(self, x):
        b, c, h, w = x.shape
        n = self.c_out // self.groups
        
        q = self.query_conv(x).view(b, self.groups, n, h, w, 1)
        k = self.key_conv(x).unfold(2, self.k, self.stride).unfold(3, self.k, self.stride).view(b, self.groups, n, h, w, -1)
        v = self.value_conv(x).unfold(2, self.k, self.stride).unfold(3, self.k, self.stride).view(b, self.groups, n, h, w, -1)
        
        r = torch.cat((self.r_ai.expand(b, -1, -1, self.k), self.r_aj.expand(b, -1, self.k, -1)), dim=1).view(b, self.groups, n, -1)
        r = r[..., None, None, :].expand(-1, -1, -1, h, w, -1)
        
        y = (torch.softmax((q*(k+r)).sum(2, keepdims=True), dim=-1) * v).sum(-1).view(b, c_out, h, w)
        
        return y
        

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_train.ipynb.
Converted 02_predict.ipynb.
Converted 10_data.read.ipynb.
Converted 11_data.loaders.ipynb.
Converted 12_data.dataset.ipynb.
Converted 13_data.utils.ipynb.
Converted 14_data.transforms.ipynb.
Converted 20_models.plmodules.ipynb.
Converted 21_models.modules.ipynb.
Converted 22_models.utils.ipynb.
Converted 23_models.hooks.ipynb.
Converted 24_models.metrics.ipynb.
Converted 25_models.losses.ipynb.
Converted 80_params.defaults.ipynb.
Converted 81_params.parser.ipynb.
Converted 99_index.ipynb.
