In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings

import numpy as np
import torch
import torch.nn as nn

sys.path.insert(0, "../../..")
from batchflow import *
from batchflow.models.torch import *
from batchflow.models.torch.layers import *

# Simple

In [2]:
inputs = torch.rand(4, 3, 128, 128)

In [3]:
layer_torch = nn.Conv2d(in_channels=3, out_channels=17, kernel_size=3, padding=1, stride=1)
layer_torch(inputs).shape

torch.Size([4, 17, 128, 128])

In [4]:
layer_torch

Conv2d(3, 17, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [None]:
class Conv(nn.Module):
    def __init__(self, inputs=None, **kwargs):
        self.layer = nn.Conv2d(in_channels=inputs.shape[1], **kwargs)

In [56]:
layer_bf = Conv(inputs=inputs, channels='max(1, same // 8) * 8', kernel_size=3, padding='same', stride=1)
layer_bf(inputs).shape

torch.Size([4, 8, 128, 128])

In [6]:
layer_bf

Conv(
  (layer): Conv2d(3, 17, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)

In [7]:
(Activation('softmax', dim=-1)(inputs).shape,
 MaxPool(inputs=inputs, pool_size=3, pool_stride=2)(inputs).shape)

(torch.Size([4, 3, 128, 128]), torch.Size([4, 3, 64, 64]))

# Multilayer

In [10]:
block_torch = nn.Sequential(nn.Conv2d(3, 17, kernel_size=3, padding=1), nn.BatchNorm2d(17))
block_torch(inputs).shape

torch.Size([4, 17, 128, 128])

In [34]:
block_bf = MultiLayer(inputs=inputs, layout='Rcnacn+', channels=17, kernel_size=[[3, 1], 5],
                      branch={'layout': 'cn', 'channels': 17}, )
block_bf(inputs).shape

torch.Size([4, 17, 128, 128])

In [43]:
block_bf.repr(4, show_num_parameters=True)

MultiLayer(
    layout=Rcnacn+
    (Layer 0,    skip "R":  (?,   3, 128, 128)  ⟶ (?,  17, 128, 128), #params=493)
    (Layer 1,  letter "c":  (?,   3, 128, 128)  ⟶ (?,  17, 128, 128), #params=153)
    (Layer 2,  letter "n":  (?,  17, 128, 128)  ⟶ (?,  17, 128, 128), #params=34)
    (Layer 3,  letter "a":  (?,  17, 128, 128)  ⟶ (?,  17, 128, 128), #params=0)
    (Layer 4,  letter "c":  (?,  17, 128, 128)  ⟶ (?,  17, 128, 128), #params=7,225)
    (Layer 5,  letter "n":  (?,  17, 128, 128)  ⟶ (?,  17, 128, 128), #params=34)
    (Layer 6, combine "+": [(?,  17, 128, 128),
                            (?,  17, 128, 128)] ⟶ (?,  17, 128, 128), #params=0)
)


# Named blocks

In [60]:
block = ResBlock(inputs=inputs, layout='cna', channels='same')

In [45]:
Block()

ResBlock(
  (repeat0-args0): MultiLayer(
    layout=Rcnacn+a
    (Layer 0,    skip "R"): Branch(
      (layer): Identity()
    )
    (Layer 1,  letter "c"): Conv(
      (layer): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (Layer 2,  letter "n"): BatchNorm(
      (layer): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Layer 3,  letter "a"): Activation(
      (activation): ReLU(inplace=True)
    )
    (Layer 4,  letter "c"): Conv(
      (layer): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (Layer 5,  letter "n"): BatchNorm(
      (layer): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Layer 6, combine "+"): Combine(
      op=sum, leading_idx=0, force_resize=True,
        input_shapes=[(4, 3, 128, 128), (4, 3, 128, 128)],
      resized_shapes=[(4, 3, 128, 128), (4, 3, 128, 128)],
       output_shapes=(4, 3, 128, 128)
    )