In [1]:
from model import Encoder, Decoder, Seq2Seq
from data_loader import *
import pandas as pd
import torch.optim.lr_scheduler as lr_scheduler
from torch import optim
import torch.nn.functional as F
import datetime
import pretty_midi
import glob

In [2]:
import os
import matplotlib
import math
matplotlib.use('Agg')
# matplotlib.use("QtAgg")
import ffmpeg
#conda install -c conda-forge ffmpeg-python

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, writers
plt.rcParams['animation.ffmpeg_path'] = '/home/ilc/anaconda3/bin/ffmpeg'#'/usr/bin/ffmpeg'

import numpy as np
import subprocess as sp
from moviepy.video.io.VideoFileClip import VideoFileClip
from moviepy.audio.io.AudioFileClip import AudioFileClip

from midi2audio import FluidSynth

from torch.autograd import Variable
from sklearn.model_selection import KFold

import itertools

In [3]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [4]:
dataset_name_path = f"./midi_list_symbolic_cross.txt" #f"./midi_list.txt"
dataloader = get_dataloader(dataset_name_path, batch_size=8) #[20, 512, 128], [20, 512, 102]
dataset = MidiMotionDataSet(dataset_name_path)

val_dataset_name_path = f"./midi_list_symbolic_cross.txt" #f"./midi_list_eval.txt"
# val_dataloader = get_val_dataloader(val_dataset_name_path, batch_size=40) #[20, 512, 128], [20, 512, 102]

full_data_path = None
with open("./midi_list_symbolic_cross.txt", "r") as file:
    lines = [line.strip() for line in file]
    full_data_path = np.array(lines)

val_data_read = np.reshape(full_data_path, (11, 10))
# print(val_data_read)

learning_rate = 0.001#0.001

# input_size_encoder = 128 #129 #128
# input_size_decoder = 115 #102 #24
# output_size = 115#102 #24

# encoder_embedding_size = 300
# decoder_embedding_size = 300
enc_dropout = 0.5
dec_dropout = 0.
step = 0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

self.piece_count:  110
dataset_len:  11000
self.piece_count:  110
dataset_len:  11000
cuda:0


In [5]:
def reset_weights(model): # reset the weight every fold
    if isinstance(model, nn.LSTM) or isinstance(model, nn.Linear):
        model.reset_parameters()

In [6]:
class LSTM1(nn.Module):
    def __init__(self, output_dim, input_size, hidden_size, num_layers, seq_length):
        super(LSTM1, self).__init__()
        self.output_dim = output_dim #number of classes
        self.num_layers = num_layers #number of layers
        self.input_size = input_size #input size
        self.hidden_size = hidden_size #hidden state
        self.seq_length = seq_length #sequence length

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                          num_layers=num_layers, batch_first=True) #lstm
        self.fc_1 =  nn.Linear(hidden_size, output_dim) #fully connected to determine output dim

        self.relu = nn.ReLU()

    def forward(self,x):
        # h0, c0 no time information
        h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)).to(device) #hidden state
        c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)).to(device) #internal state
        # Propagate input through LSTM
        # x is MIDI => [44, 512, 128]

        # hn is final state, run over the sequence length
        output, (hn, cn) = self.lstm(x, (h_0, c_0)) #lstm with input, hidden, and internal state
        # hn = hn.view(-1, self.hidden_size) #reshaping the data for Dense layer next
        # print("output.shape", output.shape)
        # print("hn.shape", hn.shape)
        # out = self.relu(hn)
        out = self.fc_1(output) #final
        return out
 

In [7]:
# Define the model architecture
input_size = 128 #number of features
hidden_size = 1024 #number of features in hidden state
num_layers = 1 #number of stacked lstm layers
seq_len = 512
output_dim = 115 #number of output classes

# model = LSTM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights).to(device)
# model = LSTM(embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights).to(device)
# model = LSTM1(output_dim, input_size, hidden_size, num_layers, seq_len).to(device) #our lstm class
# model.train()
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

num_epochs = 100 #10
k_folds = 11
cross_valid_results = {}
torch.manual_seed(42)

avg_loss_list = []
all_loss_list = []
val_loss_per_epoch_list = []

#TODO: important cross val record
val_time_loss_list = []
val_dim_loss_list = []
val_mse_loss_list = []
val_per_split_list = [] #just mse loss


In [8]:
def time_wise_loss_fn(preds, labels):
    '''
    calculate time-wise loss for motion (along the time axis)
    input: labels[batch, time, dimension(joint*xyz)]
    preds[batch, time , dimension(joint*xyz)]
    output: time loss
    '''
    epsilon = 1e-7
    preds = preds + epsilon

    labels_transpose = torch.permute(labels, (0, 2, 1))#tf.transpose(labels, [0, 2, 1]) # [b, 3, t]
    preds_transpose = torch.permute(preds, (0, 2, 1))#tf.transpose(preds, [0, 2, 1]) # [b, 3, t]
    # print("labels_transpose.shape", labels_transpose.shape)
    # print("preds_transpose.shape", preds_transpose.shape)
    # print("labels_transpose[:, :, :, None].shape", labels_transpose[:, :, :, None].shape)
    # print("labels_transpose[:, :, None, :].shape", labels_transpose[:, :, None, :].shape)
    label_diff = labels_transpose[:, :, :, None] - labels_transpose [:, :, None, :] # [b, 3, t, t]
    
    preds_diff = preds_transpose[:, :, :, None] - preds_transpose [:, :, None, :] # [b, 3, t, t]
    # print(preds_diff.shape)
    time_loss = (preds_diff - label_diff)**2 # [b, 3, t, t]
    time_loss_value = time_loss.mean() #float()
    torch.cuda.empty_cache()

    return time_loss_value
    
def dim_wise_loss_fn(preds, labels):
    '''
    calculate dimension-wise loss for motion (along the dimension axis)
    input: labels[batch, time, dimension(joint*xyz)]
    preds[batch, time , dimension(joint*xyz)]
    output: dimension loss
    '''
    epsilon = 1e-7
    preds = preds + epsilon
    
    label_diff = labels[:, :, :, None] - labels[:, :, None, :] # [b, t, 3, 3]
    preds_diff = preds[:, :, :, None] - preds[:, :, None, :] # [b, t, 3, 3]
    dim_loss = (preds_diff - label_diff)**2 # [b, t, 3, 3]
    dim_loss_value = dim_loss.mean() #float()
    torch.cuda.empty_cache()
    
    return dim_loss_value

In [9]:
def customized_mse_loss(output, target):
    # target = target.transpose(0, 1)

    # print("output.shape:", output.shape) #torch.Size([20, 513, 102])
    # print("target.shape:", target.shape) #torch.Size([20, 513, 102])

    w1_time = 0.3
    w2_dim = 0.3
    w3_mse = 0.4

    mse_loss = F.mse_loss(output, target)
    time_loss = time_wise_loss_fn(output, target)
    dim_loss = dim_wise_loss_fn(output, target)

    # print("time_loss:", time_loss)
    # print("dim_loss:", dim_loss)
    # print("mse_loss:", mse_loss)
    val_time_loss_list.append(time_loss.cpu().item())
    val_dim_loss_list.append(dim_loss.cpu().item())
    val_mse_loss_list.append(mse_loss.cpu().item())

    segment_loss = (w1_time * time_loss) + (w2_dim * dim_loss) + (w3_mse * mse_loss)
    torch.cuda.empty_cache()
    return  segment_loss

In [10]:
def evaluate_lstm_cross(model, split_count):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    counter = 0
    # previous_output = torch.zeros(512, 102).to(device)
    
    outputs_save = []
    
    np.savetxt('./temp_path.txt', val_data_read[split_count], delimiter="\n", fmt="%s")
    # print(val_data_read[split_count])
    val_dataloader = get_val_dataloader('./temp_path.txt', batch_size=11)
    
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(val_dataloader): #tqdm(enumerate(val_dataloader), total=len(val_dataloader))
            counter += 1

            inputs = inputs.to(device).float()
            targets = targets.to(device).float()
            # print("val inputs.shape:", inputs.shape)
            # print("val targets.shape:", targets.shape)
            outputs = model(inputs)
            # print("val outputs.shape:", outputs.shape)

            loss =  F.mse_loss(outputs, targets)
            valid_running_loss += loss.cpu().item()
            # previous_output = outputs
            outputs_save.append(np.asarray(outputs.cpu()))

    loc_dt = datetime.datetime.today()
    loc_dt_format = loc_dt.strftime("%Y-%m-%d_%H-%M-%S")
    if not os.path.exists("./output_eval/"):
        os.makedirs('./output_eval/')

    # print("counter", counter)
    # print("val_data_read[split_count][counter]", val_data_read[split_count][counter])
    val_file_name = val_data_read[split_count][counter].split('/')[2].split('.')[0]

    # print("val file_name:", val_file_name)
    # print("outputs_save length: ", len(outputs_save), ", element shape: " , outputs_save[0].shape)
        
    eval_output = open("./output_eval/[split_" + str(split_count) + "][midi_with_anno][total" + str(num_epochs) + "_hs" + str(hidden_size) +"]save_"+ str(loc_dt_format) + "_l1_loss_" + str(loss.cpu().item())+".pkl", 'wb')
    pickle.dump(np.asarray(outputs_save), eval_output)
    eval_output.close()
    
    # print("val counter:", counter)
    epoch_val_loss_f1 = valid_running_loss / counter
    val_per_split_list.append(epoch_val_loss_f1)
    print("split_count:", split_count, ", epoch_val_loss:", epoch_val_loss_f1)
    os.remove("./temp_path.txt")
    return #epoch_val_loss_f1
# model.train()

In [11]:
kf = KFold(n_splits=k_folds)
print(kf.get_n_splits(dataset))
KFold(n_splits=k_folds, random_state=None, shuffle=False)
# for i, (train_index, test_index) in enumerate(kf.split(X)):
#     #...TODO
split_count = 0
random_pick_fold = random.randint(0, 10) #0~10
print("random: ", random_pick_fold)
for fold,(train_idx, test_idx) in enumerate(kf.split(dataset)): #TODO: random pick 1 fold
    print('------------fold no---------{}----------------------'.format(fold))
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(test_idx)
    print("train_idx:", train_idx[0], "~", train_idx[-1], " test_idx:", test_idx[0], "~", test_idx[-1])

    train_loader = DataLoader(
                        dataset,
                        num_workers=0,
                        pin_memory=False,
                        drop_last=False,
                        batch_size=8, sampler=train_subsampler) #bs=40:4.49G, bs=128:14.65G

    val_loader = DataLoader(
                        dataset,
                        num_workers=0,
                        pin_memory=False,
                        drop_last=False,
                        batch_size=8, sampler=val_subsampler)
    
    model = LSTM1(output_dim, input_size, hidden_size, num_layers, seq_len).to(device) #our lstm class
    model.apply(reset_weights)
    
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        print(f'Starting epoch {epoch+1}')
        losses = []
        loss = 0
        mean_loss = 0
        for i, (midi_batch, motion_batch) in enumerate(train_loader):
            model.train()
            
            midi_batch = midi_batch.to(device).float()
            motion_batch = motion_batch.to(device).float()

            optimizer.zero_grad()
            output = model(midi_batch) #midi_batch
            # print("train inputs.shape:", midi_batch.shape, torch.isnan(midi_batch).any())
            # # print(motion_batch)
            # print("train targets.shape:", motion_batch.shape, torch.isnan(motion_batch).any())
            # print("train outputs.shape:", output.shape, torch.isnan(output).any())

            # loss =  F.mse_loss(output, motion_batch)
            # loss = customized_mse_loss(output.cpu(), motion_batch.cpu())
            loss = customized_mse_loss(output, motion_batch)
            
            losses.append(loss.cpu().item()) #.cpu().item()
            all_loss_list.append(loss.cpu().item()) #.cpu().item()
            loss.backward()

            optimizer.step()

            # print(f"Epoch {epoch}, batch {i}: loss = {loss.cpu().item():.6f}") #.cpu().item()

        # print(losses, sum(losses), len(losses))
        mean_loss = sum(losses)/len(losses)
        # correct, total = 0, 0
        valid_running_loss = 0.0
        counter = 0
        with torch.no_grad():
            for i, (midi_test, motion_test) in enumerate(val_loader):
                
                inputs = midi_test.to(device).float()
                targets = motion_test.to(device).float()

                outputs = model(inputs)
                # print("val inputs.shape:", inputs.shape)
                # print("val targets.shape:", targets.shape)
                # print("val outputs.shape:", outputs.shape)

                val_loss =  customized_mse_loss(outputs, targets)
                valid_running_loss += val_loss.cpu().item() #.cpu().item()
                counter += 1
            
            epoch_val_loss = valid_running_loss / counter
            # print(f"Epoch {epoch}: val_loss = {epoch_val_loss:.6f}") #.cpu().item()

        avg_loss_list.append(mean_loss) #.cpu().item()
        val_loss_per_epoch_list.append(epoch_val_loss) #.cpu().item()

        cross_valid_results[fold] = epoch_val_loss
        
        loc_dt = datetime.datetime.today()
        loc_dt_format = loc_dt.strftime("%Y-%m-%d_%H-%M-%S")
        if (epoch+1)%100 == 0:
            torch.save({
                'epoch':epoch,
                'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'loss':loss
            },  "./model_save/[midi_with_anno][total"+str(num_epochs)+ "_hs" + str(hidden_size) +"]LSTM_save_epoch_" + str(epoch)+ "_"+ str(loc_dt_format) + "_avg_loss_" + str(mean_loss) +".tar")

    # Print fold results
    print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
    print('--------------------------------')
    sum_loss = 0.0
    for key, value in cross_valid_results.items():
        print(f'Fold loss {key}: {value}')
        sum_loss += value
    print(f'Average vaildation new loss: {sum_loss/len(cross_valid_results.items())}')
    
    # validation result save
    evaluate_lstm_cross(model, split_count)
    
    split_count += 1

11
random:  4


ValueError: too many values to unpack (expected 2)

In [None]:
# # train the model
# for epoch in range(num_epochs):
#     # previous_output = torch.zeros(1, 512, 102).to(device)
#     losses = []
#     for i, (midi_batch, motion_batch) in enumerate(dataloader):
#         model.train()
        
#         midi_batch = midi_batch.to(device).float()
#         motion_batch = motion_batch.to(device).float()
#         # print("midi_batch", midi_batch.shape)
#         # print("motion_batch", motion_batch.shape)

#         optimizer.zero_grad()
#         output = model(midi_batch) #midi_batch
#         # print("output.shape", output.shape)

#         # motion_ground_truth_padding = F.pad(motion_batch, (0,0,0,1), value = 1) #<eot>
        
#         # loss =  F.mse_loss(output, motion_ground_truth_padding)
#         loss =  F.mse_loss(output, motion_batch)
#         # loss = customized_mse_loss(output, motion_ground_truth_padding, previous_output, midi_batch)
#         # loss = customized_mse_loss(output, motion_batch, previous_output, midi_batch)

#         # losses 累計lose
#         losses.append(loss.cpu().item())
#         all_loss_list.append(loss.cpu().item())
#         loss.backward()

#         optimizer.step()
#         mean_loss = sum(losses)/len(losses)

#         print(f"Epoch {epoch}, batch {i}: loss = {loss.cpu().item():.6f}")

#         # scheduler.step(1)
#         # previous_output = output

#         loc_dt = datetime.datetime.today()
#         loc_dt_format = loc_dt.strftime("%Y-%m-%d_%H-%M-%S")

#     val_loss = evaluate_lstm(model, val_dataloader) #CUDA out of memory
#     val_loss_per_epoch_list.append(val_loss)
#     print(f"Epoch {epoch}: val_loss = {val_loss:.6f}")
#     # save_best_model(
#     #         val_loss, epoch, model, optimizer, loss, loc_dt_format, mean_loss
#     #     )
#     avg_loss_list.append(mean_loss)
#     loc_dt = datetime.datetime.today()
#     loc_dt_format = loc_dt.strftime("%Y-%m-%d_%H-%M-%S")
#     if (epoch+1)%100 == 0:
#         torch.save({
#             'epoch':epoch,
#             'model_state_dict':model.state_dict(),
#             'optimizer_state_dict':optimizer.state_dict(),
#             'loss':loss
#         }, "./model_save/[100epoch]LSTM_save_epoch_" + str(epoch)+ "_"+ str(loc_dt_format) + "_avg_loss_" + str(mean_loss) +".tar")

In [None]:
print(loc_dt_format)
print(avg_loss_list)

2023-08-14_16-33-14
[0.028956809107214213, 0.02900429619476199, 0.029244605258107185, 0.028834903914481402, 0.029323842100799084, 0.029359927907586097, 0.028954434435814618, 0.029166191596537827, 0.029162976112216712, 0.028837952252477407, 0.029094887372106314]


In [None]:
print(val_loss_per_epoch_list)

[0.022863045521080493, 0.022591257356107234, 0.02265377711504698, 0.02310947358608246, 0.022645045630633832, 0.022554774172604083, 0.022957211136817934, 0.022894382737576963, 0.02252921000123024, 0.023516938537359238, 0.02348352946341038]


In [None]:
# def lr_lambda(epoch):
#     # LR to be 0.1 * (1/1+0.01*epoch)
#     base_lr = 0.1
#     factor = 0.01
#     return base_lr/(1+factor*epoch)

In [None]:
# scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=10)
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
plt.cla()
plt.clf()

In [None]:
print(len(avg_loss_list))
avg_loss_list_dataframe = pd.DataFrame(avg_loss_list)

11


In [None]:
avg_loss_list_dataframe

Unnamed: 0,0
0,0.028957
1,0.029004
2,0.029245
3,0.028835
4,0.029324
5,0.02936
6,0.028954
7,0.029166
8,0.029163
9,0.028838


In [None]:
plt.plot(np.array(avg_loss_list_dataframe.index), np.array(avg_loss_list_dataframe[0]))
plt.savefig("avg_loss_training.jpg")
plt.show()

In [None]:
plt.cla()
plt.clf()

In [None]:
loss_list_dataframe = pd.DataFrame(all_loss_list)

In [None]:
plt.plot(np.array(loss_list_dataframe.index), np.array(loss_list_dataframe[0]))
plt.savefig("training_loss.jpg")
plt.show()

In [None]:
plt.cla()
plt.clf()

In [None]:
val_loss_per_epoch_list_dataframe = pd.DataFrame(val_loss_per_epoch_list)

In [None]:
plt.plot(np.array(val_loss_per_epoch_list_dataframe.index), np.array(val_loss_per_epoch_list_dataframe[0]))
plt.savefig("training_val_loss.jpg")
plt.show()

In [None]:
def predict(model, input, device):
    model.eval()
    with torch.no_grad():
        input = torch.as_tensor(input).to(torch.float32).to(device)
        # print(target.shape)
        # target = torch.as_tensor(target).to(torch.float32).to(device)
        # TODO: target should be <sos>, should not random
        outputs = model(input)
        return outputs.cpu().numpy()

In [None]:
def read_midi(filename, specific_fps):
    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(filename)

    piano_roll = midi_data.get_piano_roll(fs=specific_fps)  # 40fps #250fps
    piano_roll[piano_roll > 0] = 1

    return piano_roll

In [None]:
test_datapath = "./BWV1001/"
change_fps = 40
test_midi_path_list = glob.glob(test_datapath + "*.mid")
test_data_list = []
test_music_list = []
for test_midi in test_midi_path_list:
    str_name = test_midi
    print("str_name:", str_name)
    filename = str_name.split('/')[2]
    filecode = filename.split('.')[0]
    print("filecode: ",filecode)
    test_music_list.append(filecode)
    
    print(test_midi)
    read_piano_roll = read_midi(test_midi, change_fps)
    read_piano_roll_transpose = read_piano_roll.T
    print(read_piano_roll_transpose.shape)
    test_midi_len = read_piano_roll_transpose.shape[0]
    test_data_list.append(read_piano_roll_transpose)

str_name: ./BWV1001/vs1-1ada.mid
filecode:  vs1-1ada
./BWV1001/vs1-1ada.mid
(8171, 128)
str_name: ./BWV1001/vs1-2fug.mid
filecode:  vs1-2fug
./BWV1001/vs1-2fug.mid
(11537, 128)
str_name: ./BWV1001/vs1-3sic.mid
filecode:  vs1-3sic
./BWV1001/vs1-3sic.mid
(6993, 128)
str_name: ./BWV1001/vs1-4prs.mid
filecode:  vs1-4prs
./BWV1001/vs1-4prs.mid
(7897, 128)


In [None]:
def column(matrix, i):
    return [row[i] for row in matrix]

def test_render_animation(fps, output, azim, prediction, ground_truth=None):
    prediction_array = np.asarray(prediction)
    print(prediction_array.size)
    limit = len(prediction_array)
    print("limit", limit)
    size = 6#6
    fps = 40

    # Skeleton layout
    parents = [[0, 1], [1, 3], [3, 2], [0, 2],#head
                [8, 6], [6, 13], [13, 4], [4, 8],#shoulder
                [6, 4], [4, 5], [5, 7], [7, 6],#Upper torso
                [8, 18], [8, 20], [13, 21], [13, 19],
                [5, 20], [5, 21], [7, 18], [7, 19],
                [18, 19], [19, 21], [21, 20], [20, 18], #waist
                [18, 22], [20, 22], [22, 23], [22, 25], [23, 25], [24,23], [24, 25],  #right lag
                [21, 26], [19, 26], [26, 27], [26, 29], [27, 29], [28, 27], [28, 29], #left lag
                [8, 9], [9, 11], [9, 10], [10, 11], [10, 12], [9, 12], [11, 12], #right hand
                [13, 14], [14, 16], [14, 15], [16, 15], [14, 17], [16, 17], [15, 17], #left hand
                [31, 33], [30, 32], [30, 31], [32, 33], [31, 32], [30, 33] #instrument
                        ]
    # joints_right = [1, 2, 12, 13, 14]

    prediction_array[:, :, 2] += 0.1 #[:, :, 2]
    if ground_truth is not None:
        ground_truth[:, :, 2] += 0.1
        poses = {'Prediction': prediction_array,
                 'Ground_truth': ground_truth}
    else:
        poses = {'Prediction': prediction_array}
    

    fig = plt.figure()#(figsize=(size*len(poses), size))
    # ax_3d = []
    # lines_3d = []
    radius = 1#14 #3.7#
    # print(poses)
    for index, (title, data) in enumerate(poses.items()):
        ax = fig.add_subplot(1, len(poses), index + 1, projection='3d')
        ax.clear()
        print(data)
        ims = [] #每一 frame 都存
        for frame_index, each_frame in enumerate(data):
            # print("each_frame")
            # print(each_frame)
            ax.view_init(elev=15., azim=azim)
            ax.set_xlim3d([-radius/2, radius/2])
            ax.set_zlim3d([0, radius])
            ax.set_ylim3d([-radius/2, radius/2])
            ax.set_aspect('auto') #ax.set_aspect('equal')

            # print(title)
            points = ax.scatter(column(each_frame[:30], 0), column(each_frame[:30], 1), column(each_frame[:30], 2), cmap='jet', marker='o', label='body joint', color = 'black')
            points_2 = ax.scatter(column(each_frame[30:32], 0), column(each_frame[30:32], 1), column(each_frame[30:32], 2), cmap='jet', marker='o', label='body joint', color = 'blue')
            points_3 = ax.scatter(column(each_frame[32:34], 0), column(each_frame[32:34], 1), column(each_frame[32:34], 2), cmap='jet', marker='o', label='body joint', color = 'red')
            
            # ax.scatter(column(each_frame, 0), column(each_frame, 1), column(each_frame, 2), cmap='jet', marker='o', label='body joint')
            # ax.legend()
            # print("+++")
            
            parents = [[0, 1], [1, 3], [3, 2], [0, 2],#head
                        [8, 6], [6, 13], [13, 4], [4, 8],#shoulder
                        [6, 4], [4, 5], [5, 7], [7, 6],#Upper torso
                        [8, 18], [8, 20], [13, 21], [13, 19],
                        [5, 20], [5, 21], [7, 18], [7, 19],
                        [18, 19], [19, 21], [21, 20], [20, 18], #waist
                        [18, 22], [20, 22], [22, 23], [22, 25], [23, 25], [24,23], [24, 25],  #right lag
                        [21, 26], [19, 26], [26, 27], [26, 29], [27, 29], [28, 27], [28, 29], #left lag
                        [8, 9], [9, 11], [9, 10], [10, 11], [10, 12], [9, 12], [11, 12], #right hand
                        [13, 14], [14, 16], [14, 15], [16, 15], [14, 17], [16, 17], [15, 17], #left hand
                        [30, 31], [32, 33],  #instrument
                        # [31, 33], [30, 32], [30, 31], [32, 33], [31, 32], [30, 33] #instrument
                        ]
            lines = []
            # draw line
            
            # lines = [ax.plot([each_frame[vs][0], each_frame[ve][0]],
            #                  [each_frame[vs][1], each_frame[ve][1]],
            #                  [each_frame[vs][2], each_frame[ve][2]]) for (vs, ve) in parents]
            line_num = len(parents)
            for idx, each_line in enumerate(parents):
                vec_start = each_frame[each_line[0]]
                vec_end = each_frame[each_line[1]]
                # print(vec_start)
                # print(vec_end)
                line_color = "black"
                if idx == line_num-2:
                    line_color = "blue"
                if idx == line_num-1:
                    line_color = "red"
                # ax.plot([vec_start[0], vec_end[0]], [vec_start[1], vec_end[1]], [vec_start[2], vec_end[2]])
                
                temp, = ax.plot([vec_start[0], vec_end[0]], [vec_start[1], vec_end[1]], [vec_start[2], vec_end[2]], color=line_color)
                lines.append(temp)

            # ax.figure.savefig('./test_pic/pic' + str(frame_index) + '.png', dpi=100, bbox_inches = 'tight')

            # ims.append([points])
            # image_frame = [points].extend(lines)
            ims.append([points]+[points_2]+[points_3]+lines) #TODO: try extend

            # plt.cla()
            # print("+++")

    anim = matplotlib.animation.ArtistAnimation(fig, ims, interval=1000/fps)

    if output.endswith('.mp4'):
        FFwriter = matplotlib.animation.FFMpegWriter(fps=fps, extra_args=['-vcodec', 'libx264'])
        anim.save(output, writer=FFwriter)
    elif output.endswith('.gif'):
        anim.save(output, fps=fps, dpi=100, writer='imagemagick')
    else:
        raise ValueError('Unsupported output format (only .mp4 and .gif are supported)')

In [None]:
def plot(audio_path, plot_path, prediction, sample_time, fps, name=""): #audio_path, plot_path, 
    # render_animation(fps, output='new_temp.mp4', azim=75, prediction=prediction)
    test_render_animation(fps, output='new_temp_' + name + '.mp4', azim=75, prediction=prediction)

    # # #merge with wav
    input_video = ffmpeg.input('new_temp_' + name + '.mp4')
    fluid_syn = FluidSynth()
    fluid_syn.midi_to_audio(audio_path, './output' + name + '.wav')
    input_audio = ffmpeg.input('./output' + name + '.wav')
    # output = ffmpeg.output(video, audio, plot_path, vcodec='copy', acodec='aac', strict='experimental')
    ffmpeg.concat(input_video, input_audio, v=1, a=1).output(plot_path).run()
    # os.remove('new_temp_' + name + '.mp4')

In [None]:
model.eval()

full_prediction = pd.DataFrame()
num_count = 0
# read midi
# test_dataloader = get_dataloader(test_datapath, batch_size=1)
for test_batch in test_data_list:
    with torch.no_grad():
        # first_target = torch.zeros(test_batch.shape[0],112)
        # print(first_target.shape)
        test_input = test_batch[None, :]
        # test_target = first_target[None, :]
        print("test_input", test_input.shape)
        # print("test_target", test_target.shape)
        prediction = predict(model, test_input, device)
        
        # print(prediction.shape)
        
        prediction  = prediction[:, :, :102]
        print("prediction.shape", prediction.shape)
        
        # full_prediction.append(prediction)
        full_prediction = pd.DataFrame(prediction[0])
        print("full_prediction", full_prediction.shape)
        
        # prev_prediction = prediction[0][:-1][None, :]
        # print(prev_prediction.shape)
        
        Row_list_prediction =[]
        
        filecode = test_music_list[num_count]
    
        # Iterate over each row
        for index, rows in full_prediction.iterrows():
            #fill nan
            rows = rows.fillna(0)
            # Create list for the current row
            my_list = rows.values.tolist()
            # print(my_list)
            
            my_list_per3 = [my_list[i:i+3] for i in range(0, len(my_list), 3)]
            # append the list to the final list
            Row_list_prediction.append(my_list_per3)

        # print(len(Row_list_prediction), len(Row_list_prediction[0]),len(Row_list_prediction[0][0]))
        plot(test_datapath + test_music_list[num_count] + ".mid", "./video_" + filecode + "_test_predict.mp4", Row_list_prediction[:800], None, 40, filecode) #ow_list[0:900]
        # print("prediction.shape", prediction.shape)
        prediction_arr = np.array(Row_list_prediction)
        # formated_motion = prediction_format(full_prediction)
        # # # plot(formated_motion)
        # audio_path = test_music_list[num_count][0]
        # output_path = "test_output_" + filecode + ".mp4"
        # plot(formated_motion, audio_path, output_path, None, 10, filecode)
        num_count += 1

# model.train()

test_input (1, 8171, 128)
prediction.shape (1, 8171, 102)
full_prediction (8171, 102)


81600
limit 800
[[[ 0.01066413  0.08593887  0.60282156]
  [-0.01280998 -0.00583635  0.61125532]
  [ 0.04083589  0.00393716  0.59273705]
  ...
  [-0.03701151  0.01846465  0.55291889]
  [-0.10076261 -0.01230686  0.6193606 ]
  [ 0.01776977  0.15142721  0.46714631]]

 [[ 0.03879211  0.14804189  1.0754324 ]
  [-0.00892204  0.04889368  1.09128056]
  [ 0.09629098  0.02398731  1.06272302]
  ...
  [-0.02065109  0.07489548  0.97945682]
  [-0.16371697  0.03421995  1.0937458 ]
  [ 0.0439385   0.28116199  0.78792391]]

 [[ 0.07446319  0.14656338  1.23176453]
  [ 0.01926645  0.10105406  1.24883649]
  [ 0.14248177  0.05786384  1.22012243]
  ...
  [ 0.01223802  0.11877581  1.10784171]
  [-0.16243598  0.07758166  1.2417346 ]
  [ 0.07986818  0.30376053  0.86332349]]

 ...

 [[ 0.08605392  0.12892547  1.09026489]
  [ 0.01502454  0.07597429  1.11031756]
  [ 0.12479843  0.05232498  1.09718392]
  ...
  [ 0.02310273  0.09360582  0.98372654]
  [-0.12351809  0.08709976  1.10302112]
  [ 0.09554765  0.25461343  

KeyboardInterrupt: 

In [None]:
model.eval()

full_prediction = pd.DataFrame()
num_count = 0
# read midi
# test_dataloader = get_dataloader(test_datapath, batch_size=1)
for test_batch in test_data_list:
    with torch.no_grad():
        first_target = torch.zeros(test_batch.shape[0],115)
        print(first_target.shape)
        test_input = test_batch[None, :]
        test_target = first_target[None, :]
        print("test_input", test_input.shape)
        print("test_target", test_target.shape)
        prediction = predict(model, test_input, device)
        
        # print(prediction.shape)
        
        prediction  = prediction[:, :, :102]
        print("prediction.shape", prediction.shape)
        
        # full_prediction.append(prediction)
        full_prediction = pd.DataFrame(prediction[0])
        print("full_prediction", full_prediction.shape)
        
        # prev_prediction = prediction[0][:-1][None, :]
        # print(prev_prediction.shape)
        
        Row_list_prediction =[]
        
        filecode = test_music_list[num_count]
    
        # Iterate over each row
        for index, rows in full_prediction.iterrows():
            #fill nan
            rows = rows.fillna(0)
            # Create list for the current row
            my_list = rows.values.tolist()
            # print(my_list)
            
            my_list_per3 = [my_list[i:i+3] for i in range(0, len(my_list), 3)]
            # append the list to the final list
            Row_list_prediction.append(my_list_per3)

        prediction_arr = np.array(Row_list_prediction)
        if not os.path.exists('./output_prediction/[midi_with_anno]'+str(num_layers)+'LSTM_hidden'+str(hidden_size)+'_'+str(num_epochs)+'epoch/'):
            os.makedirs('./output_prediction/[midi_with_anno]'+str(num_layers)+'LSTM_hidden'+str(hidden_size)+'_'+str(num_epochs)+'epoch/')
        midi_data_output = open('./output_prediction/[midi_with_anno]'+str(num_layers)+'LSTM_hidden'+str(hidden_size)+'_'+str(num_epochs)+'epoch/prediction_'+
                                filecode +'.pkl', 'wb')
        pickle.dump(prediction_arr, midi_data_output)
        midi_data_output.close()
        
        num_count += 1

# model.train()

torch.Size([8171, 115])
test_input (1, 8171, 128)
test_target torch.Size([1, 8171, 115])
prediction.shape (1, 8171, 102)
full_prediction (8171, 102)
torch.Size([11537, 115])
test_input (1, 11537, 128)
test_target torch.Size([1, 11537, 115])
prediction.shape (1, 11537, 102)
full_prediction (11537, 102)
torch.Size([6993, 115])
test_input (1, 6993, 128)
test_target torch.Size([1, 6993, 115])
prediction.shape (1, 6993, 102)
full_prediction (6993, 102)
torch.Size([7897, 115])
test_input (1, 7897, 128)
test_target torch.Size([1, 7897, 115])
prediction.shape (1, 7897, 102)
full_prediction (7897, 102)


In [None]:
# final_val_loss = evaluate_lstm(model, val_dataloader)

In [None]:
# print(final_val_loss)