In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

In [3]:
import sys, os
import glob
sys.path.append("./midi_utils/")
sys.path.append("../")

In [4]:
import midi_dataset

In [5]:
import midi_io
import midi_transform

In [6]:
MODEL_LENGTH = 16
HIDDEN_DIM = 2
LATENT_DIM = 2

In [19]:
# 1 bar : 2 x subseq = 16 x 16분음표

class MidiEncoder(nn.Module):
    def __init__(self):
        super(MidiEncoder, self).__init__()
        self.lstm = nn.LSTM(input_size=128, hidden_size=HIDDEN_DIM, num_layers=1, batch_first=True, bidirectional=True)
    
    def forward(self, x):
        out, (h, c) = self.lstm(x)
        print(out.shape, h.shape, c.shape)
        print(out, h, c)
        forward_last = out[:, MODEL_LENGTH - 1, :HIDDEN_DIM]
        backward_last = out[:, 0, HIDDEN_DIM:]
        
        return torch.cat([forward_last, backward_last], dim=1)

In [8]:

class MidiParameterize(nn.Module):
    def __init__(self):
        super(MidiParameterize, self).__init__()
        self.linear_mu = nn.Linear(HIDDEN_DIM*2, LATENT_DIM)
        self.linear_sigma = nn.Linear(HIDDEN_DIM*2, LATENT_DIM)
        self.softplus = nn.Softplus()
    
    def forward(self, x):
        mu = self.linear_mu(x)
        sigma = self.softplus(self.linear_sigma(x))
        return mu, sigma

In [20]:
enc = MidiEncoder()

In [10]:
midi_list = glob.glob("../midi1/*.mid")

In [11]:
transform = transforms.Compose([
    midi_transform.get_piano_roll,
    midi_transform.random_crop_midi,
    torch.FloatTensor,
])

In [12]:
midi_dset = midi_dataset.MidiDataset(midi_list, transform=transform)

In [13]:
sample = midi_dset[0]

In [14]:
sample = sample.unsqueeze(0)

In [15]:
sample = sample.permute(0,2,1)

In [16]:
sample.shape

torch.Size([1, 16, 128])

In [21]:
out = enc(sample)
out.shape

torch.Size([1, 16, 4]) torch.Size([2, 1, 2]) torch.Size([2, 1, 2])
tensor([[[ 2.6518e-12,  7.6087e-01,  1.0000e+00, -4.0833e-28],
         [ 3.8277e-12,  9.6381e-01,  1.0000e+00, -4.0833e-28],
         [ 4.2214e-12,  9.9501e-01,  1.0000e+00, -4.0833e-28],
         [ 4.2854e-12,  9.9932e-01,  1.0000e+00, -4.0833e-28],
         [ 4.2943e-12,  9.9991e-01,  1.0000e+00, -4.0833e-28],
         [ 4.2955e-12,  9.9999e-01,  1.0000e+00, -4.0833e-28],
         [ 4.2957e-12,  1.0000e+00,  1.0000e+00, -4.0833e-28],
         [ 4.2957e-12,  1.0000e+00,  1.0000e+00, -4.0833e-28],
         [ 4.2957e-12,  1.0000e+00,  1.0000e+00, -4.0833e-28],
         [ 4.2957e-12,  1.0000e+00,  1.0000e+00, -4.0833e-28],
         [ 4.2957e-12,  1.0000e+00,  9.9999e-01, -4.0835e-28],
         [ 4.2957e-12,  1.0000e+00,  9.9991e-01, -4.0844e-28],
         [ 4.2957e-12,  1.0000e+00,  9.9933e-01, -4.0914e-28],
         [ 4.2957e-12,  1.0000e+00,  9.9505e-01, -4.1421e-28],
         [ 4.2957e-12,  1.0000e+00,  9.6403e-01, -4

torch.Size([1, 4])

In [18]:
out

tensor([[ 7.1888e-01,  4.0038e-06, -7.6087e-01, -9.9999e-01]])

In [22]:
param = MidiParameterize()

In [23]:
param(out)

(tensor([[-0.0343,  0.2937]]), tensor([[ 0.3499,  0.4801]]))