<a href="https://colab.research.google.com/github/snaiws/MusicVAE/blob/main/MusicVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MusicVAE

## import

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd
import numpy as np
import os
import sys
import time
import random
import music21
import pickle

import torch
import torch.nn.functional as F
from torch import nn
import torch.optim as optim
from time import time

In [None]:
# device type
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"# available GPUs : {torch.cuda.device_count()}")
    print(f"GPU name : {torch.cuda.get_device_name()}")
else:
    device = torch.device("cpu")
print(device)

Found GPU at: /device:GPU:0


In [2]:
# random seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# data load
train = tfds.load(
    name="groove/full-midionly",
    split=tfds.Split.TRAIN,
    try_gcs=True)
test = tfds.load(
    name="groove/full-midionly",
    split=tfds.Split.TEST,
    try_gcs=True)
val = tfds.load(
    name="groove/full-midionly",
    split=tfds.Split.VALIDATION,
    try_gcs=True)
train = train.shuffle(1024).batch(1).prefetch(tf.data.experimental.AUTOTUNE)
val = val.shuffle(1024).batch(1).prefetch(tf.data.experimental.AUTOTUNE)
test = test.shuffle(1024).batch(1).prefetch(tf.data.experimental.AUTOTUNE)
train = [features["midi"].numpy()[0] for features in train.take(len(train))]
val = [features["midi"].numpy()[0] for features in val.take(len(val))]
test = [features["midi"].numpy()[0] for features in test.take(len(test))]

In [4]:
len(train),len(val),len(test)

(897, 124, 129)

## Preprocess

In [9]:
pitch_conv={
    'B-flat0':(22,'IDK1'),
    'D1':(26,'IDK2'),
    'C2':(36,'Bass Drum 1'),
    'C-sharp2':(37,'Side Stick'),
    'D2':(38,'Acoustic Snare'),
    'E2':(40,'Electric Snare'),
    'F-sharp2':(42,'Closed Hi-Hat'),
    'G2':(43,'High Floor Tom'),
    'G-sharp2':(44,'Pedal Hi-Hat'),
    'A2':(45,'Low Tom'),
    'B-flat2':(46,'Open Hi-Hat'),
    'B2':(47,'Low-Mid Tom'),
    'C3':(48,'Hi-Mid Tom'),
    'C-sharp3':(49,'Crash Cymbal 1'),
    'D3':(50,'High Tom'),
    'E-flat3':(51,'Ride Cymbal 1'),
    'E3':(52,'Chinese Cymbal'),
    'F3':(53,'Ride Bell'),
    'G3':(55,'Splash Cymbal'),
    'A3':(57,'Crash Cymbal 2'),
    'B-flat3':(58,'Vibraslap'),
    'B3':(59,'Ride Cymbal 2')
}

In [22]:
pitch_order = {'A2': 9,
 'A3': 19,
 'B-flat0': 0,
 'B-flat2': 10,
 'B-flat3': 20,
 'B2': 11,
 'B3': 21,
 'C-sharp2': 3,
 'C-sharp3': 13,
 'C2': 2,
 'C3': 12,
 'D1': 1,
 'D2': 4,
 'D3': 14,
 'E-flat3': 15,
 'E2': 5,
 'E3': 16,
 'F-sharp2': 6,
 'F3': 17,
 'G-sharp2': 8,
 'G2': 7,
 'G3': 18}

In [72]:
def preprocess(midi):
    midi = music21.converter.parse(midi)
    midi = list(midi.recurse())
    res = []
    for obj in midi:
      try:
        if obj.isRest:
          for i in range(int(obj.beat*36)):
            res.append([0 for _ in range(len(pitch_conv))])
        elif obj.isNote:
          a = [0 for _ in range(len(pitch_conv))]
          b = obj.fullName.split()
          b = b[0]+b[3]
          a[pitch_order[b]]=1
          res.append(a)
          for i in range(int(obj.beat*36)-1):
            res.append([0 for _ in range(len(pitch_conv))])
        elif obj.isChord:
          a = [0 for _ in range(len(pitch_conv))]
          b = obj.fullName
          b = b[b.find('{')+1:b.find('}')]
          b= b.split('|')
          for i in b:
            c=i.split(0)+i.split(3)
            a[pitch_order[c]]=1
          res.append(a)
          for i in range(int(obj.beat*36)-1):
            res.append([0 for _ in range(len(pitch_conv))])
      except:
        try:
          if str(obj).split('.')[1]=='meter' and str(obj).split('.')[-1]!='TimeSignature 4/4>':
            return 0
        except:
          continue
    return res

In [73]:
a=preprocess(train[5])

In [77]:
sum(list(map(sum,a)))

16

## embedding

## utils

In [None]:
def vae_loss(recon_x, x, mu, std, beta=0):
    logvar = std.pow(2).log()
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + (beta * KLD)


def accuracy(y_true, y_pred):
    y_true = torch.argmax(y_true, axis=2)
    total_num = y_true.shape[0] * y_true.shape[1]
    
    return torch.sum(y_true == y_pred) / total_num

In [None]:
def inverse_sigmoid(epoch, k=20):
    return k / (k + np.exp(epoch/k))

def kl_annealing(epoch, start, end, rate=0.9):
    return end + (start - end)*(rate)**epoch

## Model

In [None]:
class Encoder(nn.Module):
    """MusicVAE Encoder"""
        
    def __init__(self, input_size, hidden_size, latent_dim, num_layers=1, bidirectional=True):
        """Initialize class
     
        Parameters
        ----------
        input_size : dim of input sequence
        hidden_size : LSTM hidden size
        latent_dim : dim of latent z
        num_layers : the number of LSTM layers
        bidirectional : True or False
        """
            
        super(Encoder, self).__init__()
        
        if bidirectional == True:
            num_directions = 2
        else:
            num_directions = 1
            
        self.hidden_size = hidden_size
        self.num_hidden = num_directions * num_layers
        self.final_size = self.num_hidden * hidden_size
        
        self.lstm = nn.LSTM(batch_first=True,
                            input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            bidirectional=bidirectional)
        
        self.mu = nn.Linear(self.final_size, latent_dim)
        self.std = nn.Linear(self.final_size, latent_dim)
        self.norm = nn.LayerNorm(latent_dim, elementwise_affine=False)
        
    def encode(self, x):
        """
        Parameters
        ----------
        x : input sequecne (batch, seq, feat)
        
        Returns
        -------
        z : latent z (batch, latent_dim)
        mu : mu (batch, latent_dim)
        std : std (batch, latent_dim)
        """
        
        x, (h, c) = self.lstm(x)
        h = h.transpose(0, 1).reshape(-1, self.final_size)
        
        mu = self.norm(self.mu(h))
        std = nn.Softplus()(self.std(h))
        
        # reparam
        z = self.reparameterize(mu, std)
        
        return z, mu, std
    
    def reparameterize(self, mu, std):
        """
        Parameters
        ----------
        mu : mu (batch, latent_dim)
        std : std (batch, latent_dim)
        
        Returns
        -------
        z : reparam latent z (batch, latent_dim)
        """
            
        eps = torch.randn_like(std)

        return mu + (eps * std)
    
    def forward(self, x):
        """
        Parameters
        ----------
        x : input sequence (batch, seq, feat)
        
        Returns
        -------
        z : reparam latent z (batch, latent_dim)
        mu : mu (batch, latent_dim)
        std : std (batch, latent_dim)
        """
            
        z, mu, std = self.encode(x)
        
        return z, mu, std
    

class Decoder(nn.Module):
    """MusicVAE Decoder"""
    
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, bidirectional=False):
        """Initialize class
     
        Parameters
        ----------
        input_size : dim of input sequence
        hidden_size : dim of LSTM hidden size
        output_size : dim of output sequence
        num_layers : the number of LSTM layers
        bidirectional : True or False
        """
            
        super(Decoder, self).__init__()
        
        if bidirectional == True:
            num_directions = 2
        else:
            num_directions = 1
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_hidden = num_directions * num_layers

        self.logits= nn.Linear(hidden_size, output_size)
        self.decoder = nn.LSTM(batch_first=True,
                               input_size=input_size+output_size,
                               hidden_size=hidden_size,
                               num_layers=num_layers,
                               bidirectional=bidirectional)
        
    def forward(self, x, h, c, temp=1):
        """
        Parameters
        ----------
        x : input sequence (batch, 1, feat)
        h : LSTM state (num_hidden, batch, hidden_size)
        c : LSTM cell (num_hidden, batch, hidden_size)
        temp : temperature of softmax
        
        Returns
        -------
        out : predicted label (batch, 1, output_size)
        prob: predicted prob (batch, 1, output_size)
        h : LSTM next state
        c : LSTM next cell
        """
        
        x, (h, c) = self.decoder(x, (h, c))
        logits = self.logits(x) / temp
        prob = nn.Softmax(dim=2)(logits)
        out = torch.argmax(prob, 2)
                
        return out, prob, h, c

    
class Conductor(nn.Module):
    """MusicVAE Conductor"""
    
    def __init__(self, input_size, hidden_size, device, num_layers=2, bidirectional=False, bar=4):
        """Initialize class
     
        Parameters
        ----------
        input_size : dim of input sequence
        hidden_size : dim of LSTM hidden size
        output_size : dim of output sequence
        num_layers : the number of LSTM layers
        bidirectional : True or False
        bar : the number of units in bar
        """
        
        super(Conductor, self).__init__()

        if bidirectional == True:
            num_directions = 2
        else:
            num_directions = 1

        self.bar = bar
        self.device = device

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_hidden = num_directions * num_layers
        
        self.norm = nn.BatchNorm1d(input_size)
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.conductor = nn.LSTM(batch_first=True,
                                 input_size=input_size,
                                 hidden_size=hidden_size,
                                 num_layers=num_layers,
                                 bidirectional=bidirectional)
        
    def init_hidden(self, batch_size, z):
        h0 = z.repeat(self.num_hidden, 1, 1)
        c0 = z.repeat(self.num_hidden, 1, 1)

        return h0, c0
    
    def forward(self, z):
        """
        Parameters
        ----------
        z : latent z (batch, input_size)
        
        Returns
        -------
        feat : conductor feat (batch, bar_seq, hidden_size)
        """
            
        batch_size = z.shape[0]
        
        z = self.norm(z) # it is different from paper
        h, c = self.init_hidden(batch_size, z)
        z = z.unsqueeze(1)
        
        # initialize
        feat = torch.zeros(batch_size, self.bar, self.hidden_size, device=self.device)
        
        # conductor
        z_input = z
        for i in range(self.bar):
            z_input, (h, c) = self.conductor(z_input, (h, c))
            feat[:, i, :] = z_input.squeeze()
            z_input = z
            
        feat = self.linear(feat)
            
        return feat
    

class Hierarchical_Decoder(nn.Module):
    """MusicVAE Hierarchical Decoder"""
    
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, bidirectional=False):
        """Initialize class
     
        Parameters
        ----------
        input_size : dim of input sequence
        hidden_size : dim of LSTM hidden size
        output_size : dim of output sequence
        num_layers : the number of LSTM layers
        bidirectional : True or False
        """
            
        super(Hierarchical_Decoder, self).__init__()
        
        if bidirectional == True:
            num_directions = 2
        else:
            num_directions = 1
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_hidden = num_directions * num_layers
        
        self.logits= nn.Linear(hidden_size, output_size)
        self.decoder = nn.LSTM(batch_first=True,
                               input_size=input_size+output_size,
                               hidden_size=hidden_size,
                               num_layers=num_layers,
                               bidirectional=bidirectional)
        
    def forward(self, x, h, c, z, temp=1):
        """
        Parameters
        ----------
        x : input sequence (batch, 1, feat)
        h : LSTM state (num_hidden, batch, hidden_size)
        c : LSTM cell (num_hidden, batch, hidden_size)
        z : concat feature
        temp : temperature of softmax
        
        Returns
        -------
        out : predicted label (batch, 1, output_size)
        prob: predicted prob (batch, 1, output_size)
        h : LSTM next state
        c : LSTM next cell
        """
            
        x = torch.cat((x, z.unsqueeze(1)), 2)
        
        x, (h, c) = self.decoder(x, (h, c))
        logits = self.logits(x) / temp
        prob = nn.Softmax(dim=2)(logits)
        out = torch.argmax(prob, 2)
                
        return out, prob, h, c

## Train

In [None]:
def flat_train(device, loss_fn, train_loader, val_loader, model, optimizer, temp=1, epochs=100):
    history = {}
    history['train_loss'] = []
    history['train_acc'] = []
    history['val_loss'] = []
    history['val_acc'] = []
    
    encoder, decoder = model
    enc_optimizer, dec_optimizer = optimizer
    
    hidden_size = decoder.hidden_size
    num_hidden = decoder.num_hidden
    output_size = decoder.output_size
    
    enc_scheduler = optim.lr_scheduler.CosineAnnealingLR(enc_optimizer, epochs, eta_min=1e-6)
    dec_scheduler = optim.lr_scheduler.CosineAnnealingLR(dec_optimizer, epochs, eta_min=1e-6)
    
    for i in range(1, epochs+1):
        start_time = time()
        
        train_loss = 0
        train_acc = 0
        
        val_loss = 0
        val_acc = 0
        
        ### train
        encoder.train()
        decoder.train()
        for batch_idx, x_train in enumerate(train_loader):
            x_train = x_train.to(device)
            
            batch_size = x_train.shape[0]
            seq_len = x_train.shape[1]
            
            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()
            
            # encoder
            z, x_train_mu, x_train_std = encoder(x_train)

            # initialize
            h = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
            c = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
            
            x_train_inputs = torch.zeros((batch_size, 1, x_train.shape[2]), device=device)
            x_train_inputs = torch.cat((x_train_inputs, z.unsqueeze(1)), 2)
            x_train_label = torch.zeros(x_train.shape[:-1], device=device) # argmax
            x_train_prob = torch.zeros(x_train.shape, device=device) # prob

            # forward
            for j in range(seq_len):
                label, prob, h, c = decoder(x_train_inputs, h, c, temp=1)

                x_train_label[:, j] = label.squeeze()
                x_train_prob[:, j, :] = prob.squeeze()
                
                # scheduled sampling
                if np.random.binomial(1, inverse_sigmoid(i)):
                    # teacher forcing
                    x_train_inputs = torch.cat((x_train[:, j, :], z), 1).unsqueeze(1)
                else:
                    # sampling
                    label = F.one_hot(label, num_classes=output_size)
                    x_train_inputs = torch.cat((label, z.unsqueeze(1)), 2)
            
            # loss
            beta = kl_annealing(i, 0, 0.2)
            loss = loss_fn(x_train_prob, x_train, x_train_mu, x_train_std, beta)
            
            # backward
            loss.backward()
            enc_optimizer.step()
            dec_optimizer.step()
            
            train_loss += loss.item()
            train_acc += accuracy(x_train, x_train_label).item()
            
        enc_scheduler.step()
        dec_scheduler.step()
        
        train_loss = train_loss / (batch_idx + 1)
        train_acc = train_acc / (batch_idx + 1)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        ### validation
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            for batch_idx, x_val in enumerate(val_loader):
                x_val = x_val.to(device)
                
                batch_size = x_val.shape[0]
                seq_len = x_val.shape[1]
                
                # forward encoder
                z, x_val_mu, x_val_std = encoder(x_val)
                
                # initialize
                h = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
                c = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
                
                # full sampling
                x_val_inputs = torch.zeros((batch_size, 1, x_val.shape[2]), device=device)
                x_val_inputs = torch.cat((x_val_inputs, z.unsqueeze(1)), 2)
                x_val_label = torch.zeros(x_val.shape[:-1], device=device) # argmax
                x_val_prob = torch.zeros(x_val.shape, device=device) # prob
                
                # forward
                for j in range(seq_len):
                    label, prob, h, c = decoder(x_val_inputs, h, c, temp=1)
                    
                    x_val_label[:, j] = label.squeeze()
                    x_val_prob[:, j, :] = prob.squeeze()
                    
                    label = F.one_hot(label, num_classes=output_size)
                    x_val_inputs = torch.cat((label, z.unsqueeze(1)), 2)
                
                loss = loss_fn(x_val_prob, x_val, x_val_mu, x_val_std, beta)
                
                val_loss += loss.item()
                val_acc += accuracy(x_val, x_val_label).item()
                
        val_loss = val_loss / (batch_idx + 1)
        val_acc = val_acc / (batch_idx + 1)
        
        history['val_loss'].append(val_loss)
        history['val_acc'] .append(val_acc)
        
        print('Epoch %d (%0.2f sec) - train_loss: %0.3f, train_acc: %0.3f, val_loss: %0.3f, val_acc: %0.3f, lr: %0.6f' % \
             (i, time()-start_time, train_loss, train_acc, val_loss, val_acc, enc_scheduler.get_last_lr()[0]))
        
    return history


def hierarchical_train(device, loss_fn, train_loader, val_loader, model, optimizer, bar_units=16, epochs=100):
    history = {}
    history['train_loss'] = []
    history['train_acc'] = []
    history['val_loss'] = []
    history['val_acc'] = []
    
    encoder, conductor, decoder = model
    enc_optimizer, con_optimizer, dec_optimizer = optimizer
    
    hidden_size = decoder.hidden_size
    num_hidden = decoder.num_hidden
    output_size = decoder.output_size
    
    enc_scheduler = optim.lr_scheduler.CosineAnnealingLR(enc_optimizer, epochs, eta_min=1e-6)
    con_scheduler = optim.lr_scheduler.CosineAnnealingLR(con_optimizer, epochs, eta_min=1e-6)
    dec_scheduler = optim.lr_scheduler.CosineAnnealingLR(dec_optimizer, epochs, eta_min=1e-6)
    
    for i in range(1, epochs+1):
        start_time = time()
        
        train_loss = 0
        train_acc = 0
        
        val_loss = 0
        val_acc = 0
        
        encoder.train()
        conductor.train()
        decoder.train()
        for batch_idx, x_train in enumerate(train_loader):
            x_train = x_train.to(device)
            
            batch_size = x_train.shape[0]
            seq_len = x_train.shape[1]
            
            enc_optimizer.zero_grad()
            con_optimizer.zero_grad()
            dec_optimizer.zero_grad()
            
            # forward
            x_train_z, x_train_mu, x_train_std = encoder(x_train)
            x_train_feat = conductor(x_train_z)
            
            # initialize    
            x_train_inputs = torch.zeros((batch_size, 1, x_train.shape[2]), device=device)
            x_train_label = torch.zeros(x_train.shape[:-1], device=device) # argmax
            x_train_prob = torch.zeros(x_train.shape, device=device) # prob
            
            # teacher forcing
            for j in range(seq_len):
                bar_idx = j // bar_units
                bar_change_idx = j % bar_units
                
                z = x_train_feat[:, bar_idx, :]
                
                # init state
                if bar_change_idx == 0:
                    h = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
                    c = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
                
                label, prob, h, c = decoder(x_train_inputs, h, c, z)
                
                x_train_label[:, j] = label.squeeze()
                x_train_prob[:, j, :] = prob.squeeze()
                
                # teacher forcing
                if np.random.binomial(1, inverse_sigmoid(i)):
                    x_train_inputs = x_train[:, j, :].unsqueeze(1)
                else:
                    x_train_inputs = F.one_hot(label, num_classes=output_size)
            
            beta = kl_annealing(i, 0, 0.2)
            loss = loss_fn(x_train_prob, x_train, x_train_mu, x_train_std, beta)
            
            # backward
            loss.backward()
            enc_optimizer.step()
            con_optimizer.step()
            dec_optimizer.step()
            
            train_loss += loss.item()
            train_acc += accuracy(x_train, x_train_label).item()
            
        enc_scheduler.step()
        con_scheduler.step()
        dec_scheduler.step()
        
        train_loss = train_loss / (batch_idx + 1)
        train_acc = train_acc / (batch_idx + 1)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        encoder.eval()
        conductor.eval()
        decoder.eval()
        with torch.no_grad():
            for batch_idx, x_val in enumerate(val_loader):
                x_val = x_val.to(device)
                
                batch_size = x_val.shape[0]
                seq_len = x_val.shape[1]
                
                # forward
                x_val_z, x_val_mu, x_val_std = encoder(x_val)
                x_val_feat = conductor(x_val_z)
                
                # initialize
                x_val_inputs = torch.zeros((batch_size, 1, x_val.shape[2]), device=device)  
                x_val_label = torch.zeros(x_val.shape[:-1], device=device) # argmax
                x_val_prob = torch.zeros(x_val.shape, device=device) # prob
                
                # full sampling
                for j in range(seq_len):
                    bar_idx = j // bar_units
                    bar_change_idx = j % bar_units
                    
                    z = x_val_feat[:, bar_idx, :]
                
                    # init state
                    if bar_change_idx == 0:
                        h = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
                        c = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
                    
                    label, prob, h, c = decoder(x_val_inputs, h, c, z)

                    x_val_label[:, j] = label.squeeze()
                    x_val_prob[:, j, :] = prob.squeeze()

                    # full sampling
                    x_val_inputs = F.one_hot(label, num_classes=output_size)
            
                loss = loss_fn(x_val_prob, x_val, x_val_mu, x_val_std)
                
                val_loss += loss.item()
                val_acc += accuracy(x_val, x_val_label).item()
                
        val_loss = val_loss / (batch_idx + 1)
        val_acc = val_acc / (batch_idx + 1)
        
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print('Epoch %d (%0.2f sec) - train_loss: %0.3f, train_acc: %0.3f, val_loss: %0.3f, val_acc: %0.3f, lr: %0.6f' % \
             (i, time()-start_time, train_loss, train_acc, val_loss, val_acc, enc_scheduler.get_last_lr()[0]))
        
    return history

## test

In [None]:
def flat_test(device, loss_fn, test_loader, model, temp=1, options='teacher_forcing'):
    history = {}
    history['test_loss'] = []
    history['test_acc'] = []
    
    encoder, decoder = model
    
    start_time = time()

    test_loss = 0
    test_acc = 0
    
    y_true = []
    y_pred = []

    encoder.eval()
    decoder.eval()
    
    num_hidden = decoder.num_hidden
    hidden_size = decoder.hidden_size
    output_size = decoder.output_size
    
    with torch.no_grad():
        for batch_idx, x_test in enumerate(test_loader):
            x_test = x_test.to(device)
            
            batch_size = x_test.shape[0]
            seq_len = x_test.shape[1]

            # forward
            z, x_test_mu, x_test_std = encoder(x_test)
            
            h = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
            c = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
            
            x_test_inputs = torch.zeros((batch_size, 1, x_test.shape[2]), device=device)
            x_test_inputs = torch.cat((x_test_inputs, z.unsqueeze(1)), 2)
            x_test_label = torch.zeros(x_test.shape[:-1], device=device) # argmax
            x_test_prob = torch.zeros(x_test.shape, device=device) # prob
            
            for j in range(seq_len):
                label, prob, h, c = decoder(x_test_inputs, h, c, temp=temp)

                x_test_label[:, j] = label.squeeze()
                x_test_prob[:, j, :] = prob.squeeze()
                
                if options == 'teacher_forcing':
                    x_test_inputs = torch.cat((x_test[:, j, :], z), 1).unsqueeze(1)
                else:
                    label = F.one_hot(label, num_classes=output_size)
                    x_test_inputs = torch.cat((label, z.unsqueeze(1)), 2)

            loss = loss_fn(x_test_prob, x_test, x_test_mu, x_test_std)

            test_loss += loss.item()
            test_acc += accuracy(x_test, x_test_label).item()
            
            if batch_idx % 10000 == 0:
                y_true.append(x_test.data.cpu().numpy())
                y_pred.append(x_test_prob.data.cpu().numpy())

    test_loss = test_loss / (batch_idx + 1)
    test_acc = test_acc / (batch_idx + 1)

    history['test_loss'].append(test_loss)
    history['test_acc'] .append(test_acc)

    print('(%0.2f sec) - test_loss: %0.3f, test_acc: %0.3f' % (time()-start_time, test_loss, test_acc))
        
    return history, np.vstack(y_true), np.vstack(y_pred)


def hierarchical_test(device, loss_fn, test_loader, model, temp=1, bar_units=16, options='teacher_forcing'):
    history = {}
    history['test_loss'] = []
    history['test_acc'] = []
    
    encoder, conductor, decoder = model
    
    start_time = time()

    test_loss = 0
    test_acc = 0
    
    y_true = []
    y_pred = []

    encoder.eval()
    conductor.eval()
    decoder.eval()
    
    num_hidden = decoder.num_hidden
    hidden_size = decoder.hidden_size
    output_size = decoder.output_size

    with torch.no_grad():
        for batch_idx, x_test in enumerate(test_loader):
            x_test = x_test.to(device)
            batch_size = x_test.shape[0]
            seq_len = x_test.shape[1]

            # forward
            x_test_z, x_test_mu, x_test_std = encoder(x_test)
            x_test_feat = conductor(x_test_z)
            
            # initialize
            x_test_inputs = torch.zeros((batch_size, 1, x_test.shape[2]), device=device)
            x_test_label = torch.zeros(x_test.shape[:-1], device=device) # argmax
            x_test_prob = torch.zeros(x_test.shape, device=device) # prob
            
            for j in range(seq_len):
                bar_idx = j // bar_units
                bar_change_idx = j % bar_units

                z = x_test_feat[:, bar_idx, :]

                # init state
                if bar_change_idx == 0:
                    h = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
                    c = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))

                label, prob, h, c = decoder(x_test_inputs, h, c, z)

                x_test_label[:, j] = label.squeeze()
                x_test_prob[:, j, :] = prob.squeeze()
                
                if options == 'teacher_forcing':
                    x_test_inputs = x_test[:, j, :].unsqueeze(1)
                else:
                    x_test_inputs = F.one_hot(label, num_classes=output_size)

            loss = loss_fn(x_test_prob, x_test, x_test_mu, x_test_std)

            test_loss += loss.item()
            test_acc += accuracy(x_test, x_test_label).item()
            
            if batch_idx % 10000 == 0:
                y_true.append(x_test.data.cpu().numpy())
                y_pred.append(x_test_prob.data.cpu().numpy())

    test_loss = test_loss / (batch_idx + 1)
    test_acc = test_acc / (batch_idx + 1)

    history['test_loss'].append(test_loss)
    history['test_acc'] .append(test_acc)

    print('(%0.2f sec) - test_loss: %0.3f, test_acc: %0.3f' % (time()-start_time, test_loss, test_acc))
        
    return history, np.vstack(y_true), np.vstack(y_pred)

## Output