In [5]:
from model import Encoder, Decoder, Seq2Seq
from data_loader import *
import pandas as pd
import torch.optim.lr_scheduler as lr_scheduler
from torch import optim
import torch.nn.functional as F
import datetime
import pretty_midi
import glob
import librosa

In [6]:
import os
import matplotlib
import math
matplotlib.use('Agg')
# matplotlib.use("QtAgg")
import ffmpeg
#conda install -c conda-forge ffmpeg-python

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, writers
plt.rcParams['animation.ffmpeg_path'] = '/home/ilc/anaconda3/bin/ffmpeg'#'/usr/bin/ffmpeg'

import numpy as np
import subprocess as sp
from moviepy.video.io.VideoFileClip import VideoFileClip
from moviepy.audio.io.AudioFileClip import AudioFileClip

from midi2audio import FluidSynth

from torch.autograd import Variable 

In [7]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [8]:
dataset_name_path = f"./both_list_symbolic.txt" #f"./midi_list.txt"
dataloader = get_all_dataloader(dataset_name_path, batch_size=128) #[20, 512, 128], [20, 512, 102]

val_dataset_name_path = f"./both_list_eval_symbolic.txt" #f"./midi_list_eval.txt"
val_dataloader = get_all_val_dataloader(val_dataset_name_path, batch_size=128) #[20, 512, 128], [20, 512, 102]

learning_rate = 0.001#0.001

input_size_encoder = 28 #128 #129 #128
input_size_decoder = 112 #102 #24
output_size = 112#102 #24

# encoder_embedding_size = 300
# decoder_embedding_size = 300
enc_dropout = 0.5
dec_dropout = 0.
step = 0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

self.piece_count:  105
dataset_len:  10500
val_dataset_len 5
cuda:0


In [9]:
class LSTM1(nn.Module):
    def __init__(self, output_dim, input_size, hidden_size, num_layers, seq_length):
        super(LSTM1, self).__init__()
        self.output_dim = output_dim #number of classes
        self.num_layers = num_layers #number of layers
        self.input_size = input_size #input size
        self.hidden_size = hidden_size #hidden state
        self.seq_length = seq_length #sequence length

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                          num_layers=num_layers, batch_first=True) #lstm
        self.fc_1 =  nn.Linear(hidden_size, output_dim) #fully connected to determine output dim

        self.relu = nn.ReLU()

    def forward(self,x):
        # h0, c0 no time information
        h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)).to(device) #hidden state
        c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)).to(device) #internal state
        # Propagate input through LSTM
        # x is MIDI => [44, 512, 128]

        # hn is final state, run over the sequence length
        output, (hn, cn) = self.lstm(x, (h_0, c_0)) #lstm with input, hidden, and internal state
        # hn = hn.view(-1, self.hidden_size) #reshaping the data for Dense layer next
        # print("output.shape", output.shape)
        # print("hn.shape", hn.shape)
        # out = self.relu(hn)
        out = self.fc_1(output) #final
        return out
 

In [10]:
# Define the model architecture
input_size = 156 #number of features
hidden_size = 1024 #number of features in hidden state
num_layers = 1 #number of stacked lstm layers
seq_len = 512
output_dim = 112 #number of output classes

# model = LSTM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights).to(device)
# model = LSTM(embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights).to(device)
model = LSTM1(output_dim, input_size, hidden_size, num_layers, seq_len).to(device) #our lstm class
model.train()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

num_epochs = 100 #10
avg_loss_list = []
all_loss_list = []
val_loss_per_epoch_list = []

In [11]:
def customized_mse_loss(output, target, prev_output, midi_array):
    # target = target.transpose(0, 1)
    # print("output", output)
    # print(output.shape) #torch.Size([20, 513, 102])
    # print(target.shape) #torch.Size([20, 513, 102])
    mse_loss = F.mse_loss(output, target)
    # print("mse_loss:", mse_loss)

    var_diff = torch.var(torch.squeeze(output), dim=1, keepdim=True)
    mean_diff = torch.mean(var_diff)
    
    # Condition 1: Penalize if output is similar to previous output
    if mean_diff < 1e-4: #threshold
        #output [512, 1, 102] => [102] <-> [102] <-> [102] <-> ... <-> [102]
        mse_loss *= 1000

    # Condition 2: Stop movement if input is all zeros
    # midi_transpose = midi_array.transpose(0, 1)
    # midi_sum_row = torch.sum(midi_transpose, dim=-1)
    # mask = midi_sum_row == 0
    # mask = mask.unsqueeze(-1)
    # mask = mask.to(device)
    # # according to recorded index, make a mask [0, 1, 1, 0, ..., 0], true part will be omit(set value to 0).
    # # before compute mse, use mask first to tensor, then caculate MES loss
    # masked_output = output.masked_fill(mask, 0) #inplace function
    # masked_target = target.masked_fill(mask, 0)
    # mse_loss += F.mse_loss(masked_output, masked_target) * 100 #output 和 previous output 不像的話，增大 loss

    # Condition 3: Penalize if right-hand movement is too different between outputs
    # if output.shape[-1] == 21:  # Assumes hand joints are the last 21 dimensions
    #     rh_indices = [i for i in range(12, 21)]  # Right-hand joint indices
    #     rh_output = output[..., rh_indices]
    #     rh_prev_output = prev_output[..., rh_indices]
    #     rh_loss = nn.functional.mse_loss(rh_output, rh_prev_output)
    #     if rh_loss > 0.1:
    #         mse_loss *= 1000

    return mse_loss

In [12]:
def evaluate_lstm(model, val_dataloader):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    counter = 0
    # previous_output = torch.zeros(512, 102).to(device)
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(val_dataloader): #tqdm(enumerate(val_dataloader), total=len(val_dataloader))
            counter += 1

            inputs = inputs.to(device).float()
            targets = targets.to(device).float()
            print("inputs.shape:", inputs.shape)
            print("targets.shape:", targets.shape)
            outputs = model(inputs)
            print("outputs.shape:", outputs.shape)

            loss =  F.mse_loss(outputs, targets)
            valid_running_loss += loss.cpu().item()
            # previous_output = outputs

    epoch_val_loss = valid_running_loss / counter
    return epoch_val_loss

In [13]:
def evaluate_l1(model, val_dataloader):
    model.eval()
    print('Validation L1')
    valid_running_loss = 0.0
    counter = 0
    # previous_output = torch.zeros(512, 102).to(device)
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(val_dataloader): #tqdm(enumerate(val_dataloader), total=len(val_dataloader))
            counter += 1

            inputs = inputs.to(device).float()
            targets = targets.to(device).float()
            print("inputs.shape:", inputs.shape)
            print("targets.shape:", targets.shape)
            outputs = model(inputs)
            print("outputs.shape:", outputs.shape)

            loss =  F.l1_loss(outputs, targets)
            valid_running_loss += loss.cpu().item()
            # previous_output = outputs

    epoch_val_loss = valid_running_loss / counter
    return epoch_val_loss

In [14]:
# train the model
for epoch in range(num_epochs):
    # previous_output = torch.zeros(1, 512, 102).to(device)
    losses = []
    for i, (audio_batch, motion_batch) in enumerate(dataloader):
        model.train()
        
        audio_batch = audio_batch.to(device).float()
        motion_batch = motion_batch.to(device).float()
        # print("audio_batch", audio_batch.shape)
        # print("motion_batch", motion_batch.shape)

        optimizer.zero_grad()
        output = model(audio_batch) #audio_batch
        # print("output.shape", output.shape)

        # motion_ground_truth_padding = F.pad(motion_batch, (0,0,0,1), value = 1) #<eot>
        
        # loss =  F.mse_loss(output, motion_ground_truth_padding)
        loss =  F.mse_loss(output, motion_batch)
        # loss = customized_mse_loss(output, motion_ground_truth_padding, previous_output, midi_batch)
        # loss = customized_mse_loss(output, motion_batch, previous_output, midi_batch)

        # losses 累計lose
        losses.append(loss.cpu().item())
        all_loss_list.append(loss.cpu().item())
        loss.backward()

        optimizer.step()
        mean_loss = sum(losses)/len(losses)

        print(f"Epoch {epoch}, batch {i}: loss = {loss.cpu().item():.6f}")

        # scheduler.step(1)
        # previous_output = output

        loc_dt = datetime.datetime.today()
        loc_dt_format = loc_dt.strftime("%Y-%m-%d_%H-%M-%S")

    val_loss = evaluate_lstm(model, val_dataloader) #CUDA out of memory
    val_loss_per_epoch_list.append(val_loss)
    print(f"Epoch {epoch}: val_loss = {val_loss:.6f}")
    # save_best_model(
    #         val_loss, epoch, model, optimizer, loss, loc_dt_format, mean_loss
    #     )
    avg_loss_list.append(mean_loss)
    loc_dt = datetime.datetime.today()
    loc_dt_format = loc_dt.strftime("%Y-%m-%d_%H-%M-%S")
    if (epoch+1)%100 == 0:
        torch.save({
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
            'loss':loss
        }, "./model_save/[audio]LSTM_save_epoch_" + str(epoch)+ "_"+ str(loc_dt_format) + "_avg_loss_" + str(mean_loss) +".tar")

Epoch 0, batch 0: loss = 0.239317
Epoch 0, batch 1: loss = 0.076427
Epoch 0, batch 2: loss = 0.042218
Epoch 0, batch 3: loss = 0.045830
Epoch 0, batch 4: loss = 0.042866
Epoch 0, batch 5: loss = 0.036171
Epoch 0, batch 6: loss = 0.030401
Epoch 0, batch 7: loss = 0.024924
Epoch 0, batch 8: loss = 0.022711
Epoch 0, batch 9: loss = 0.019896
Epoch 0, batch 10: loss = 0.018655
Epoch 0, batch 11: loss = 0.018178
Epoch 0, batch 12: loss = 0.019470
Epoch 0, batch 13: loss = 0.019169
Epoch 0, batch 14: loss = 0.018379
Epoch 0, batch 15: loss = 0.016918
Epoch 0, batch 16: loss = 0.016931
Epoch 0, batch 17: loss = 0.016836
Epoch 0, batch 18: loss = 0.015510
Epoch 0, batch 19: loss = 0.016195
Epoch 0, batch 20: loss = 0.014126
Epoch 0, batch 21: loss = 0.015149
Epoch 0, batch 22: loss = 0.015129
Epoch 0, batch 23: loss = 0.015155
Epoch 0, batch 24: loss = 0.015436
Epoch 0, batch 25: loss = 0.013678
Epoch 0, batch 26: loss = 0.013971
Epoch 0, batch 27: loss = 0.013979
Epoch 0, batch 28: loss = 0.01

In [None]:
print(loc_dt_format)
print(avg_loss_list)

2023-07-18_20-18-33
[0.014088256811969969, 0.012748566265386271, 0.012135430968489992, 0.011783327020884278, 0.011665332321840596, 0.011348000278763742, 0.011211590673101235, 0.01126363599695355, 0.010981195034032845, 0.011041705530271473, 0.010652564320129803, 0.010493768816701618, 0.010436136421399662, 0.010394138800182256, 0.010266933163785073, 0.010155179560274244, 0.010048930421022766, 0.009978100118866885, 0.009849181370697466, 0.009758815090789134, 0.009883428014904619, 0.00967351151966905, 0.00962024238184992, 0.009702010966657874, 0.009488175634339631, 0.0092398691249181, 0.009143874744873449, 0.009248467532536352, 0.009218326792210699, 0.009097143457865858, 0.009768575643110707, 0.009479412405067179, 0.009325702079986951, 0.008976236182119113, 0.008987353239432875, 0.008953133721398302, 0.010966492143560606, 0.009526901440807136, 0.009302870774394777, 0.009145709394241672, 0.009100556171622622, 0.008991728065512985, 0.00887559393577906, 0.008845821527921292, 0.008874942918857

In [None]:
print(val_loss_per_epoch_list)

[0.01278438325971365, 0.012452135793864727, 0.012619588524103165, 0.012583580799400806, 0.012676828540861607, 0.012934515252709389, 0.013072746805846691, 0.0120470579713583, 0.011598094366490841, 0.012044154107570648, 0.012224655598402023, 0.012943221256136894, 0.011343262158334255, 0.01114050392061472, 0.010922430083155632, 0.011694221757352352, 0.011601006612181664, 0.011959804221987724, 0.01255539245903492, 0.012239864096045494, 0.012193874455988407, 0.012043044902384281, 0.012772328220307827, 0.011544476263225079, 0.011981447227299213, 0.011674492619931698, 0.011470047757029533, 0.012049633078277111, 0.012382530607283115, 0.012579183094203472, 0.011807631701231003, 0.014613443985581398, 0.012081064283847809, 0.012330255471169949, 0.011556034907698631, 0.033256713300943375, 0.012245435267686844, 0.011719465255737305, 0.011602915823459625, 0.01261739432811737, 0.012613462284207344, 0.011865119449794292, 0.01175965927541256, 0.01231275126338005, 0.012440602295100689, 0.011565727181732

In [None]:
# def lr_lambda(epoch):
#     # LR to be 0.1 * (1/1+0.01*epoch)
#     base_lr = 0.1
#     factor = 0.01
#     return base_lr/(1+factor*epoch)

In [None]:
# scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=10)
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
plt.cla()
plt.clf()

In [None]:
print(len(avg_loss_list))
avg_loss_list_dataframe = pd.DataFrame(avg_loss_list)

100


In [None]:
avg_loss_list_dataframe

Unnamed: 0,0
0,0.014088
1,0.012749
2,0.012135
3,0.011783
4,0.011665
...,...
95,0.008107
96,0.008095
97,0.008018
98,0.007957


In [None]:
plt.plot(np.array(avg_loss_list_dataframe.index), np.array(avg_loss_list_dataframe[0]))
plt.savefig("avg_loss_training.jpg")
plt.show()

In [None]:
plt.cla()
plt.clf()

In [None]:
loss_list_dataframe = pd.DataFrame(all_loss_list)

In [None]:
plt.plot(np.array(loss_list_dataframe.index), np.array(loss_list_dataframe[0]))
plt.savefig("training_loss.jpg")
plt.show()

In [None]:
plt.cla()
plt.clf()

In [None]:
val_loss_per_epoch_list_dataframe = pd.DataFrame(val_loss_per_epoch_list)

In [None]:
plt.plot(np.array(val_loss_per_epoch_list_dataframe.index), np.array(val_loss_per_epoch_list_dataframe[0]))
plt.savefig("training_val_loss.jpg")
plt.show()

In [None]:
def predict(model, input, device):
    model.eval()
    with torch.no_grad():
        input = torch.as_tensor(input).to(torch.float32).to(device)
        # print(target.shape)
        # target = torch.as_tensor(target).to(torch.float32).to(device)
        # TODO: target should be <sos>, should not random
        outputs = model(input)
        return outputs.cpu().numpy()

In [None]:
def audio_preprocess(audio_path, specific_fps):
    n_fft = 4096
    hop = int(44000/specific_fps) #1102.5 -> 40fps #882 -> 50fps
    y, sr = librosa.load(audio_path, sr=44000) #44000 for divide 40
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_fft=n_fft, hop_length=hop, n_mfcc=13)
    y = np.where(y == 0, 1e-10, y)
    energy = np.log(librosa.feature.rms(y=y, frame_length=n_fft, hop_length=hop, center=True))
    mfcc_energy = np.vstack((mfcc, energy))
    mfcc_delta = librosa.feature.delta(mfcc_energy)
    aud = np.vstack((mfcc_energy, mfcc_delta)).T
    
    print("hop:", hop)
    print("aud:", aud.shape)
    return aud

In [None]:
def read_midi(filename, specific_fps):
    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(filename)

    piano_roll = midi_data.get_piano_roll(fs=specific_fps)  # 40fps #250fps
    piano_roll[piano_roll > 0] = 1

    return piano_roll

In [None]:
test_datapath = "./BWV1001/"
change_fps = 40
test_audio_path_list = glob.glob(test_datapath + "*.wav")
test_midi_path_list = glob.glob(test_datapath + "*.mid")
test_audio_dict = {}
test_midi_dict = {}
test_data_list = []
test_music_list = []

test_audio_length = {}

for test_audio in test_audio_path_list:
    str_name = test_audio
    print("str_name:", str_name)
    filename = str_name.split('/')[2]
    filecode = filename.split('.')[0]
    print("filecode: ",filecode)
    # test_music_list.append(filecode)
    
    print(test_audio)
    # read_piano_roll = read_midi(test_midi, change_fps)
    read_audio = audio_preprocess(test_audio, change_fps)
    # read_audio_transpose = read_audio
    print(read_audio.shape)
    test_audio_len = read_audio.shape[0]
    test_audio_length[filecode] = test_audio_len
    test_audio_dict[filecode] = read_audio

for test_midi in test_midi_path_list:
    str_name = test_midi
    print("str_name:", str_name)
    filename = str_name.split('/')[2]
    filecode = filename.split('.')[0]
    print("filecode: ",filecode)
    test_music_list.append(filecode)
    
    print(test_midi)
    read_piano_roll = read_midi(test_midi, change_fps)
    read_piano_roll_transpose = read_piano_roll.T
    print("old len: ", read_piano_roll_transpose.shape)

    audio_len = test_audio_length[filecode]
    midi_len = len(read_piano_roll_transpose)
    if audio_len > midi_len:
        #(top,bottom), (left,right)
        read_piano_roll_transpose = np.pad(read_piano_roll_transpose,
                                            pad_width=((0, audio_len - midi_len), (0, 0)))
    if audio_len < midi_len:
        n = midi_len - audio_len
        read_piano_roll_transpose = read_piano_roll_transpose[:-n, :]
        
    test_midi_len = read_piano_roll_transpose.shape[0]
    test_midi_dict[filecode] = read_piano_roll_transpose
    print("new len: ", read_piano_roll_transpose.shape)


print(test_music_list)

for test_music in test_music_list:
    combine_music = np.append(test_midi_dict[test_music], test_audio_dict[test_music], axis = 1)
    test_data_list.append(combine_music)
    print(combine_music.shape)

str_name: ./BWV1001/vs1-1ada.wav
filecode:  vs1-1ada
./BWV1001/vs1-1ada.wav
hop: 1100
aud: (8172, 28)
(8172, 28)
str_name: ./BWV1001/vs1-3sic.wav
filecode:  vs1-3sic
./BWV1001/vs1-3sic.wav
hop: 1100
aud: (6989, 28)
(6989, 28)
str_name: ./BWV1001/vs1-2fug.wav
filecode:  vs1-2fug
./BWV1001/vs1-2fug.wav
hop: 1100
aud: (11534, 28)
(11534, 28)
str_name: ./BWV1001/vs1-4prs.wav
filecode:  vs1-4prs
./BWV1001/vs1-4prs.wav
hop: 1100
aud: (7898, 28)
(7898, 28)
str_name: ./BWV1001/vs1-1ada.mid
filecode:  vs1-1ada
./BWV1001/vs1-1ada.mid
old len:  (8171, 128)
new len:  (8172, 128)
str_name: ./BWV1001/vs1-2fug.mid
filecode:  vs1-2fug
./BWV1001/vs1-2fug.mid
old len:  (11537, 128)
new len:  (11534, 128)
str_name: ./BWV1001/vs1-3sic.mid
filecode:  vs1-3sic
./BWV1001/vs1-3sic.mid
old len:  (6993, 128)
new len:  (6989, 128)
str_name: ./BWV1001/vs1-4prs.mid
filecode:  vs1-4prs
./BWV1001/vs1-4prs.mid
old len:  (7897, 128)
new len:  (7898, 128)
['vs1-1ada', 'vs1-2fug', 'vs1-3sic', 'vs1-4prs']
(8172, 156)
(11

In [None]:
def column(matrix, i):
    return [row[i] for row in matrix]

def test_render_animation(fps, output, azim, prediction, ground_truth=None):
    prediction_array = np.asarray(prediction)
    print(prediction_array.size)
    limit = len(prediction_array)
    print("limit", limit)
    size = 6#6
    fps = 40

    # Skeleton layout
    parents = [[0, 1], [1, 3], [3, 2], [0, 2],#head
                [8, 6], [6, 13], [13, 4], [4, 8],#shoulder
                [6, 4], [4, 5], [5, 7], [7, 6],#Upper torso
                [8, 18], [8, 20], [13, 21], [13, 19],
                [5, 20], [5, 21], [7, 18], [7, 19],
                [18, 19], [19, 21], [21, 20], [20, 18], #waist
                [18, 22], [20, 22], [22, 23], [22, 25], [23, 25], [24,23], [24, 25],  #right lag
                [21, 26], [19, 26], [26, 27], [26, 29], [27, 29], [28, 27], [28, 29], #left lag
                [8, 9], [9, 11], [9, 10], [10, 11], [10, 12], [9, 12], [11, 12], #right hand
                [13, 14], [14, 16], [14, 15], [16, 15], [14, 17], [16, 17], [15, 17], #left hand
                [31, 33], [30, 32], [30, 31], [32, 33], [31, 32], [30, 33] #instrument
                        ]
    # joints_right = [1, 2, 12, 13, 14]

    prediction_array[:, :, 2] += 0.1 #[:, :, 2]
    if ground_truth is not None:
        ground_truth[:, :, 2] += 0.1
        poses = {'Prediction': prediction_array,
                 'Ground_truth': ground_truth}
    else:
        poses = {'Prediction': prediction_array}
    

    fig = plt.figure()#(figsize=(size*len(poses), size))
    # ax_3d = []
    # lines_3d = []
    radius = 1#14 #3.7#
    # print(poses)
    for index, (title, data) in enumerate(poses.items()):
        ax = fig.add_subplot(1, len(poses), index + 1, projection='3d')
        ax.clear()
        print(data)
        ims = [] #每一 frame 都存
        for frame_index, each_frame in enumerate(data):
            # print("each_frame")
            # print(each_frame)
            ax.view_init(elev=15., azim=azim)
            ax.set_xlim3d([-radius/2, radius/2])
            ax.set_zlim3d([0, radius])
            ax.set_ylim3d([-radius/2, radius/2])
            ax.set_aspect('auto') #ax.set_aspect('equal')

            # print(title)
            points = ax.scatter(column(each_frame[:30], 0), column(each_frame[:30], 1), column(each_frame[:30], 2), cmap='jet', marker='o', label='body joint', color = 'black')
            points_2 = ax.scatter(column(each_frame[30:32], 0), column(each_frame[30:32], 1), column(each_frame[30:32], 2), cmap='jet', marker='o', label='body joint', color = 'blue')
            points_3 = ax.scatter(column(each_frame[32:34], 0), column(each_frame[32:34], 1), column(each_frame[32:34], 2), cmap='jet', marker='o', label='body joint', color = 'red')
            
            # ax.scatter(column(each_frame, 0), column(each_frame, 1), column(each_frame, 2), cmap='jet', marker='o', label='body joint')
            # ax.legend()
            # print("+++")
            
            parents = [[0, 1], [1, 3], [3, 2], [0, 2],#head
                        [8, 6], [6, 13], [13, 4], [4, 8],#shoulder
                        [6, 4], [4, 5], [5, 7], [7, 6],#Upper torso
                        [8, 18], [8, 20], [13, 21], [13, 19],
                        [5, 20], [5, 21], [7, 18], [7, 19],
                        [18, 19], [19, 21], [21, 20], [20, 18], #waist
                        [18, 22], [20, 22], [22, 23], [22, 25], [23, 25], [24,23], [24, 25],  #right lag
                        [21, 26], [19, 26], [26, 27], [26, 29], [27, 29], [28, 27], [28, 29], #left lag
                        [8, 9], [9, 11], [9, 10], [10, 11], [10, 12], [9, 12], [11, 12], #right hand
                        [13, 14], [14, 16], [14, 15], [16, 15], [14, 17], [16, 17], [15, 17], #left hand
                        [30, 31], [32, 33],  #instrument
                        # [31, 33], [30, 32], [30, 31], [32, 33], [31, 32], [30, 33] #instrument
                        ]
            lines = []
            # draw line
            
            # lines = [ax.plot([each_frame[vs][0], each_frame[ve][0]],
            #                  [each_frame[vs][1], each_frame[ve][1]],
            #                  [each_frame[vs][2], each_frame[ve][2]]) for (vs, ve) in parents]
            line_num = len(parents)
            for idx, each_line in enumerate(parents):
                vec_start = each_frame[each_line[0]]
                vec_end = each_frame[each_line[1]]
                # print(vec_start)
                # print(vec_end)
                line_color = "black"
                if idx == line_num-2:
                    line_color = "blue"
                if idx == line_num-1:
                    line_color = "red"
                # ax.plot([vec_start[0], vec_end[0]], [vec_start[1], vec_end[1]], [vec_start[2], vec_end[2]])
                
                temp, = ax.plot([vec_start[0], vec_end[0]], [vec_start[1], vec_end[1]], [vec_start[2], vec_end[2]], color=line_color)
                lines.append(temp)

            # ax.figure.savefig('./test_pic/pic' + str(frame_index) + '.png', dpi=100, bbox_inches = 'tight')

            # ims.append([points])
            # image_frame = [points].extend(lines)
            ims.append([points]+[points_2]+[points_3]+lines) #TODO: try extend

            # plt.cla()
            # print("+++")

    anim = matplotlib.animation.ArtistAnimation(fig, ims, interval=1000/fps)

    if output.endswith('.mp4'):
        FFwriter = matplotlib.animation.FFMpegWriter(fps=fps, extra_args=['-vcodec', 'libx264'])
        anim.save(output, writer=FFwriter)
    elif output.endswith('.gif'):
        anim.save(output, fps=fps, dpi=100, writer='imagemagick')
    else:
        raise ValueError('Unsupported output format (only .mp4 and .gif are supported)')

In [None]:
def plot(audio_path, plot_path, prediction, sample_time, fps, name=""): #audio_path, plot_path, 
    # render_animation(fps, output='new_temp.mp4', azim=75, prediction=prediction)
    test_render_animation(fps, output='new_temp_' + name + '.mp4', azim=75, prediction=prediction)

    # # #merge with wav
    input_video = ffmpeg.input('new_temp_' + name + '.mp4')
    fluid_syn = FluidSynth()
    fluid_syn.midi_to_audio(audio_path, './output' + name + '.wav')
    input_audio = ffmpeg.input('./output' + name + '.wav')
    # output = ffmpeg.output(video, audio, plot_path, vcodec='copy', acodec='aac', strict='experimental')
    ffmpeg.concat(input_video, input_audio, v=1, a=1).output(plot_path).run()
    # os.remove('new_temp_' + name + '.mp4')

In [None]:
model.eval()

full_prediction = pd.DataFrame()
num_count = 0
# read midi
# test_dataloader = get_dataloader(test_datapath, batch_size=1)
for test_batch in test_data_list:
    with torch.no_grad():
        # first_target = torch.zeros(test_batch.shape[0],112)
        # print(first_target.shape)
        test_input = test_batch[None, :]
        # test_target = first_target[None, :]
        print("test_input", test_input.shape)
        # print("test_target", test_target.shape)
        prediction = predict(model, test_input, device)
        
        # print(prediction.shape)
        
        prediction  = prediction[:, :, :102]
        print("prediction.shape", prediction.shape)
        
        # full_prediction.append(prediction)
        full_prediction = pd.DataFrame(prediction[0])
        print("full_prediction", full_prediction.shape)
        
        # prev_prediction = prediction[0][:-1][None, :]
        # print(prev_prediction.shape)
        
        Row_list_prediction =[]
        
        filecode = test_music_list[num_count]
    
        # Iterate over each row
        for index, rows in full_prediction.iterrows():
            #fill nan
            rows = rows.fillna(0)
            # Create list for the current row
            my_list = rows.values.tolist()
            # print(my_list)
            
            my_list_per3 = [my_list[i:i+3] for i in range(0, len(my_list), 3)]
            # append the list to the final list
            Row_list_prediction.append(my_list_per3)

        # print(len(Row_list_prediction), len(Row_list_prediction[0]),len(Row_list_prediction[0][0]))
        plot(test_datapath + test_music_list[num_count] + ".mid", "./video_" + filecode + "_test_predict.mp4", Row_list_prediction[:800], None, 40, filecode) #ow_list[0:900]
        # print("prediction.shape", prediction.shape)
        prediction_arr = np.array(Row_list_prediction)
        # formated_motion = prediction_format(full_prediction)
        # # # plot(formated_motion)
        # audio_path = test_music_list[num_count][0]
        # output_path = "test_output_" + filecode + ".mp4"
        # plot(formated_motion, audio_path, output_path, None, 10, filecode)
        num_count += 1

# model.train()

test_input (1, 8172, 156)


prediction.shape (1, 8172, 102)
full_prediction (8172, 102)
81600
limit 800
[[[ 0.06355162  0.13356397  1.05939022]
  [ 0.02451253  0.09293263  1.077633  ]
  [ 0.12643632  0.05664559  1.05023298]
  ...
  [ 0.04668002  0.10732425  0.93162737]
  [-0.1422101   0.08608773  1.03613845]
  [ 0.08750859  0.2402505   0.73955188]]

 [[ 0.08414692  0.11390457  1.08929906]
  [ 0.036357    0.09666923  1.09711764]
  [ 0.15051877  0.04217945  1.09635202]
  ...
  [ 0.06092626  0.1131491   0.94641129]
  [-0.12833472  0.08752057  1.04275451]
  [ 0.12260811  0.21374393  0.77403579]]

 [[ 0.07429266  0.10394423  1.09531007]
  [ 0.02439765  0.09194402  1.09812889]
  [ 0.14384674  0.03564306  1.10306678]
  ...
  [ 0.05184722  0.11060955  0.95129052]
  [-0.13783553  0.07621256  1.04014263]
  [ 0.11655657  0.21114898  0.78239528]]

 ...

 [[ 0.10110911  0.11616088  1.10146627]
  [ 0.04462209  0.07892136  1.10632417]
  [ 0.14449075  0.04203048  1.11012242]
  ...
  [ 0.05935774  0.09892346  0.96882555]
  [-0.07

fluidsynth: panic: An error occurred while reading from stdin.


FluidSynth runtime version 2.1.1
Copyright (C) 2000-2020 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of E-mu Systems, Inc.

Rendering audio to file './outputvs1-1ada.wav'..


ffmpeg version 4.2.2 Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/ilc/anaconda3/envs/sinica --cc=/tmp/build/80754af9/ffmpeg_1587154242452/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --enable-avresample --enable-gmp --enable-hardcoded-tables --enable-libfreetype --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame --disable-nonfree --enable-gpl --enable-gnutls --disable-openssl --enable-libopenh264 --enable-libx264
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     7. 57.100 /  7. 57.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  5.100 /  5.  5.100
  libswresample   3.  5.100 /  3.  5.100
  libpostproc    55

test_input (1, 11534, 156)
prediction.shape (1, 11534, 102)
full_prediction (11534, 102)
81600
limit 800
[[[ 0.08070495  0.14534007  1.11399064]
  [ 0.06805236  0.09855323  1.0968518 ]
  [ 0.13246995  0.03517139  1.03133163]
  ...
  [ 0.00319252  0.12243085  0.96492354]
  [-0.04189362  0.20937711  1.00741146]
  [ 0.04500939  0.22100392  0.66683141]]

 [[ 0.05204338  0.18029329  1.15192351]
  [ 0.02479108  0.1176395   1.13649688]
  [ 0.11065398  0.06827161  1.10690877]
  ...
  [-0.04117206  0.13978781  1.00970546]
  [-0.09507247  0.22575581  1.10915658]
  [ 0.02460541  0.22122194  0.7199851 ]]

 [[ 0.04383961  0.19789372  1.15437171]
  [ 0.01176155  0.13208453  1.14288435]
  [ 0.10307933  0.08855149  1.12313328]
  ...
  [-0.04814007  0.14903975  1.0238103 ]
  [-0.11726123  0.22779185  1.13596091]
  [ 0.0087961   0.23648413  0.72666035]]

 ...

 [[ 0.05046132  0.17993964  1.10056386]
  [-0.00964692  0.14965498  1.11908648]
  [ 0.0923114   0.10692465  1.12704203]
  ...
  [-0.00319973  0.1

fluidsynth: panic: An error occurred while reading from stdin.


FluidSynth runtime version 2.1.1
Copyright (C) 2000-2020 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of E-mu Systems, Inc.

Rendering audio to file './outputvs1-2fug.wav'..


ffmpeg version 4.2.2 Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/ilc/anaconda3/envs/sinica --cc=/tmp/build/80754af9/ffmpeg_1587154242452/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --enable-avresample --enable-gmp --enable-hardcoded-tables --enable-libfreetype --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame --disable-nonfree --enable-gpl --enable-gnutls --disable-openssl --enable-libopenh264 --enable-libx264
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     7. 57.100 /  7. 57.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  5.100 /  5.  5.100
  libswresample   3.  5.100 /  3.  5.100
  libpostproc    55

test_input (1, 6989, 156)
prediction.shape (1, 6989, 102)
full_prediction (6989, 102)
81600
limit 800
[[[ 0.08093315  0.16217832  1.08257095]
  [ 0.02813307  0.10430899  1.09326998]
  [ 0.12145004  0.07571991  1.07382003]
  ...
  [ 0.04153442  0.11155283  0.95659081]
  [-0.12720582  0.10311461  1.04320166]
  [ 0.11097981  0.27614909  0.75573698]]

 [[ 0.08594778  0.14783636  1.10058997]
  [ 0.0249481   0.10837639  1.1080421 ]
  [ 0.12715413  0.0657869   1.10983608]
  ...
  [ 0.03546703  0.11836524  0.96746788]
  [-0.12930612  0.09401795  1.05956051]
  [ 0.14663424  0.27465183  0.79056177]]

 [[ 0.095667    0.13217275  1.10089192]
  [ 0.03438616  0.09972148  1.10742543]
  [ 0.13994768  0.05398431  1.11132989]
  ...
  [ 0.04686549  0.11235158  0.96742312]
  [-0.11001059  0.09180865  1.04840288]
  [ 0.16705954  0.27243337  0.78550074]]

 ...

 [[ 0.09965208  0.11615978  1.1017879 ]
  [ 0.04233343  0.07993923  1.10244629]
  [ 0.1539481   0.04093921  1.11618612]
  ...
  [ 0.06512362  0.1035

fluidsynth: panic: An error occurred while reading from stdin.


FluidSynth runtime version 2.1.1
Copyright (C) 2000-2020 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of E-mu Systems, Inc.

Rendering audio to file './outputvs1-3sic.wav'..


ffmpeg version 4.2.2 Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/ilc/anaconda3/envs/sinica --cc=/tmp/build/80754af9/ffmpeg_1587154242452/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --enable-avresample --enable-gmp --enable-hardcoded-tables --enable-libfreetype --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame --disable-nonfree --enable-gpl --enable-gnutls --disable-openssl --enable-libopenh264 --enable-libx264
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     7. 57.100 /  7. 57.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  5.100 /  5.  5.100
  libswresample   3.  5.100 /  3.  5.100
  libpostproc    55

test_input (1, 7898, 156)
prediction.shape (1, 7898, 102)
full_prediction (7898, 102)
81600
limit 800
[[[ 0.09799373  0.14396888  1.11264226]
  [ 0.02599373  0.08567091  1.11100791]
  [ 0.1360521   0.03611819  1.06715868]
  ...
  [ 0.03058271  0.12778926  0.95958928]
  [-0.10334559  0.16407764  1.04709647]
  [ 0.11493185  0.21483327  0.68059478]]

 [[ 0.06030766  0.13740326  1.1075351 ]
  [ 0.00342417  0.10915142  1.1095067 ]
  [ 0.1145698   0.05522706  1.10932848]
  ...
  [ 0.02013712  0.12653446  0.96561728]
  [-0.16469252  0.11648785  1.11110768]
  [ 0.09213534  0.21191293  0.78857819]]

 [[ 0.05830814  0.11862735  1.09925029]
  [ 0.00577288  0.09314741  1.10155067]
  [ 0.11214718  0.04878429  1.1079602 ]
  ...
  [ 0.02789533  0.11125202  0.95738063]
  [-0.14720145  0.11713358  1.10546527]
  [ 0.10149716  0.18553953  0.77859191]]

 ...

 [[ 0.06514078  0.10743559  1.09854434]
  [ 0.01258841  0.07519512  1.09770617]
  [ 0.11424431  0.0304051   1.1052197 ]
  ...
  [ 0.02894011  0.1037

fluidsynth: panic: An error occurred while reading from stdin.


FluidSynth runtime version 2.1.1
Copyright (C) 2000-2020 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of E-mu Systems, Inc.

Rendering audio to file './outputvs1-4prs.wav'..


ffmpeg version 4.2.2 Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/ilc/anaconda3/envs/sinica --cc=/tmp/build/80754af9/ffmpeg_1587154242452/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --enable-avresample --enable-gmp --enable-hardcoded-tables --enable-libfreetype --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame --disable-nonfree --enable-gpl --enable-gnutls --disable-openssl --enable-libopenh264 --enable-libx264
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     7. 57.100 /  7. 57.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  5.100 /  5.  5.100
  libswresample   3.  5.100 /  3.  5.100
  libpostproc    55

In [None]:
model.eval()

full_prediction = pd.DataFrame()
num_count = 0
# read midi
# test_dataloader = get_dataloader(test_datapath, batch_size=1)
for test_batch in test_data_list:
    with torch.no_grad():
        first_target = torch.zeros(test_batch.shape[0],112)
        print(first_target.shape)
        test_input = test_batch[None, :]
        test_target = first_target[None, :]
        print("test_input", test_input.shape)
        print("test_target", test_target.shape)
        prediction = predict(model, test_input, device)
        
        # print(prediction.shape)
        
        prediction  = prediction[:, :, :102]
        print("prediction.shape", prediction.shape)
        
        # full_prediction.append(prediction)
        full_prediction = pd.DataFrame(prediction[0])
        print("full_prediction", full_prediction.shape)
        
        # prev_prediction = prediction[0][:-1][None, :]
        # print(prev_prediction.shape)
        
        Row_list_prediction =[]
        
        filecode = test_music_list[num_count]
    
        # Iterate over each row
        for index, rows in full_prediction.iterrows():
            #fill nan
            rows = rows.fillna(0)
            # Create list for the current row
            my_list = rows.values.tolist()
            # print(my_list)
            
            my_list_per3 = [my_list[i:i+3] for i in range(0, len(my_list), 3)]
            # append the list to the final list
            Row_list_prediction.append(my_list_per3)

        prediction_arr = np.array(Row_list_prediction)
        if not os.path.exists('./output_prediction/[all]'+str(num_layers)+'LSTM_hidden'+str(hidden_size)+'_'+str(num_epochs)+'epoch/'):
            os.makedirs('./output_prediction/[all]'+str(num_layers)+'LSTM_hidden'+str(hidden_size)+'_'+str(num_epochs)+'epoch/')
        midi_data_output = open('./output_prediction/[all]'+str(num_layers)+'LSTM_hidden'+str(hidden_size)+'_'+str(num_epochs)+'epoch/prediction_'+
                                filecode +'.pkl', 'wb')
        pickle.dump(prediction_arr, midi_data_output)
        midi_data_output.close()
        
        num_count += 1

# model.train()

torch.Size([8172, 112])
test_input (1, 8172, 156)
test_target torch.Size([1, 8172, 112])
prediction.shape (1, 8172, 102)
full_prediction (8172, 102)
torch.Size([11534, 112])
test_input (1, 11534, 156)
test_target torch.Size([1, 11534, 112])
prediction.shape (1, 11534, 102)
full_prediction (11534, 102)
torch.Size([6989, 112])
test_input (1, 6989, 156)
test_target torch.Size([1, 6989, 112])
prediction.shape (1, 6989, 102)
full_prediction (6989, 102)
torch.Size([7898, 112])
test_input (1, 7898, 156)
test_target torch.Size([1, 7898, 112])
prediction.shape (1, 7898, 102)
full_prediction (7898, 102)


In [None]:
final_val_loss = evaluate_lstm(model, val_dataloader)

Validation
len(all_data) 6061
len(motion_data) 6061
len(all_data) 6706
len(motion_data) 6706
len(all_data) 6069
len(motion_data) 6069
len(all_data) 4525
len(motion_data) 4525
len(all_data) 5281
len(motion_data) 5281
inputs.shape: torch.Size([5, 6706, 156])
targets.shape: torch.Size([5, 6706, 112])
outputs.shape: torch.Size([5, 6706, 112])


In [None]:
print(final_val_loss)

0.01275933813303709


In [None]:
final_l1_loss = evaluate_l1(model, val_dataloader)
print(final_l1_loss)

Validation L1
len(all_data) 5281
len(motion_data) 5281
len(all_data) 6069
len(motion_data) 6069
len(all_data) 6061
len(motion_data) 6061
len(all_data) 4525
len(motion_data) 4525
len(all_data) 6706
len(motion_data) 6706
inputs.shape: torch.Size([5, 6706, 156])
targets.shape: torch.Size([5, 6706, 112])
outputs.shape: torch.Size([5, 6706, 112])
0.04442169889807701
