In [None]:
# default_exp models.conv_rnn

# Recurrecnt Convolutional Kernels
> ConvLSTM and ConvGRU cells and models

In [None]:
#export
from fastai2.vision.all import *

In [None]:
torch.cuda.set_device(1)
torch.cuda.get_device_name()

'GeForce RTX 2070 SUPER'

In [None]:
#export
class ConvGRU_cell(Module):
    def __init__(self, in_ch, out_ch, ks=3, debug=False):
        self.in_ch = in_ch
        # kernel_size of input_to_state equals state_to_state
        self.ks = ks
        self.out_ch = out_ch
        self.debug = debug
        self.padding = (ks - 1) // 2
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.in_ch + self.out_ch,
                      2 * self.out_ch, self.ks, 1,
                      self.padding),
            nn.GroupNorm(2 * self.out_ch // 32, 2 * self.out_ch))
        self.conv2 = nn.Sequential(
            nn.Conv2d(self.in_ch + self.out_ch,
                      self.out_ch, self.ks, 1, self.padding),
            nn.GroupNorm(self.out_ch // 32, self.out_ch))

    def forward(self, inputs, hidden_state=None):
        "inputs shape: (bs, seq_len, ch, w, h)"
        bs, seq_len, ch, w, h = inputs.shape
        if hidden_state is None:
            htprev = self.initHidden(bs, self.out_ch, w, h)
            if self.debug: print(f'htprev: {htprev.shape}')
        else:
            htprev = hidden_state
        output_inner = []
        for index in range(seq_len):
            x = inputs[:, index, ...]
            combined_1 = torch.cat((x, htprev), 1)  # X_t + H_t-1
            gates = self.conv1(combined_1)  # W * (X_t + H_t-1)          
            zgate, rgate = torch.split(gates, self.out_ch, dim=1)
            z = torch.sigmoid(zgate)
            r = torch.sigmoid(rgate)
            combined_2 = torch.cat((x, r * htprev),1)
            ht = self.conv2(combined_2)
            ht = torch.tanh(ht)
            htnext = (1 - z) * htprev + z * ht
            output_inner.append(htnext)
            htprev = htnext
        return torch.stack(output_inner, dim=1), htnext
    def __repr__(self): return f'ConvGRU_cell(in={self.in_ch}, out={self.out_ch}, ks={self.ks})'
    def initHidden(self, bs, ch, w, h): return one_param(self).new_zeros(bs, ch, w, h)

Let's check:

In [None]:
cell = ConvGRU_cell(32, 32, debug=True).cuda()
cell

ConvGRU_cell(in=32, out=32, ks=3)

In [None]:
x = torch.rand(2, 7, 32, 64, 64).cuda()
out, h = cell(x)

htprev: torch.Size([2, 32, 64, 64])


In [None]:
out.shape

torch.Size([2, 7, 32, 64, 64])

Checking sizes:

In [None]:
test_eq(out.shape, x.shape) 
test_eq(h.shape, [2,32,64,64])

Should be possible to call with hidden state:

In [None]:
out2, h2 = cell(out, h)
test_eq(h2.shape, [2, 32, 64, 64])

A very nasty module to propagate 2D layers over sequence of images, inspired from Keras

In [None]:
#export
class TimeDistributed(Module):
    "Applies a module over tdim identically for each step" 
    def __init__(self, module, low_mem=False, tdim=1):
        self.module = module
        self.low_mem = low_mem
        self.tdim = tdim
        
    def forward(self, x):
        "input x with shape:(bs,steps,channels,width,height)"
        if self.low_mem or self.tdim!=1: 
            return self.low_mem_forward(x)
        else:
            inp_shape = x.shape
            bs, seq_len = inp_shape[0], inp_shape[1]   
            out = self.module(x.view(bs*seq_len, *inp_shape[2:]))
            out_shape = out.shape
            return out.view(bs, seq_len,*out_shape[1:])
        
    def low_mem_forward(self, x):                                           
        "input x with shape:(bs,steps,channels,width,height)"
        x_split = torch.split(x,1,dim=self.tdim)
        out =[]
        for i in range(len(x_split)):
            out.append(self.module(x_split[i].squeeze(dim=self.tdim)))
        return torch.stack(out,dim=self.tdim)

## Encoder

In [None]:
#export
class Encoder(Module):
    def __init__(self, n_in=1, szs=[16,64,96,96], ks=3, rnn_ks=5, act=nn.ReLU, norm=None, debug=False):
        self.n_blocks = len(szs)-1
        self.debug = debug
        convs = []
        rnns = []
        convs.append(ConvLayer(1, szs[0], ks=ks, padding=ks//2, act_cls=act, norm_type=norm))
        rnns.append(ConvGRU_cell(szs[0], szs[1], ks=rnn_ks))
        for ni, nf in zip(szs[1:-1], szs[2:]):
            if self.debug: print(ni, nf)
            convs.append(ConvLayer(ni, ni, ks=ks, stride=2, padding=ks//2, act_cls=act, norm_type=norm))
            rnns.append(ConvGRU_cell(ni, nf, ks=rnn_ks))
        self.convs = nn.ModuleList(TimeDistributed(conv) for conv in convs)
        self.rnns = nn.ModuleList(rnns)
        
    def forward_by_stage(self, inputs, conv, rnn):
        if self.debug: 
            print(f' Layer: {rnn}')
            print(' inputs: ', inputs.shape)
        inputs = conv(inputs)
        if self.debug: print(' after_convs: ', inputs.shape)
        outputs_stage, state_stage = rnn(inputs, None)
        if self.debug: print(' output_stage: ', outputs_stage.shape)
        return outputs_stage, state_stage

    def forward(self, inputs):
        "inputs.shape bs,seq_len,1,64,64"
        hidden_states = []
        for i, (conv, rnn) in enumerate(zip(self.convs, self.rnns)):
            if self.debug: print('stage: ',i)
            inputs, state_stage = self.forward_by_stage(inputs, conv, rnn)
            hidden_states.append(state_stage)
        return inputs, hidden_states

In [None]:
enc = Encoder(debug=True)
enc

64 96
96 96


Encoder(
  (convs): ModuleList(
    (0): TimeDistributed(
      (module): ConvLayer(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
    )
    (1): TimeDistributed(
      (module): ConvLayer(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
      )
    )
    (2): TimeDistributed(
      (module): ConvLayer(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU()
      )
    )
  )
  (rnns): ModuleList(
    (0): ConvGRU_cell(in=16, out=64, ks=5)
    (1): ConvGRU_cell(in=64, out=96, ks=5)
    (2): ConvGRU_cell(in=96, out=96, ks=5)
  )
)

In [None]:
imgs = torch.rand(2, 10, 1, 64, 64)

In [None]:
enc_out, h = enc(imgs)

stage:  0
 Layer: ConvGRU_cell(in=16, out=64, ks=5)
 inputs:  torch.Size([2, 10, 1, 64, 64])
 after_convs:  torch.Size([2, 10, 16, 64, 64])
 output_stage:  torch.Size([2, 10, 64, 64, 64])
stage:  1
 Layer: ConvGRU_cell(in=64, out=96, ks=5)
 inputs:  torch.Size([2, 10, 64, 64, 64])
 after_convs:  torch.Size([2, 10, 64, 32, 32])
 output_stage:  torch.Size([2, 10, 96, 32, 32])
stage:  2
 Layer: ConvGRU_cell(in=96, out=96, ks=5)
 inputs:  torch.Size([2, 10, 96, 32, 32])
 after_convs:  torch.Size([2, 10, 96, 16, 16])
 output_stage:  torch.Size([2, 10, 96, 16, 16])


In [None]:
enc_out.shape

torch.Size([2, 10, 96, 16, 16])

## Decoder

In [None]:
#export 
class UpsampleBlock(Module):
    "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
    @delegates(ConvLayer.__init__)
    def __init__(self, in_ch, out_ch, final_div=True, blur=False, act_cls=defaults.activation,
                 self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        store_attr(self, 'in_ch,out_ch,blur,act_cls,self_attention,norm_type')
        self.shuf = PixelShuffle_ICNR(in_ch, in_ch//2, blur=blur, act_cls=act_cls, norm_type=norm_type)
        ni = in_ch//2
        nf = out_ch
        self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type,
                               xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = act_cls()
        apply_init(nn.Sequential(self.conv1, self.conv2), init)
    def __repr__(self): return (f'UpsampleBLock(in={self.in_ch}, out={self.out_ch}, blur={self.blur}, '
                                f'act={self.act_cls()}, attn={self.self_attention}, norm={self.norm_type})')
    def forward(self, up_in):
        up_out = self.shuf(up_in)
        return self.conv2(self.conv1(up_out))

In [None]:
us = UpsampleBlock(32, 16)
us

UpsampleBLock(in=32, out=16, blur=False, act=ReLU(), attn=False, norm=None)

In [None]:
us(torch.rand(8, 32, 32, 32)).shape

torch.Size([8, 16, 64, 64])

In [None]:
#export
class Decoder(Module):
    def __init__(self, n_out=1, szs=[16,64,96,96], ks=3, rnn_ks=5, act=nn.ReLU, blur=False, attn=False, 
                 norm=None, debug=False):
        self.n_blocks = len(szs)-1
        self.debug = debug
        deconvs = []
        rnns = []
        szs = szs[::-1]
        rnns.append(ConvGRU_cell(szs[0], szs[0], ks=rnn_ks))
        for ni, nf in zip(szs[0:-2], szs[1:]):
            deconvs.append(UpsampleBlock(ni, ni, blur=blur, self_attention=attn, act_cls=act, norm_type=norm))
            rnns.append(ConvGRU_cell(ni, nf, ks=rnn_ks))
        
        #last layer
        deconvs.append(ConvLayer(szs[-2], szs[-1], ks, padding=ks//2, act_cls=act, norm_type=norm))
        self.head = TimeDistributed(nn.Conv2d(szs[-1], n_out,kernel_size=1))
        self.deconvs = nn.ModuleList(TimeDistributed(conv) for conv in deconvs)
        self.rnns = nn.ModuleList(rnns)

    def forward_by_stage(self, inputs, state, deconv, rnn):
        if self.debug: 
            print(f' Layer: {rnn}')
            print(' inputs:, state: ', inputs.shape, state.shape)
        inputs, state_stage = rnn(inputs, state)
        if self.debug: 
            print(' after rnn: ', inputs.shape)
            print(f' Layer: {deconv}')
        outputs_stage = deconv(inputs)
        if self.debug: print(' after_deconvs: ', outputs_stage.shape)
        return outputs_stage, state_stage
    
    def forward(self, inputs, hidden_states):
        for i, (state, conv, rnn) in enumerate(zip(hidden_states[::-1], self.deconvs, self.rnns)):
            if self.debug: print('stage: ',i)
            inputs, state_stage = self.forward_by_stage(inputs, state, conv, rnn)
        return self.head(inputs)

In [None]:
dec = Decoder(debug=True)
dec

Decoder(
  (head): TimeDistributed(
    (module): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (deconvs): ModuleList(
    (0): TimeDistributed(
      (module): UpsampleBLock(in=96, out=96, blur=False, act=ReLU(), attn=False, norm=None)
    )
    (1): TimeDistributed(
      (module): UpsampleBLock(in=96, out=96, blur=False, act=ReLU(), attn=False, norm=None)
    )
    (2): TimeDistributed(
      (module): ConvLayer(
        (0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
    )
  )
  (rnns): ModuleList(
    (0): ConvGRU_cell(in=96, out=96, ks=5)
    (1): ConvGRU_cell(in=96, out=96, ks=5)
    (2): ConvGRU_cell(in=96, out=64, ks=5)
  )
)

In [None]:
[_.shape for _ in h]

[torch.Size([2, 64, 64, 64]),
 torch.Size([2, 96, 32, 32]),
 torch.Size([2, 96, 16, 16])]

In [None]:
enc_out.shape

torch.Size([2, 10, 96, 16, 16])

In [None]:
test_eq(dec(enc_out, h).shape, imgs.shape)
print('---\n')
test_eq(dec(*enc(imgs)).shape, imgs.shape)

stage:  0
 Layer: ConvGRU_cell(in=96, out=96, ks=5)
 inputs:, state:  torch.Size([2, 10, 96, 16, 16]) torch.Size([2, 96, 16, 16])
 after rnn:  torch.Size([2, 10, 96, 16, 16])
 Layer: TimeDistributed(
  (module): UpsampleBLock(in=96, out=96, blur=False, act=ReLU(), attn=False, norm=None)
)
 after_deconvs:  torch.Size([2, 10, 96, 32, 32])
stage:  1
 Layer: ConvGRU_cell(in=96, out=96, ks=5)
 inputs:, state:  torch.Size([2, 10, 96, 32, 32]) torch.Size([2, 96, 32, 32])
 after rnn:  torch.Size([2, 10, 96, 32, 32])
 Layer: TimeDistributed(
  (module): UpsampleBLock(in=96, out=96, blur=False, act=ReLU(), attn=False, norm=None)
)
 after_deconvs:  torch.Size([2, 10, 96, 64, 64])
stage:  2
 Layer: ConvGRU_cell(in=96, out=64, ks=5)
 inputs:, state:  torch.Size([2, 10, 96, 64, 64]) torch.Size([2, 64, 64, 64])
 after rnn:  torch.Size([2, 10, 64, 64, 64])
 Layer: TimeDistributed(
  (module): ConvLayer(
    (0): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
)
 a

## Model

In [None]:
#export
class SimpleModel(Module):
    def __init__(self, n_in=1, n_out=1, szs=[16,64,96,96], ks=3, rnn_ks=5, act=nn.ReLU, blur=False, attn=False, 
                 norm=None, strategy='zero', debug=False):
        self.strategy = strategy
        self.encoder = Encoder(n_in, szs, ks, rnn_ks, act, norm, debug)
        self.decoder = Decoder(n_out, szs, ks, rnn_ks, act, blur, attn, norm, debug)
    def forward(self, x):
        if isinstance(x, list) or isinstance(x, tuple):
            x = torch.stack(x, dim=1)
        enc_out, h = self.encoder(x)
        if self.strategy is 'zero':
            dec_in = one_param(self).new_zeros(*enc_out.shape)
        elif self.strategy is 'encoder':
            dec_in = enc_out.detach()
        return self.decoder(dec_in, h)

In [None]:
m = SimpleModel(strategy='zero')

In [None]:
imgs_list = [torch.rand(2,1,64,64) for _ in range(10)]

In [None]:
test_eq(m(imgs_list).shape, imgs.shape)

In [None]:
test_eq(m(imgs).shape, imgs.shape)

# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_data.ipynb.
Converted 01_models.conv_rnn.ipynb.
Converted index.ipynb.
