In [None]:
# default_exp models.phy_original

# PhyDNet original implementation
> ConvLSTM + PhyCell
https://github.com/vincent-leguen/PhyDNet/blob/master/models/models.py

In [None]:
#export
from fastai.vision.all import *

In [None]:
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    print(torch.cuda.get_device_name())

Quadro RTX 8000


## The PhyCell

<!-- We will refactor this to not make the hidden state as a class attribute. We can also make use of some fastai magic, like `one_param` (to be sure to be on the same device as the model params) and `store_attr()` to save our class attributes. -->
![phycell](images/phycell.png)

In [None]:
#export
class PhyCell_Cell(nn.Module):
    def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1):
        super(PhyCell_Cell, self).__init__()
        self.input_dim  = input_dim
        self.F_hidden_dim = F_hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        
        self.F = nn.Sequential()
        self.F.add_module('bn1',nn.BatchNorm2d(input_dim))          
        self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim, kernel_size=self.kernel_size, stride=(1,1), padding=self.padding))  
        #self.F.add_module('f_act1', nn.LeakyReLU(negative_slope=0.1))        
        self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim, kernel_size=(1,1), stride=(1,1), padding=(0,0)))

        self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim,
                              out_channels= self.input_dim,
                              kernel_size=(3,3),
                              padding=(1,1), bias=self.bias)

    def forward(self, x, hidden): # x [batch_size, hidden_dim, height, width]      
        hidden_tilde = hidden + self.F(hidden)        # prediction
        
        combined = torch.cat([x, hidden_tilde], dim=1)  # concatenate along channel axis
        combined_conv = self.convgate(combined)
        K = torch.sigmoid(combined_conv)
        
        next_hidden = hidden_tilde + K * (x-hidden_tilde)   # correction , Haddamard product     
        return next_hidden

In [None]:
phy_cell = PhyCell_Cell(16, 32, (3, 3)).cuda()
phy_cell

PhyCell_Cell(
  (F): Sequential(
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  )
  (convgate): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [None]:
out = phy_cell(torch.rand(64,16,12,12).cuda(), torch.rand(64,16,12,12).cuda())
out.shape

torch.Size([64, 16, 12, 12])

In [None]:
mse_loss = MSELossFlat()
loss = mse_loss(out, torch.zeros_like(out))
loss

tensor(0.3029, device='cuda:0', grad_fn=<MseLossBackward>)

In [None]:
loss.backward()

In [None]:
loss

tensor(0.3029, device='cuda:0', grad_fn=<MseLossBackward>)

In [None]:
#export
class PhyCell(nn.Module):
    def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device):
        super(PhyCell, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.F_hidden_dims = F_hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H = []  
        self.device = device
             
        cell_list = []
        for i in range(0, self.n_layers):
        #    cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]

            cell_list.append(PhyCell_Cell(input_dim=input_dim,
                                          F_hidden_dim=self.F_hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j] = cell(input_, self.H[j])
            else:
                self.H[j] = cell(self.H[j-1],self.H[j])
        
        return self.H , self.H 
    
    def initHidden(self,batch_size):
        self.H = [] 
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device) )

    def setHidden(self, H):
        self.H = H
  

In [None]:
phy = PhyCell((6,6), 8, [8,8], n_layers=2, kernel_size=(3,3),device=0).cuda()
out, states = phy(torch.rand(1,8,6,6).cuda(), True)
out = torch.stack(out, dim=1)
states = torch.stack(states, dim=0)

In [None]:
out.shape, states.shape

(torch.Size([1, 2, 8, 6, 6]), torch.Size([2, 1, 8, 6, 6]))

In [None]:
mse_loss = MSELossFlat()
loss = mse_loss(out, torch.zeros_like(out))
loss

tensor(0.0822, device='cuda:0', grad_fn=<MseLossBackward>)

In [None]:
loss.backward()

In [None]:
loss

tensor(0.0822, device='cuda:0', grad_fn=<MseLossBackward>)

In [None]:
#export
class ConvLSTM_Cell(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1):              
        """
        input_shape: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """
        super(ConvLSTM_Cell, self).__init__()
        
        self.height, self.width = input_shape
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding, bias=self.bias)
                 
    # we implement LSTM that process only one timestep 
    def forward(self,x, hidden): # x [batch, hidden_dim, width, height]          
        h_cur, c_cur = hidden
        
        combined = torch.cat([x, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

In [None]:
conv_cell = ConvLSTM_Cell((6,6), 8, 8, kernel_size=(3,3)).cuda()
h, c = conv_cell(torch.rand(1,8,6,6).cuda(), (torch.rand(1,8,6,6).cuda(), torch.rand(1,8,6,6).cuda()))

In [None]:
#export
class ConvLSTM(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size,device):
        super(ConvLSTM, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H, self.C = [],[]   
        self.device = device
        
        cell_list = []
        for i in range(0, self.n_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
            print('layer ',i,'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i])
            cell_list.append(ConvLSTM_Cell(input_shape=self.input_shape,
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j]))
            else:
                self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j]))
        
        return (self.H,self.C) , self.H   # (hidden, output)
    
    def initHidden(self,batch_size):
        self.H, self.C = [],[]  
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
            self.C.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
    
    def setHidden(self, hidden):
        H,C = hidden
        self.H, self.C = H,C
 

In [None]:
conv = ConvLSTM((6,6), 8, [8], n_layers=1, kernel_size=(3,3), device=0).cuda()
h, c = conv(torch.rand(1,8,6,6).cuda(), True)

layer  0 input dim  8  hidden dim  8


In [None]:
#export
class dcgan_conv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_conv, self).__init__()
        self.main = nn.Sequential(
                nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3), stride=stride, padding=1),
                nn.GroupNorm(4,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)

        
class dcgan_upconv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_upconv, self).__init__()
        if (stride ==2):
            output_padding = 1
        else:
            output_padding = 0
        self.main = nn.Sequential(
                nn.ConvTranspose2d(in_channels=nin,out_channels=nout,kernel_size=(3,3), stride=stride,padding=1,output_padding=output_padding),
                nn.GroupNorm(4,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)     

In [None]:
#export
class image_encoder(nn.Module):
    def __init__(self, nc=1):
        super(image_encoder, self).__init__()
        nf = 16
        # input is (nc) x 64 x 64
        self.c1 = dcgan_conv(nc, int(nf/2), stride=1) # (nf) x 64 x 64
        self.c2 = dcgan_conv(int(nf/2), nf, stride=1) # (nf) x 64 x 64
        self.c3 = dcgan_conv(nf, nf*2, stride=2) # (2*nf) x 32 x 32
        self.c4 = dcgan_conv(nf*2, nf*2, stride=1) # (2*nf) x 32 x 32
        self.c5 = dcgan_conv(nf*2, nf*4, stride=2) # (4*nf) x 16 x 16
        self.c6 = dcgan_conv(nf*4, nf*4, stride=1) # (4*nf) x 16 x 16          

    def forward(self, input):
        h1 = self.c1(input)  # (nf/2) x 64 x 64
        h2 = self.c2(h1)     # (nf) x 64 x 64
        h3 = self.c3(h2)     # (2*nf) x 32 x 32
        h4 = self.c4(h3)     # (2*nf) x 32 x 32
        h5 = self.c5(h4)     # (4*nf) x 16 x 16
        h6 = self.c6(h5)     # (4*nf) x 16 x 16          
        return h6, [h1, h2, h3, h4, h5, h6]

In [None]:
img_encoder = image_encoder()

In [None]:
#export
class image_decoder(nn.Module):
    def __init__(self, nc=1):
        super(image_decoder, self).__init__()
        nf = 16
        self.upc1 = dcgan_upconv(nf*4*2, nf*4, stride=1) #(nf*4) x 16 x 16
        self.upc2 = dcgan_upconv(nf*4*2, nf*2, stride=2) #(nf*2) x 32 x 32
        self.upc3 = dcgan_upconv(nf*2*2, nf*2, stride=1) #(nf*2) x 32 x 32
        self.upc4 = dcgan_upconv(nf*2*2, nf, stride=2)   #(nf) x 64 x 64
        self.upc5 = dcgan_upconv(nf*2, int(nf/2), stride=1)   #(nf/2) x 64 x 64
        self.upc6 = nn.ConvTranspose2d(in_channels=nf,out_channels=nc,kernel_size=(3,3),stride=1,padding=1)  #(nc) x 64 x 64

    def forward(self, input):
        vec, skip = input    # vec: (4*nf) x 16 x 16          
        [h1, h2, h3, h4, h5, h6] = skip
        d1 = self.upc1(torch.cat([vec, h6], dim=1))  #(nf*4) x 16 x 16
        d2 = self.upc2(torch.cat([d1, h5], dim=1))   #(nf*2) x 32 x 32
        d3 = self.upc3(torch.cat([d2, h4], dim=1))   #(nf*2) x 32 x 32
        d4 = self.upc4(torch.cat([d3, h3], dim=1))   #(nf) x 64 x 64
        d5 = self.upc5(torch.cat([d4, h2], dim=1))   #(nf/2) x 64 x 64
        d6 = self.upc6(torch.cat([d5, h1], dim=1))   #(nc) x 64 x 64
        return d6
        

In [None]:
img_decoder = image_decoder()

In [None]:
#export
class EncoderRNN(torch.nn.Module):
    def __init__(self,phycell,convlstm, device):
        super(EncoderRNN, self).__init__()
        self.image_cnn_enc = image_encoder().to(device) # image encoder 64x64x1 -> 16x16x64
        self.image_cnn_dec = image_decoder().to(device) # image decoder 16x16x64 -> 64x64x1 
        
        self.phycell = phycell.to(device)
        self.convlstm = convlstm.to(device)

        
    def forward(self, input, first_timestep=False, decoding=False):
        if decoding:  # input=None in decoding phase
            output_phys = None
        else:
            output_phys,skip = self.image_cnn_enc(input)
        output_conv,skip = self.image_cnn_enc(input)     

        hidden1, output1 = self.phycell(output_phys, first_timestep)
        hidden2, output2 = self.convlstm(output_conv, first_timestep)

        out_phys = torch.sigmoid(self.image_cnn_dec([output1[-1],skip])) # partial reconstructions for vizualization
        out_conv = torch.sigmoid(self.image_cnn_dec([output2[-1],skip]))

        concat = output1[-1]+output2[-1]
        output_image =  self.image_cnn_dec([concat,skip]) 
        return out_phys, hidden1, output_image, out_phys, out_conv

In [None]:
device=0
phycell =  PhyCell(input_shape=(16,16), input_dim=64, F_hidden_dims=[49], n_layers=1, kernel_size=(7,7), device=device) 
convlstm =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
encoder = EncoderRNN(phycell, convlstm, device)

layer  0 input dim  64  hidden dim  128
layer  1 input dim  128  hidden dim  128
layer  2 input dim  128  hidden dim  64


## Loss

very horrible imports!

In [None]:
#export
from numpy import *
from numpy.linalg import *
from scipy.special import factorial

In [None]:
#export
def _apply_axis_left_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    for i in range(k):
        x = tensordot(mats[k-i-1], x, dim=[1,k])
    x = x.permute([k,]+list(range(k))).contiguous()
    x = x.view(sizex)
    return x

def _apply_axis_right_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    x = x.permute(list(range(1,k+1))+[0,])
    for i in range(k):
        x = tensordot(x, mats[i], dim=[0,0])
    x = x.contiguous()
    x = x.view(sizex)
    return x

In [None]:
#export
class _MK(nn.Module):
    def __init__(self, shape):
        super(_MK, self).__init__()
        self._size = torch.Size(shape)
        self._dim = len(shape)
        M = []
        invM = []
        assert len(shape) > 0
        j = 0
        for l in shape:
            M.append(zeros((l,l)))
            for i in range(l):
                M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i)
            invM.append(inv(M[-1]))
            self.register_buffer('_M'+str(j), torch.from_numpy(M[-1]))
            self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1]))
            j += 1

    @property
    def M(self):
        return list(self._buffers['_M'+str(j)] for j in range(self.dim()))
    @property
    def invM(self):
        return list(self._buffers['_invM'+str(j)] for j in range(self.dim()))

    def size(self):
        return self._size
    def dim(self):
        return self._dim
    def _packdim(self, x):
        assert x.dim() >= self.dim()
        if x.dim() == self.dim():
            x = x[newaxis,:]
        x = x.contiguous()
        x = x.view([-1,]+list(x.size()[-self.dim():]))
        return x

    def forward(self):
        pass

In [None]:
m = _MK((7,7))

In [None]:
#export
class M2K(_MK):
    """
    convert moment matrix to convolution kernel
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        m2k = M2K([5,5])
        m = torch.randn(5,5,dtype=torch.float64)
        k = m2k(m)
    """
    def __init__(self, shape):
        super(M2K, self).__init__(shape)
    def forward(self, m):
        """
        m (Tensor): torch.size=[...,*self.shape]
        """
        sizem = m.size()
        m = self._packdim(m)
        m = _apply_axis_left_dot(m, self.invM)
        m = m.view(sizem)
        return m

In [None]:
#export
class K2M(_MK):
    """
    convert convolution kernel to moment matrix
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        k2m = K2M([5,5])
        k = torch.randn(5,5,dtype=torch.float64)
        m = k2m(k)
    """
    def __init__(self, shape):
        super(K2M, self).__init__(shape)
    def forward(self, k):
        """
        k (Tensor): torch.size=[...,*self.shape]
        """
        sizek = k.size()
        k = self._packdim(k)
        k = _apply_axis_left_dot(k, self.M)
        k = k.view(sizek)
        return k

In [None]:
#export    
def tensordot(a,b,dim):
    """
    tensordot in PyTorch, see numpy.tensordot?
    """
    l = lambda x,y:x*y
    if isinstance(dim,int):
        a = a.contiguous()
        b = b.contiguous()
        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-dim]
        sizea1 = sizea[-dim:]
        sizeb0 = sizeb[:dim]
        sizeb1 = sizeb[dim:]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    else:
        adims = dim[0]
        bdims = dim[1]
        adims = [adims,] if isinstance(adims, int) else adims
        bdims = [bdims,] if isinstance(bdims, int) else bdims
        adims_ = set(range(a.dim())).difference(set(adims))
        adims_ = list(adims_)
        adims_.sort()
        perma = adims_+adims
        bdims_ = set(range(b.dim())).difference(set(bdims))
        bdims_ = list(bdims_)
        bdims_.sort()
        permb = bdims+bdims_
        a = a.permute(*perma).contiguous()
        b = b.permute(*permb).contiguous()

        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-len(adims)]
        sizea1 = sizea[-len(adims):]
        sizeb0 = sizeb[:len(bdims)]
        sizeb1 = sizeb[len(bdims):]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    a = a.view([-1,N])
    b = b.view([N,-1])
    c = a@b
    return c.view(sizea0+sizeb1)

## PhyDNet

In [None]:
#export
class PhyDNet(Module):
    def __init__(self, encoder, criterion=MSELossFlat()): 
        store_attr()
        self.pr = 0
        self.k2m = K2M([7,7])
        self.constraints = torch.zeros((49,7,7))
        ind = 0
        for i in range(0,7):
            for j in range(0,7):
                self.constraints[ind,i,j] = 1
                ind +=1  

    def forward(self, input_tensor, target_tensor=None):
        device = one_param(self).device
        
        input_length  = input_tensor.size(1)
        target_length = target_tensor.size(1)
        loss = 0
        for ei in range(input_length-1): 
            encoder_output, encoder_hidden, output_image,_,_ = self.encoder(input_tensor[:,ei,:,:,:], (ei==0) )
            loss += self.criterion(output_image,input_tensor[:,ei+1,:,:,:])
        
        decoder_input = input_tensor[:,-1,:,:,:] # first decoder input = last image of input sequence
        
        output_images = []
        if (target_tensor is not None) and (random.random()<self.pr):
            for di in range(target_length):
                decoder_output, decoder_hidden, output_image,_,_ = self.encoder(decoder_input)
                target = target_tensor[:,di,:,:,:]
#                 loss += self.criterion(output_image,target)
                output_images.append(output_image)
                decoder_input = target 
        else:
            for di in range(target_length):
                decoder_output, decoder_hidden, output_image,_,_ = self.encoder(decoder_input)
                decoder_input = output_image
                target = target_tensor[:,di,:,:,:]
                output_images.append(output_image)
#                 loss += self.criterion(output_image, target)
                
        # Moment Regularisation  encoder.phycell.cell_list[0].F.conv1.weight # size (nb_filters,in_channels,7,7)
        for b in range(0,self.encoder.phycell.cell_list[0].input_dim):
            filters = self.encoder.phycell.cell_list[0].F.conv1.weight[:,b,:,:] # (nb_filters,7,7)
            m = self.k2m(filters.double()) 
            m  = m.float()   
            loss += self.criterion(m, self.constraints.to(device)) # constrains is a precomputed matrix   
        return torch.stack(output_images, dim=1), loss

In [None]:
phynet = PhyDNet(encoder).cuda()

In [None]:
output, loss = phynet(torch.rand(1,5,1,64,64).cuda(), target_tensor=torch.rand(1,5,1,64,64).cuda())

In [None]:
output.shape, loss

(torch.Size([1, 5, 1, 64, 64]),
 tensor(10.3940, device='cuda:0', grad_fn=<AddBackward0>))

# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_data.ipynb.
Converted 01_models.conv_rnn.ipynb.
Converted 02_models.dcn.ipynb.
Converted 02_models.transformer.ipynb.
Converted 02_tcn.ipynb.
Converted 03_phy.ipynb.
Converted 03_phy_original.ipynb.
Converted 04_seq2seq.ipynb.
Converted index.ipynb.
