In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from flow import ResidualCouplingBlock
from cbhg import CBHG

class VAE(nn.Module):
    def __init__(self, in_dim, z_dim):
        super().__init__()
        self.encoder = CBHG(in_dim, z_dim*2)
        self.decoder = CBHG(z_dim, in_dim)

    def forward(self, y):
        # y : (b, c, t)
        
        z_params = self.encoder(y)
        z_mean, z_logstd = z_params[:, :z_params.shape[1]//2], z_params[:, z_params.shape[1]//2:]
        z_sample = z_mean + torch.randn_like(z_logstd)*z_logstd.exp()
        y_pred = self.decoder(z_sample)
        recon_loss = F.l1_loss(y, y_pred)
        kl_loss = torch.mean(-z_logstd + 0.5 * (z_logstd.exp() ** 2 + z_mean ** 2) - 0.5)
        data = {'recon_loss': recon_loss,
                'kl_loss': kl_loss,
                'vae_loss': recon_loss + kl_loss,
                'z_sample': z_sample,
                'y_pred': y_pred,
               }
        return data
    
    def inference(self, z_sample):
        y_pred = self.decoder(y_sample)
        return y_pred
        
class Model(nn.Module):
    def __init__(self, in_dim, z_dim, out_dim):
        super().__init__()
        self.out_dim = out_dim
        self.z_dim = z_dim
        self.vae = VAE(out_dim, z_dim)
        self.encoder = CBHG(in_dim, z_dim*2)
        self.decoder = ResidualCouplingBlock(z_dim, 128, 5, 1, 4, gin_channels=0)
        
    def forward(self, x, y):
        # x : (b, c, t)
        # y : (b, c, t)
        
        vae_data = self.vae(y)
        enc = self.encoder(x)
        print(enc.shape)
        z_mean, z_logstd = enc.split(self.z_dim, dim=1)
        
        z_mask = torch.ones(x.shape[0], 1, 1).to(x.device)
        z, _ = self.decoder(vae_data['z_sample'], z_mask)
        flow_loss = torch.mean(z_logstd + 0.5 * ((z_mean - z) / z_logstd.exp()) ** 2)
        
        data = {'flow_loss': flow_loss,
                'total_loss': flow_loss + vae_loss,
               }
        data.update(vae_data)
        return data
    
    def inference(self, x):
        z_mean, z_logstd = self.encoder(x).split(self.z_dim, dim=1)
        z_sample = z_mean + torch.randn_like(z_logstd)*z_logstd.exp()
        y_pred = self.decoder(y_sample)
        return y_pred

In [20]:
model = Model(8, 16, 61)
x = torch.randn(2, 8, 100)
y = torch.randn(2, 61, 100)
data = model(x, y)

torch.Size([2, 32, 100])


In [21]:
data.keys()

dict_keys(['flow_loss', 'recon_loss', 'kl_loss', 'vae_loss', 'z_sample', 'y_pred'])

In [22]:
data['flow_loss']

tensor(0.4896, grad_fn=<MeanBackward0>)

In [23]:
data['recon_loss']

tensor(0.7981, grad_fn=<L1LossBackward0>)

In [24]:
data['kl_loss']

tensor(0.0033, grad_fn=<MeanBackward0>)