In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def split(x):
    c = x.size(1) // 2
    y1 = x[:, :c]
    y2 = x[:, c:]
    
    return y1, y2
    
def merge(y1, y2):
    x = torch.cat([y1, y2], dim=1)
    
    return x

class Permutation1d(nn.Module):
    def __init__(self, c):
        super().__init__()
        
        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
        
        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1 * W[:,0]
        #W = torch.eye(c)
        self.W = nn.Parameter(W)
        
    def forward(self, x):
        batch, channel, time = x.size()

        log_det_W = torch.slogdet(self.W)[1]
        dlog_det = time * log_det_W
        y = F.conv1d(x, self.W[:, :, None])

        return y, dlog_det
        
    def set_inverse(self):
        self.W_inverse = self.W.inverse()
        
    def inverse(self, y):
        x = F.conv1d(y, self.W_inverse[:, :, None])
        
        return x

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
class Conv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.c = out_channels
        self.filter_out = nn.Conv1d(in_channels=in_channels, out_channels=out_channels*2, 
                                    kernel_size=kernel_size, stride=stride, padding=padding, bias=False)      
        self.filter_out.weight.data.normal_(0, 0.02)
        
    def forward(self, x):
        # x: (B, C, T)
        y = self.filter_out(x)
        y = F.tanh(y[:, :self.c]) * torch.sigmoid(y[:, self.c:])
        
        return y

In [16]:
class NonLinear1d(nn.Module):
    def __init__(self, channels, hidden_channels):
        super().__init__()
        self.convs = nn.ModuleList([Conv1d(in_channels=channels, out_channels=hidden_channels,
                                           kernel_size=3, stride=1, padding=1),
                                    Conv1d(in_channels=hidden_channels, out_channels=hidden_channels,
                                           kernel_size=1, stride=1, padding=0)])
        self.last_conv = nn.Conv1d(in_channels=hidden_channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.last_conv.weight.data.zero_()
        self.last_conv.bias.data.zero_()
        
    def forward(self, x):
        # x: (B, C, T)
        
        # (B, C, T)
        for i, conv in enumerate(self.convs):
            if i == 0:
                x = conv(x)
            else:
                x = x + conv(x)
        
        x = self.last_conv(x)
        
        return x

In [17]:
class Flow1d(nn.Module):
    def __init__(self, channels, hidden_channels):
        super().__init__()
        self.permutation = Permutation1d(channels)
        self.non_linear = NonLinear1d(channels//2, hidden_channels)
        
    def forward(self, x):

        # Permutation
        y, log_det_W = self.permutation(x)
        # Split
        y1, y2 = split(y)
        # Transform
        m = self.non_linear(y1)
        y2 = y2 + m
        
        # Merge
        y = merge(y1, y2)
        # Log-Determinant
        log_det = log_det_W
        
        return y, log_det
    
    def set_inverse(self):
        self.permutation.set_inverse()
    
    def inverse(self, y):
        # Split
        x1, x2 = split(y)
        # Inverse-Transform
        m = self.non_linear(x1)
        x2 = x2 - m
        # Merge
        x = merge(x1, x2)
        # Inverse-Permutation
        x = self.permutation.inverse(x)
        
        return x
    

In [18]:
class FlowModel(nn.Module):
    def __init__(self, channels, hidden_channels, n_layers):
        super().__init__()
        self.channels = channels
        self.flow_layers = nn.ModuleList([Flow1d(channels, hidden_channels) for _ in range(n_layers)])
        self.inverse_init = False
        
    def forward(self, x):
        
        z = x
        log_det = 0
        for flow_layer in self.flow_layers:
            z, dlog_det = flow_layer(z)
            log_det = log_det + dlog_det
            
        loss = self.get_loss(z, log_det)
        data = {'z': z,
                'log_det': log_det,
                'loss': loss
               }
        return data
    
    def get_loss(self, z, log_det):
        dim = z.size(1) * z.size(2)
        log_likelihood = torch.sum(-0.5 * (np.log(2*np.pi) + z**2), dim=(1, 2)) + log_det
        loss = torch.mean(-log_likelihood / dim)
        return loss
    
    def inference(self, z):
        if not self.inverse_init:
            self.inverse_init = True
            self.set_inverse()

        x = self.inverse(z)
        return x
        
    def inverse(self, z):
        x = z
        for flow_layer in reversed(self.flow_layers):
            x = flow_layer.inverse(x)
        return x
    
    def set_inverse(self):
        for flow_layer in self.flow_layers:
            flow_layer.set_inverse()

In [21]:
model = FlowModel(16, 128, 8)
x = torch.randn(2, 16, 100)
outputs = model(x)
print(x)
model.set_inverse()
x_recon = model.inverse(outputs['z'])
print(x_recon)

tensor([[[ 0.5605,  0.9309,  0.0881,  ...,  1.5359,  0.0424,  1.2882],
         [ 2.1455,  1.1861, -0.2822,  ..., -2.5990,  0.1391,  0.4910],
         [ 0.4745,  1.5509, -0.8069,  ...,  0.0503,  0.8072, -0.6042],
         ...,
         [-0.8192,  0.8339,  1.2938,  ..., -1.1961, -0.8061, -0.2306],
         [ 0.3780,  0.5642,  0.8813,  ..., -0.0719, -0.4430, -0.5078],
         [-0.1368,  1.5311, -0.7713,  ..., -1.3084, -0.7429,  0.7916]],

        [[-1.0273, -1.1636, -0.4735,  ...,  0.4123,  1.1214, -0.6715],
         [-0.9444, -0.1015,  0.5434,  ...,  1.0050, -0.2346,  0.8794],
         [-0.6714,  0.1147,  0.4112,  ..., -0.7877, -0.4188,  0.5138],
         ...,
         [ 0.2146,  1.9117,  0.5991,  ..., -0.1414,  0.4215,  0.7217],
         [ 0.5057,  2.2707,  0.6765,  ..., -1.7742,  1.4327, -2.1872],
         [ 0.7630, -0.6310, -0.0527,  ...,  1.7685,  0.4119, -1.2751]]])
tensor([[[ 0.5605,  0.9309,  0.0881,  ...,  1.5359,  0.0424,  1.2882],
         [ 2.1455,  1.1861, -0.2822,  ..., -2