In [None]:
# default_exp models.phy

# PhyDNet
> ConvLSTM + PhyCell
https://github.com/vincent-leguen/PhyDNet/blob/master/models/models.py

In [None]:
#export
from fastai.vision.all import *

In [None]:
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    print(torch.cuda.get_device_name())

Quadro RTX 8000


## The PhyCell

We will refactor this to not make the hidden state as a class attribute. We can also make use of some fastai magic, like `one_param` (to be sure to be on the same device as the model params) and `store_attr()` to save our class attributes.
![phycell](images/phycell.png)

In [None]:
#export
class PhyCell_Cell(Module):
    def __init__(self, ch_in, hidden_dim, ks=3, bias=True):
        store_attr()
        padding = ks // 2
        bias = bias
        self.f = nn.Sequential(
                 nn.BatchNorm2d(ch_in),   
                 nn.Conv2d(ch_in, hidden_dim, ks, padding=padding),
                 nn.Conv2d(hidden_dim, ch_in, kernel_size=(1,1)))

        self.convgate = nn.Conv2d(2*ch_in,
                                  ch_in,
                                  kernel_size=(3,3),
                                  padding=(1,1), 
                                  bias=bias)

    def forward(self, x, hidden=None): 
        "x ~[batch_size, hidden_dim, height, width]"  
        if hidden is None: hidden = self.init_hidden(x)
        hidden_tilde = hidden + self.f(hidden)
        combined = torch.cat([x, hidden_tilde], dim=1)
        combined_conv = self.convgate(combined)
        K = torch.sigmoid(combined_conv)
        next_hidden = hidden_tilde + K * (x - hidden_tilde)
        return next_hidden
    
    def init_hidden(self, x):
        bs, ch, h, w = x.shape
        return one_param(self).new_zeros(bs, ch, h, w)

In [None]:
p_cell = PhyCell_Cell(16, 32, 3).cuda()
p_cell

PhyCell_Cell(
  (f): Sequential(
    (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  )
  (convgate): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [None]:
out = p_cell(torch.rand(64,16,12,12).cuda())
out.shape

torch.Size([64, 16, 12, 12])

In [None]:
mse_loss = MSELossFlat()
loss = mse_loss(out, torch.zeros_like(out))
loss

tensor(0.0899, device='cuda:0', grad_fn=<MseLossBackward>)

In [None]:
loss.backward()

In [None]:
loss

tensor(0.0899, device='cuda:0', grad_fn=<MseLossBackward>)

In [None]:
#export
class PhyCell(Module):
    def __init__(self, ch_in, hidden_dims, ks, n_layers):
        store_attr() 
        self.cell_list = nn.ModuleList()
        for i in range(self.n_layers):
            self.cell_list.append(PhyCell_Cell(ch_in=ch_in,
                                               hidden_dim=hidden_dims[i],
                                               ks=ks))                                     
       
    def forward(self, x, hidden=None): 
        "x ~ [batch_size, seq_len, channels, width, height]"    
        assert x.shape[2] == self.ch_in, "Input tensor has different channels dim than Cell"
        if hidden is None: hidden = self.init_hidden(x)
        cur_layer_input = torch.unbind(x, dim=1)
        seq_len = len(cur_layer_input)
        last_state_list = []
        
        for cell, h in zip(self.cell_list, hidden):
            output_inner = []
            for inp in cur_layer_input:
                hid = cell(inp, h)
                output_inner.append(hid)
            cur_layer_input = output_inner
            last_state_list.append(h)
            
        layer_output = torch.stack(output_inner, dim=1)
        last_states = torch.stack(last_state_list, dim=0)
        return layer_output, last_states
    
    def init_hidden(self, x):
        assert len(x.shape)==5, "input shape must be [bs, seq_len, ch, w, h]"
        hid = [] 
        for l in self.cell_list:
            hid.append(l.init_hidden(x[:,0, ...]))
        return hid

In [None]:
phy = PhyCell(16, [49], 7, 1)

In [None]:
out, states = phy(torch.rand(8,10,16,6,6))

In [None]:
out.shape, states.shape

(torch.Size([8, 10, 16, 6, 6]), torch.Size([1, 8, 16, 6, 6]))

In [None]:
states[-1].shape

torch.Size([8, 16, 6, 6])

In [None]:
mse_loss = MSELossFlat()
loss = mse_loss(out, torch.zeros_like(out))
loss

tensor(0.0910, grad_fn=<MseLossBackward>)

In [None]:
loss.backward()

In [None]:
loss

tensor(0.0910, grad_fn=<MseLossBackward>)

# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_data.ipynb.
Converted 01_models.conv_rnn.ipynb.
Converted 02_models.dcn.ipynb.
Converted 02_models.transformer.ipynb.
Converted 02_tcn.ipynb.
Converted 03_phy.ipynb.
Converted 04_seq2seq.ipynb.
Converted index.ipynb.
