In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from flow import ResidualCouplingBlock

In [2]:
inter_channels = 128
hidden_channels = 256
gin_channels = 0
flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
flow

ResidualCouplingBlock(
  (flows): ModuleList(
    (0): ResidualCouplingLayer(
      (pre): Conv1d(64, 256, kernel_size=(1,), stride=(1,))
      (enc): WN(
        (in_layers): ModuleList(
          (0): Conv1d(256, 512, kernel_size=(5,), stride=(1,), padding=(2,))
          (1): Conv1d(256, 512, kernel_size=(5,), stride=(1,), padding=(2,))
          (2): Conv1d(256, 512, kernel_size=(5,), stride=(1,), padding=(2,))
          (3): Conv1d(256, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        )
        (res_skip_layers): ModuleList(
          (0): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
          (1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
          (2): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
          (3): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
        )
        (drop): Dropout(p=0, inplace=False)
      )
      (post): Conv1d(256, 64, kernel_size=(1,), stride=(1,))
    )
    (1): Flip()
    (2): ResidualCouplingLayer(
      (pre): Conv1d(64, 25

In [3]:
x = torch.randn(2, 128, 100)
x_mask = torch.ones(x.shape[0], 1, 1)
y, logdet = flow(x, x_mask=x_mask)
print(y.shape, logdet)

x_recon = flow(y, x_mask=x_mask, reverse=True)
x == x_recon

torch.Size([2, 128, 100]) tensor([0., 0.])


tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from flow import ResidualCouplingBlock
from cbhg import CBHG

class Model(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.out_dim = out_dim
        self.z_dim = out_dim if out_dim % 2 == 0 else out_dim + 1
        self.encoder = CBHG(in_dim, self.z_dim)
        self.decoder = ResidualCouplingBlock(self.z_dim, 256, 5, 1, 4, gin_channels=0)
        
    def get_loss(self, z, logdet):
        dim = z.size(1) * z.size(2)
        nll = -(torch.sum(-0.5 * (np.log(2*np.pi) + z**2), dim=[1, 2]) + logdet)
        loss = torch.mean(nll / dim)
        return loss
            
    def forward(self, x, y):
        # x : (b, c, t)
        # y : (b, c, t)
        
        if y.shape[1] != self.z_dim:
            y = F.pad(y, (0, 0, 0, 1))
        
        y_mask = torch.ones(y.shape[0], 1, 1).to(x.device)
        z, logdet = self.decoder(y, y_mask)
        flow_loss = self.get_loss(z, logdet)
        
        z_pred = self.encoder(x)
        reg_loss = F.l1_loss(z_pred, z.detach())
        
        data = {'flow_loss': flow_loss,
                'reg_loss': reg_loss,
                'total_loss': flow_loss + reg_loss}
        return data
    
    def inference(self, x):
        z = self.encoder(x)
        y_mask = torch.ones(z.shape[0], 1, 1).to(x.device)
        y = self.decoder(z, y_mask, reverse=True)
        y = y[:, :self.out_dim]
        return y

In [47]:
model = Model(8, 65)
model

Model(
  (encoder): CBHG(
    (conv_bank): ModuleList(
      (0): Sequential(
        (0): ConstantPad1d(padding=(0, 0), value=0.0)
        (1): Conv1d(8, 256, kernel_size=(1,), stride=(1,))
        (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
      (1): Sequential(
        (0): ConstantPad1d(padding=(0, 1), value=0.0)
        (1): Conv1d(8, 256, kernel_size=(2,), stride=(1,))
        (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
      (2): Sequential(
        (0): ConstantPad1d(padding=(1, 1), value=0.0)
        (1): Conv1d(8, 256, kernel_size=(3,), stride=(1,))
        (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
      (3): Sequential(
        (0): ConstantPad1d(padding=(1, 2), value=0.0)
        (1): Conv1d(8, 256, kernel_size=(4,), stride=(1,))
        (2): BatchNorm1d(256, eps=1

In [48]:
x = torch.randn(2, 8, 100)
y = torch.randn(2, 65, 100)
model(x, y)

{'flow_loss': tensor(1.4175, grad_fn=<MeanBackward0>),
 'reg_loss': tensor(0.7934, grad_fn=<L1LossBackward0>),
 'total_loss': tensor(2.2109, grad_fn=<AddBackward0>)}

In [49]:
model.inference(x)

tensor([[[-0.0604, -0.0547, -0.0443,  ..., -0.0587, -0.0729, -0.0769],
         [-0.0002,  0.0245,  0.0423,  ...,  0.0720,  0.0446,  0.0189],
         [-0.0461, -0.0383, -0.0422,  ..., -0.0041,  0.0062,  0.0218],
         ...,
         [ 0.0605,  0.0596,  0.0820,  ...,  0.0508,  0.0610,  0.0498],
         [ 0.0673,  0.0502,  0.0612,  ...,  0.0469,  0.0535,  0.0508],
         [ 0.0708,  0.1026,  0.1132,  ...,  0.1429,  0.1508,  0.1366]],

        [[-0.0429, -0.0666, -0.0842,  ..., -0.0367, -0.0663, -0.0831],
         [ 0.0107,  0.0226, -0.0137,  ...,  0.0703,  0.0957,  0.0437],
         [-0.0184, -0.0129, -0.0038,  ...,  0.0316,  0.0266,  0.0097],
         ...,
         [ 0.0764,  0.0643,  0.0632,  ...,  0.0952,  0.0859,  0.0814],
         [ 0.0228,  0.0371,  0.0323,  ...,  0.0541,  0.0735,  0.0603],
         [ 0.0952,  0.1177,  0.1132,  ...,  0.1255,  0.1221,  0.0855]]],
       grad_fn=<SliceBackward0>)