In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def split(x):
    c = x.size(1) // 2
    y1 = x[:, :c]
    y2 = x[:, c:]
    
    return y1, y2
    
def merge(y1, y2):
    x = torch.cat([y1, y2], dim=1)
    
    return x

class Permutation1d(nn.Module):
    def __init__(self, c):
        super().__init__()
        
        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
        
        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1 * W[:,0]
        #W = torch.eye(c)
        self.W = nn.Parameter(W)
        
    def forward(self, x):
        batch, channel, time = x.size()

        log_det_W = torch.slogdet(self.W)[1]
        dlog_det = time * log_det_W
        y = F.conv1d(x, self.W[:, :, None])

        return y, dlog_det
        
    def set_inverse(self):
        self.W_inverse = self.W.inverse()
        
    def inverse(self, y):
        x = F.conv1d(y, self.W_inverse[:, :, None])
        
        return x

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ConditionedConv1d(nn.Module):
    def __init__(self, in_channels, cond_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.c = out_channels
        self.filter_out = nn.Conv1d(in_channels=in_channels+cond_channels, out_channels=out_channels*2, 
                              kernel_size=kernel_size, stride=stride, padding=padding, bias=False)      
        self.filter_out.weight.data.normal_(0, 0.02)
        
    def forward(self, x, cond):
        # x: (B, C, T)
        # cond: (B, C, T)
        y = self.filter_out(torch.cat([x, cond], dim=1))
        y = F.tanh(y[:, :self.c]) * torch.sigmoid(y[:, self.c:])
        
        return y

In [3]:
class NonLinear1d(nn.Module):
    def __init__(self, channels, cond_channels, hidden_channels):
        super().__init__()
        self.convs = nn.ModuleList([ConditionedConv1d(in_channels=channels, cond_channels=cond_channels, out_channels=hidden_channels,
                                                     kernel_size=3, stride=1, padding=1),
                                    ConditionedConv1d(in_channels=hidden_channels, cond_channels=cond_channels, out_channels=hidden_channels,
                                                     kernel_size=1, stride=1, padding=0)])
        
        self.last_conv = nn.Conv1d(in_channels=hidden_channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.last_conv.weight.data.zero_()
        self.last_conv.bias.data.zero_()
        
    def forward(self, x, cond):
        # x: (B, C, T)
        # cond: (B, C, T)
        
        # (B, C, T)
        for i, conv in enumerate(self.convs):
            if i == 0:
                x = conv(x, cond)
            else:
                x = x + conv(x, cond)
        
        x = self.last_conv(x)
        
        return x

In [4]:
class Flow1d(nn.Module):
    def __init__(self, channels, cond_channels, hidden_channels):
        super().__init__()
        self.permutation = Permutation1d(channels)
        self.non_linear = NonLinear1d(channels//2, cond_channels, hidden_channels)
        
    def forward(self, x, cond):

        # Permutation
        y, log_det_W = self.permutation(x)
        # Split
        y1, y2 = split(y)
        # Transform
        m = self.non_linear(y1, cond)
        y2 = y2 + m
        
        # Merge
        y = merge(y1, y2)
        # Log-Determinant
        log_det = log_det_W
        
        return y, log_det
    
    def set_inverse(self):
        self.permutation.set_inverse()
    
    def inverse(self, y, cond):
        # Split
        x1, x2 = split(y)
        # Inverse-Transform
        m = self.non_linear(x1, cond)
        x2 = x2 - m
        # Merge
        x = merge(x1, x2)
        # Inverse-Permutation
        x = self.permutation.inverse(x)
        
        return x
    

In [73]:
class FlowModel(nn.Module):
    def __init__(self, channels, cond_channels, hidden_channels, n_layers):
        super().__init__()
        self.channels = channels
        self.flow_layers = nn.ModuleList([Flow1d(channels, cond_channels, hidden_channels) for _ in range(n_layers)])
        self.inverse_init = False
        
    def forward(self, x, cond):
        
        z = x
        log_det = 0
        for flow_layer in self.flow_layers:
            z, dlog_det = flow_layer(z, cond)
            log_det = log_det + dlog_det
            
        loss = self.get_loss(z, log_det)
        data = {'z': z,
                'log_det': log_det,
                'loss': loss
               }
        return data
    
    def get_loss(self, z, log_det):
        dim = z.size(1) * z.size(2)
        log_likelihood = torch.sum(-0.5 * (np.log(2*np.pi) + z**2), dim=(1, 2)) + log_det
        loss = torch.mean(-log_likelihood / dim)
        return loss
    
    def inference(self, cond):
        if not self.inverse_init:
            self.inverse_init = True
            self.set_inverse()

        z = torch.randn(cond.shape[0], self.channels, cond.shape[2])
        x = self.inverse(z, cond)
        return x
        
    def inverse(self, z, cond):
        x = z
        for flow_layer in reversed(self.flow_layers):
            x = flow_layer.inverse(x, cond)
        return x
    
    def set_inverse(self):
        for flow_layer in self.flow_layers:
            flow_layer.set_inverse()

In [74]:
flow = FlowModel(62, 16, 128, 4)
flow.set_inverse()

x = torch.randn(2, 62, 100)
cond = torch.randn(2, 16, 100)
outputs = flow(x, cond)
x2 = flow.inverse(outputs['z'], cond)

In [80]:
x

tensor([[[ 0.9094, -0.9566, -0.4663,  ...,  0.9501,  1.2360,  2.0946],
         [ 0.6283,  0.6162,  0.4864,  ...,  1.2429,  0.3654, -3.7170],
         [-0.0999, -0.6048,  0.1541,  ...,  0.1173,  0.7031, -1.1231],
         ...,
         [-2.0162,  0.8025, -0.5839,  ..., -1.7456, -0.3185,  1.1787],
         [ 0.0261,  0.4257,  0.7985,  ..., -0.3326, -2.0682,  3.2813],
         [-1.0040,  0.3287,  1.1457,  ..., -1.0747, -0.6557, -0.9607]],

        [[ 1.6326, -0.3635,  0.5933,  ...,  0.1177,  0.8728,  0.9375],
         [ 2.5529,  0.0835, -0.2958,  ..., -0.5103, -0.3954, -0.7982],
         [ 0.1695,  1.9433, -0.6700,  ...,  0.8310,  0.1287, -0.3896],
         ...,
         [-0.8056, -1.3714,  1.0232,  ...,  0.2947, -0.1565, -1.2067],
         [ 0.6961,  2.1302,  0.1545,  ..., -0.6888, -0.6213,  0.6619],
         [ 0.8564, -0.0458, -0.5370,  ...,  1.0753, -1.5578, -2.3208]]])

In [75]:
x2

tensor([[[ 0.9094, -0.9566, -0.4663,  ...,  0.9501,  1.2360,  2.0946],
         [ 0.6283,  0.6162,  0.4864,  ...,  1.2429,  0.3654, -3.7170],
         [-0.0999, -0.6048,  0.1541,  ...,  0.1173,  0.7031, -1.1231],
         ...,
         [-2.0162,  0.8025, -0.5839,  ..., -1.7456, -0.3185,  1.1787],
         [ 0.0261,  0.4257,  0.7985,  ..., -0.3326, -2.0682,  3.2813],
         [-1.0040,  0.3287,  1.1457,  ..., -1.0747, -0.6557, -0.9607]],

        [[ 1.6326, -0.3635,  0.5933,  ...,  0.1177,  0.8728,  0.9375],
         [ 2.5529,  0.0835, -0.2958,  ..., -0.5103, -0.3954, -0.7982],
         [ 0.1695,  1.9433, -0.6700,  ...,  0.8310,  0.1287, -0.3896],
         ...,
         [-0.8056, -1.3714,  1.0232,  ...,  0.2947, -0.1565, -1.2067],
         [ 0.6961,  2.1302,  0.1545,  ..., -0.6888, -0.6213,  0.6619],
         [ 0.8564, -0.0458, -0.5370,  ...,  1.0753, -1.5578, -2.3208]]],
       grad_fn=<ConvolutionBackward0>)

In [76]:
torch.allclose(x, x2)

False

In [77]:
torch.allclose(x[0, 0, 0], x2[0, 0, 0])

True

In [78]:
x[0, 0, 0], x2[0, 0, 0]

(tensor(0.9094), tensor(0.9094, grad_fn=<SelectBackward0>))

In [81]:
outputs['loss']

tensor(1.4218, grad_fn=<MeanBackward0>)