In [None]:
import json
import torch
import random
import IPython.display

import pytorch_lightning as pl
import matplotlib.pyplot as plt

from glob import glob
from time import time
from scipy.io.wavfile import write

from torch import nn
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint

from utils.data import *
from utils.model import *
from utils.metric import *
from utils.common_utils import *

plt.style.use("seaborn")

%load_ext blackcellmagic
%load_ext autoreload
%autoreload 2

In [None]:
# initialize model with GPU
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# load config file
with open("./config.json", "r") as f:
    config = json.load(f)

#### Data Load

In [None]:
# data here!
test_file_list = glob(config["test_folder_path"] + "/*")
print("The number of test: %d" % len(test_file_list))

# dataloader
test_params = {"batch_size": config["batch_size"], 
               "shuffle": False, 
               "pin_memory": True, 
               "num_workers": 4}

test_set = DataLoader(DatasetSampler(test_file_list), **test_params)

#### Model Here

In [None]:
# TransUNet
diffusion = TransUNet_Lightning(
    config["in_ch"],
    config["out_ch"],
    config["num_layers"],
    config["d_model"],
    config["latent_dim"],
    config["time_emb_dim"],
    config["time_steps"],
    rate=config["rate"],
)

diffusion = diffusion.load_from_checkpoint(
    config["tranunet_model_path"],
    in_ch=config["in_ch"],
    out_ch=config["out_ch"],
    num_layers=config["num_layers"],
    d_model=config["d_model"],
    latent_dim=config["latent_dim"],
    time_emb_dim=config["time_emb_dim"],
    time_steps=config["time_steps"],
    rate=config["rate"],
)

diffusion = diffusion.to(device)

In [None]:
# Simple Decoder
decoder = SimpleDecoder_Lightning(
    config["in_ch"], config["out_ch"], diffusion, config["latent_dim"]
)

decoder = decoder.load_from_checkpoint(
    config["decoder_model_path"],
    in_ch=config["in_ch"],
    out_ch=config["out_ch"],
    diffusion_model=diffusion,
    latent_dim=config["latent_dim"],
)

decoder = decoder.to(device)

#### Inference Here!

In [None]:
diffusion.eval()
decoder.eval()

gen_data = []
music_data = []
mixture_data = []
with torch.no_grad():
    for batch_idx, batch in enumerate(test_set):
        melody, mixture, music, track = batch
        shape = (music.shape[0], 5, 64, 72)
        
        mixture = mixture.to(device)
        data = diffusion(mixture, shape, device=device, eta=0, mode="ddim")
        data = decoder(data)
        
        data = nn.Sigmoid()(data)
        data = (data >= 0.5).to(torch.float32)
        data = data * mixture
        
        data = data.detach().cpu().numpy()
        data = np.transpose(data, [0, 2, 3, 1])
        
        music = music.detach().cpu().numpy()
        music = np.transpose(music, [0, 2, 3, 1])
        
        mixture = mixture.detach().cpu().numpy()
        mixture = np.transpose(mixture, [0, 2, 3, 1])
        
        gen_data.append(data)
        music_data.append(music)
        mixture_data.append(mixture)
        
gen_data = np.vstack(gen_data)
music_data = np.vstack(music_data)
mixture_data = np.vstack(mixture_data)

print("\ngen_data shape :", gen_data.shape)
print("music_data shape :", music_data.shape)
print("mixture_data shape :", mixture_data.shape)

#### Evaluation

In [None]:
# consistency loss
consistency = consistency_loss(mixture_data, gen_data)
print("consistency : %f" % (consistency))

# diversity loss
diversity = diversity_loss(mixture_data, music_data, gen_data)
print("diversity : %f" % (diversity))

In [None]:
event_time = 0.18
file_path = "./samples/"

for idx in range(5):
    pm1 = play_pianoroll(gen_data[idx], event_time=event_time)
    pm2 = play_pianoroll(music_data[idx], event_time=event_time)
    
    wav1 = pm1.fluidsynth(fs=16000).astype(np.float32)
    wav2 = pm2.fluidsynth(fs=16000).astype(np.float32)
    
    file_name_1 = str(idx+1) + "_music_from_mixture.wav"
    file_name_2 = str(idx+1) + "_original_music.wav"
    
    # # save midi file as wav
    # write(file_path + file_name_1, 16000, wav1.astype(np.float32))
    # write(file_path + file_name_2, 16000, wav2.astype(np.float32))

    # visualize pianoroll
    file_name = str(idx+1) + ".png"
    plot_two_pianoroll(gen_data[idx], music_data[idx], save_path=file_path+file_name,
                       SIZE=[10, 10], CHAR_FONT_SIZE=15, NUM_FONT_SIZE=13, LABEL_PAD=8)