### U-net

In [1]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
np.set_printoptions(precision=2)
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
print ("Ready.")

Ready.


### Util codes

In [2]:
def np2torch(x_np,dtype=torch.float32,device='cpu'):
    x_torch = torch.tensor(x_np,dtype=dtype,device=device)
    return x_torch
def torch2np(x_torch):
    x_np = x_torch.detach().cpu().numpy()
    return x_np
class SinPositionEmbeddingsClass(nn.Module):
    def __init__(self,dim=128,T=1000):
        super().__init__()
        self.dim = dim
        self.T = T
    @torch.no_grad()
    def forward(self,steps=torch.arange(start=0,end=1000,step=1)):
        device = steps.device
        half_dim = self.dim // 2
        embeddings = np.log(self.T) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = steps[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
print ("Ready.")

Ready.


In [3]:
class DenseBlockClass(nn.Module):
    def __init__(self,in_dim=10,out_dim=5,pos_emb_dim=10,USE_POS_EMB=True):
        """
            Initialize
        """
        super(DenseBlockClass,self).__init__()
        self.in_dim      = in_dim
        self.out_dim     = out_dim
        self.pos_emb_dim = pos_emb_dim
        self.USE_POS_EMB = USE_POS_EMB
        
        self.dense1 = nn.Linear(self.in_dim,self.out_dim)
        self.bnorm1 = nn.BatchNorm1d(self.out_dim)
        self.dense2 = nn.Linear(self.out_dim,self.out_dim)
        self.bnorm2 = nn.BatchNorm1d(self.out_dim)
        self.actv   = nn.ReLU()
        self.pos_emb_mlp = nn.Linear(self.pos_emb_dim,self.out_dim)
        
    def forward(self,x,t):
        """
            Forward
        """
        h = self.bnorm1(self.actv(self.dense1(x))) # dense -> actv -> bnrom1 [B x out_dim]
        if self.USE_POS_EMB:
            h = h + self.pos_emb_mlp(t) # [B x out_dim]
        h = self.bnorm2(self.actv(self.dense2(h))) # [B x out_dim]
        return h
        
DenseBlock = DenseBlockClass(in_dim=100,out_dim=20,pos_emb_dim=64)
x = torch.randn(128,100)
t = torch.randn(128,64)
out = DenseBlock(x=x,t=t) # forward Block

print ("x:%s,t:%s to out:%s"%(x.shape,t.shape,out.shape))

x:torch.Size([128, 100]),t:torch.Size([128, 64]) to out:torch.Size([128, 20])


### Dense U-Net 
<img src="../img/DenseUNet.jpeg" width="500" />
<img src="../img/Block.jpeg" width="500" />

In [4]:
class DenseUNetClass(nn.Module):
    def __init__(self,name='dense_unet',
                 x_dim=300,pos_emb_dim=128,h_dims=[128,64],z_dim=32,
                 USE_POS_EMB=True):
        """
            Initialize
        """
        super(DenseUNetClass,self).__init__()
        self.name        = name
        self.x_dim       = x_dim
        self.pos_emb_dim = pos_emb_dim
        self.h_dims      = h_dims
        self.z_dim       = z_dim
        self.USE_POS_EMB = True
        # Initialize layers
        self.init_layers()
        
    def init_layers(self):
        """
            Initialize layers
        """
        self.layers = nn.ModuleDict()
        # Encoder
        h_prev = self.x_dim
        for h_idx,h_dim in enumerate(self.h_dims):
            self.layers['Enc_%02d'%(h_idx)] = DenseBlockClass(
                in_dim=h_prev,out_dim=h_dim,pos_emb_dim=self.pos_emb_dim,
                USE_POS_EMB = self.USE_POS_EMB)
            h_prev = h_dim
        self.layers['Enc_%02d'%(len(self.h_dims))] = DenseBlockClass(
            in_dim=self.h_dims[-1],out_dim=self.z_dim,pos_emb_dim=self.pos_emb_dim,
            USE_POS_EMB = self.USE_POS_EMB)
        # Map
        self.layers['Map'] = DenseBlockClass(
            in_dim=self.z_dim,out_dim=self.z_dim,pos_emb_dim=self.pos_emb_dim,
            USE_POS_EMB = self.USE_POS_EMB)
        # Decoder
        h_prev = self.z_dim
        for h_idx,h_dim in enumerate(self.h_dims[::-1]):
            self.layers['Dec_%02d'%(h_idx)] = DenseBlockClass(
                in_dim=h_prev,out_dim=h_dim,pos_emb_dim=self.pos_emb_dim,
                USE_POS_EMB = self.USE_POS_EMB)
            h_prev = 2*h_dim
        self.layers['Dec_%02d'%(len(self.h_dims))] = DenseBlockClass(
            in_dim=h_prev,out_dim=self.x_dim,
            pos_emb_dim=self.pos_emb_dim,USE_POS_EMB = self.USE_POS_EMB)
        # Out
        self.layers['Out'] = DenseBlockClass(
            in_dim=2*self.x_dim,out_dim=self.x_dim,pos_emb_dim=self.pos_emb_dim,
            USE_POS_EMB = self.USE_POS_EMB)
    
    def forward(self,x,t):
        """
            Forward
        """
        net = x # [B x x_dim]
        # Net
        self.nets = []
        # Encoder 
        self.enc_paths = []
        self.enc_paths.append(net)
        self.nets.append(net)
        for h_idx in range(len(self.h_dims)+1):
            net = self.layers['Enc_%02d'%(h_idx)](net,t)
            self.enc_paths.append(net)
            self.nets.append(net)
        # Map
        net = self.layers['Map'](net,t)
        self.nets.append(net)
        # Decoder
        self.dec_paths = []
        for h_idx in range(len(self.h_dims)+1):
            net = self.layers['Dec_%02d'%(h_idx)](net,t)
            net = torch.cat([self.enc_paths[len(self.h_dims)-h_idx],net],dim=1)
            self.dec_paths.append(net)
            self.nets.append(net)
        net = self.layers['Out'](net,t)
        self.nets.append(net)
        return net
    
print ("Ready.")        

Ready.


### Demo forward path

In [5]:
DUNet = DenseUNetClass(x_dim=300,pos_emb_dim=32,h_dims=[128,64],z_dim=32,USE_POS_EMB=True)
x = torch.randn(128,300)
t = torch.randn(128,32)
DUNet(x=x,t=t)
print ("Ready")

Ready


### Print

In [6]:
print ("Layer information")
for key_idx,key in enumerate(DUNet.layers.keys()):
    layer = DUNet.layers[key]
    print ("[%d/%d][%7s] [%03d] =>[%03d]"%
           (key_idx,len(DUNet.layers.keys()),key,layer.dense1.in_features,layer.dense1.out_features))
    
print ("Network information")
for net_idx in range(len(DUNet.nets)):
    net = DUNet.nets[net_idx]
    print ("[%02d/%02d] %s"%(net_idx,len(DUNet.nets),net.shape))    

Layer information
[0/8][ Enc_00] [300] =>[128]
[1/8][ Enc_01] [128] =>[064]
[2/8][ Enc_02] [064] =>[032]
[3/8][    Map] [032] =>[032]
[4/8][ Dec_00] [032] =>[064]
[5/8][ Dec_01] [128] =>[128]
[6/8][ Dec_02] [256] =>[300]
[7/8][    Out] [600] =>[300]
Network information
[00/09] torch.Size([128, 300])
[01/09] torch.Size([128, 128])
[02/09] torch.Size([128, 64])
[03/09] torch.Size([128, 32])
[04/09] torch.Size([128, 32])
[05/09] torch.Size([128, 128])
[06/09] torch.Size([128, 256])
[07/09] torch.Size([128, 600])
[08/09] torch.Size([128, 300])


### Projecting U-net

In [7]:
class SinPositionEmbeddingsClass(nn.Module):
    def __init__(self,dim=128,T=1000):
        super().__init__()
        self.dim = dim
        self.T = T
    @torch.no_grad()
    def forward(self,steps=torch.arange(start=0,end=1000,step=1)):
        device = steps.device
        half_dim = self.dim // 2
        embeddings = np.log(self.T) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = steps[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
    
class DenseBlockClass(nn.Module):
    def __init__(self,in_dim=10,out_dim=5,pos_emb_dim=10,USE_POS_EMB=True):
        """
            Initialize dense block network with posisional embedding 
        """
        super(DenseBlockClass,self).__init__()
        self.in_dim      = in_dim
        self.out_dim     = out_dim
        self.pos_emb_dim = pos_emb_dim
        self.USE_POS_EMB = USE_POS_EMB
        # Simple block consists of two layers
        self.dense1 = nn.Linear(self.in_dim,self.out_dim)
        self.bnorm1 = nn.BatchNorm1d(self.out_dim)
        self.dense2 = nn.Linear(self.out_dim,self.out_dim)
        self.bnorm2 = nn.BatchNorm1d(self.out_dim)
        self.actv   = nn.ReLU()
        # Positional embedding
        self.pos_emb_mlp = nn.Linear(self.pos_emb_dim,self.out_dim)
    def forward(self,x,t):
        """
            Forward
        """
        h = self.bnorm1(self.actv(self.dense1(x))) # dense -> actv -> bnrom1 [B x out_dim]
        if self.USE_POS_EMB:
            h = h + self.pos_emb_mlp(t) # [B x out_dim]
        h = self.bnorm2(self.actv(self.dense2(h))) # [B x out_dim]
        return h
    
class DenoisingDenseUnetClass(nn.Module):
    def __init__(self,
                 name        = 'denoising_dense_unet',
                 D           = 3,
                 L           = 100,
                 T           = 1000, # max diffusion steps
                 pos_emb_dim = 127,
                 h_dims      = [128,64],
                 z_dim       = 32,
                 USE_POS_EMB = True,
                 Gammas      = None
                ):
        """
            Initialize denoising dense unet for DDPM
            The input dimension would be L*D
        """
        super(DenoisingDenseUnetClass,self).__init__()
        self.name        = name
        self.D           = D
        self.L           = L
        self.T           = T
        self.x_dim       = self.D * self.L
        self.pos_emb_dim = pos_emb_dim
        self.h_dims      = h_dims
        self.z_dim       = z_dim
        self.USE_POS_EMB = USE_POS_EMB
        self.Gammas      = Gammas # RKHS projections
        # Initialize layers
        self.init_layers()

    def init_layers(self):
        """
            Initialize layers
        """
        self.layers = nn.ModuleDict()
        # Encoder (x->z)
        h_prev = self.x_dim
        for h_idx,h_dim in enumerate(self.h_dims):
            self.layers['Enc_%02d'%(h_idx)] = DenseBlockClass(
                in_dim=h_prev,out_dim=h_dim,pos_emb_dim=self.pos_emb_dim,
                USE_POS_EMB = self.USE_POS_EMB)
            h_prev = h_dim
        self.layers['Enc_%02d'%(len(self.h_dims))] = DenseBlockClass(
            in_dim=self.h_dims[-1],out_dim=self.z_dim,pos_emb_dim=self.pos_emb_dim,
            USE_POS_EMB = self.USE_POS_EMB)
        # Map (z->z)
        self.layers['Map'] = DenseBlockClass(
            in_dim=self.z_dim,out_dim=self.z_dim,pos_emb_dim=self.pos_emb_dim,
            USE_POS_EMB = self.USE_POS_EMB)
        # Decoder (z->x)
        h_prev = self.z_dim
        for h_idx,h_dim in enumerate(self.h_dims[::-1]):
            self.layers['Dec_%02d'%(h_idx)] = DenseBlockClass(
                in_dim=h_prev,out_dim=h_dim,pos_emb_dim=self.pos_emb_dim,
                USE_POS_EMB = self.USE_POS_EMB)
            h_prev = 2*h_dim
        self.layers['Dec_%02d'%(len(self.h_dims))] = DenseBlockClass(
            in_dim=h_prev,out_dim=self.x_dim,
            pos_emb_dim=self.pos_emb_dim,USE_POS_EMB = self.USE_POS_EMB)
        # Out
        self.layers['Out'] = DenseBlockClass(
            in_dim=2*self.x_dim,out_dim=self.x_dim,pos_emb_dim=self.pos_emb_dim,
            USE_POS_EMB = self.USE_POS_EMB)
        # Time embedding
        self.layers['Pos_Emb'] = nn.Sequential(
                SinPositionEmbeddingsClass(dim=self.pos_emb_dim,T=self.T),
                nn.Linear(self.pos_emb_dim,self.pos_emb_dim),
                nn.GELU()
            )
    def forward(self,x,t):
        """
            Forward
            x: [B x x_dim]
            t: [B x 1]
        """
        net = x # [B x x_dim]
        # Positional Embedding
        pos_emb = self.layers['Pos_Emb'](t) # [B x pos_emb_dim]
        # Net
        self.nets = {}
        # Encoder 
        self.enc_paths = []
        self.enc_paths.append(net)
        self.nets['x'] = net
        for h_idx in range(len(self.h_dims)+1):
            net = self.layers['Enc_%02d'%(h_idx)](net,pos_emb)
            self.enc_paths.append(net)
            self.nets['Enc_%02d'%(h_idx)] = net
        # Map
        net = self.layers['Map'](net,pos_emb) # [B x z_dim]
        self.nets['Map'] = net # [B x z_dim]
        # Decoder
        self.dec_paths = []
        for h_idx in range(len(self.h_dims)+1):
            net = self.layers['Dec_%02d'%(h_idx)](net,pos_emb)
            net = torch.cat([self.enc_paths[len(self.h_dims)-h_idx],net],dim=1)
            self.dec_paths.append(net)
            self.nets['Dec_%02d'%(h_idx)] = net
        net = self.layers['Out'](net,pos_emb) # [B x LD]
        
        # RKHS projection
        self.Gammas # [D x L x L]
        
        
        # Return
        self.nets['Out'] = net  # [B x LD]
        return net
    
print ("Ready.")

Ready.


### Model usage

In [8]:
from scipy.spatial import distance
def kernel_se(x1,x2,hyp={'gain':1.0,'len':1.0}):
    """ Squared-exponential kernel function """
    D = distance.cdist(x1/hyp['len'],x2/hyp['len'],'sqeuclidean')
    K = hyp['gain']*np.exp(-D)
    return K
def get_gamma(times=np.linspace(start=0.0,stop=1.0,num=100).reshape((-1,1)), # [L x 1]
              hyp_len=1.0,reg_coef=1e-8):
    """
        RKHS projection
    """
    L = times.shape[0]
    K = kernel_se(times,times,hyp={'gain':1.0,'len':hyp_len}) # [L x L]
    Gamma = K @ np.linalg.inv(K + reg_coef*np.eye(L,L)) # [L x L]
    return Gamma
print ("Ready.")

Ready.


In [9]:
# Hyper parameters
D,L,T = 3,100,1000
times = np.linspace(start=0.0,stop=1.0,num=100).reshape((-1,1)) # [L x 1]
Gammas_np = np.zeros(shape=(D,L,L))
hyp_lens = [1,0.1,0.01]
for d_idx in range(D):
    hyp_len = hyp_lens[d_idx]
    Gammas_np[d_idx,:,:] = get_gamma(times=times,hyp_len=hyp_len,reg_coef=1e-6) # [L x L]
Gammas = np2torch(Gammas_np) # [D x L x L]
# Instantiate denoising dense U-net class
model = DenoisingDenseUnetClass(
    L=L,D=D,T=T,pos_emb_dim=16,h_dims=[64],z_dim=32,USE_POS_EMB=True,Gammas=Gammas)
# Forward path
B = 5 # batch size
x = torch.randn(B,D*L) # [B x LD]
steps = torch.zeros(B).type(torch.long) # [B]
out = model(x=x,t=steps) # [B x DL]
# Print-out forward path
for key_idx,key in enumerate(model.nets.keys()):
    net = model.nets[key]
    print ("[%02d] [%8s]: %s"%(key_idx,key,net.shape))
print ("Done.")

[00] [       x]: torch.Size([5, 300])
[01] [  Enc_00]: torch.Size([5, 64])
[02] [  Enc_01]: torch.Size([5, 32])
[03] [     Map]: torch.Size([5, 32])
[04] [  Dec_00]: torch.Size([5, 128])
[05] [  Dec_01]: torch.Size([5, 600])
[06] [     Out]: torch.Size([5, 300])
Done.
