### U-net

In [1]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
np.set_printoptions(precision=2)
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
print ("Ready.")

Ready.


### Util codes

In [2]:
def np2torch(x_np,dtype=torch.float32,device='cpu'):
    x_torch = torch.tensor(x_np,dtype=dtype,device=device)
    return x_torch
def torch2np(x_torch):
    x_np = x_torch.detach().cpu().numpy()
    return x_np
class SinPositionEmbeddingsClass(nn.Module):
    def __init__(self,dim=128,T=1000):
        super().__init__()
        self.dim = dim
        self.T = T
    @torch.no_grad()
    def forward(self,steps=torch.arange(start=0,end=1000,step=1)):
        device = steps.device
        half_dim = self.dim // 2
        embeddings = math.log(self.T) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = steps[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
print ("Ready.")

Ready.


In [3]:
class DenseBlockClass(nn.Module):
    def __init__(self,in_dim=10,out_dim=5,pos_emb_dim=10):
        """
            Initialize
        """
        super(DenseBlockClass,self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.pos_emb_dim = pos_emb_dim
        
        self.dense1 = nn.Linear(self.in_dim,self.out_dim)
        self.bnorm1 = nn.BatchNorm1d(self.out_dim)
        self.dense2 = nn.Linear(self.out_dim,self.out_dim)
        self.bnorm2 = nn.BatchNorm1d(self.out_dim)
        self.actv   = nn.ReLU()
        self.pos_emb_mlp = nn.Linear(self.pos_emb_dim,self.out_dim)
        
    def forward(self,x,t):
        """
            Forward
        """
        h = self.bnorm1(self.actv(self.dense1(x))) # dense -> actv -> bnrom1 [B x out_dim]
        h = h + self.pos_emb_mlp(t) # [B x out_dim]
        h = self.bnorm2(self.actv(self.dense2(h))) # [B x out_dim]
        return h
        
DenseBlock = DenseBlockClass(in_dim=100,out_dim=20,pos_emb_dim=64)
x = torch.randn(128,100)
t = torch.randn(128,64)
out = DenseBlock(x=x,t=t) # forward Block

print ("x:%s,t:%s to out:%s"%(x.shape,t.shape,out.shape))

x:torch.Size([128, 100]),t:torch.Size([128, 64]) to out:torch.Size([128, 20])


### Dense U-Net 
<img src="../img/DenseUNet.jpeg" width="500" />
<img src="../img/Block.jpeg" width="500" />

In [4]:
class DenseUNetClass(nn.Module):
    def __init__(self,name='dense_unet',
                 x_dim=300,pos_emb_dim=128,h_dims=[128,64],z_dim=32):
        """
            Initialize
        """
        super(DenseUNetClass,self).__init__()
        self.name        = name
        self.x_dim       = x_dim
        self.pos_emb_dim = pos_emb_dim
        self.h_dims      = h_dims
        self.z_dim       = z_dim
        # Initialize layers
        self.init_layers()
        
    def init_layers(self):
        """
            Initialize layers
        """
        self.layers = {}
        # Encoder
        h_prev = self.x_dim
        for h_idx,h_dim in enumerate(self.h_dims):
            self.layers['Enc_%02d'%(h_idx)] = DenseBlockClass(
                in_dim=h_prev,out_dim=h_dim,pos_emb_dim=self.pos_emb_dim)
            h_prev = h_dim
        self.layers['Enc_%02d'%(len(self.h_dims))] = DenseBlockClass(
            in_dim=self.h_dims[-1],out_dim=self.z_dim,pos_emb_dim=self.pos_emb_dim)
        # Map
        self.layers['Map'] = DenseBlockClass(
            in_dim=self.z_dim,out_dim=self.z_dim,pos_emb_dim=self.pos_emb_dim)
        # Decoder
        h_prev = self.z_dim
        for h_idx,h_dim in enumerate(self.h_dims[::-1]):
            self.layers['Dec_%02d'%(h_idx)] = DenseBlockClass(
                in_dim=h_prev,out_dim=h_dim,pos_emb_dim=self.pos_emb_dim)
            h_prev = 2*h_dim
        self.layers['Dec_%02d'%(len(self.h_dims))] = DenseBlockClass(
            in_dim=h_prev,out_dim=self.x_dim,
            pos_emb_dim=self.pos_emb_dim)
        # Out
        self.layers['Out'] = DenseBlockClass(
            in_dim=2*self.x_dim,out_dim=self.x_dim,pos_emb_dim=self.pos_emb_dim)
    
    def forward(self,x,t):
        """
            Forward
        """
        net = x # [B x x_dim]
        # Net
        self.nets = []
        # Encoder 
        self.enc_paths = []
        self.enc_paths.append(net)
        self.nets.append(net)
        for h_idx in range(len(self.h_dims)+1):
            net = self.layers['Enc_%02d'%(h_idx)](net,t)
            self.enc_paths.append(net)
            self.nets.append(net)
        # Map
        net = self.layers['Map'](net,t)
        self.nets.append(net)
        # Decoder
        self.dec_paths = []
        for h_idx in range(len(self.h_dims)+1):
            net = self.layers['Dec_%02d'%(h_idx)](net,t)
            net = torch.cat([self.enc_paths[len(self.h_dims)-h_idx],net],dim=1)
            self.dec_paths.append(net)
            self.nets.append(net)
        net = self.layers['Out'](net,t)
        self.nets.append(net)
        return net
    
print ("Ready.")        

Ready.


### Demo forward path

In [5]:
DUNet = DenseUNetClass(x_dim=300,pos_emb_dim=32,h_dims=[128,64],z_dim=32)
x = torch.randn(128,300)
t = torch.randn(128,32)
DUNet(x=x,t=t)
print ("Ready")

Ready


### Print

In [6]:
print ("Layer information")
for key_idx,key in enumerate(DUNet.layers.keys()):
    layer = DUNet.layers[key]
    print ("[%d/%d][%7s] [%03d] =>[%03d]"%
           (key_idx,len(DUNet.layers.keys()),key,layer.dense1.in_features,layer.dense1.out_features))
    
print ("Network information")
for net_idx in range(len(DUNet.nets)):
    net = DUNet.nets[net_idx]
    print ("[%02d/%02d] %s"%(net_idx,len(DUNet.nets),net.shape))    

Layer information
[0/8][ Enc_00] [300] =>[128]
[1/8][ Enc_01] [128] =>[064]
[2/8][ Enc_02] [064] =>[032]
[3/8][    Map] [032] =>[032]
[4/8][ Dec_00] [032] =>[064]
[5/8][ Dec_01] [128] =>[128]
[6/8][ Dec_02] [256] =>[300]
[7/8][    Out] [600] =>[300]
Network information
[00/09] torch.Size([128, 300])
[01/09] torch.Size([128, 128])
[02/09] torch.Size([128, 64])
[03/09] torch.Size([128, 32])
[04/09] torch.Size([128, 32])
[05/09] torch.Size([128, 128])
[06/09] torch.Size([128, 256])
[07/09] torch.Size([128, 600])
[08/09] torch.Size([128, 300])
