In [1]:
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader
from typing import Sequence
import numpy as np

from audio import get_clean_tensor,get_crunch_tensor,get_distortion_tensor,normalize_tensor
from utilities import plot_waveform
import styles_ranges as sr

In [37]:
dry = get_clean_tensor()
dry = normalize_tensor(dry)
# plot_waveform(dry)
crunch = get_crunch_tensor()
# plot_waveform(crunch)
distortion = get_distortion_tensor()
# plot_waveform(distortion)

seconds = 1
test_split_ratio = 0.2

train_time_seconds = 24
val_time_seconds = 6
train_samples_start = sr.POWER_CHORDS_RING_OUT_START
train_samples_end = sr.CHORDS_ARPEGGIO_END

val_samples_start = sr.PENTATONIC_FAST_START
val_samples_end = sr.PENTATONIC_FAST_END

x = dry[0]
y = crunch[0]

In [38]:
window_size = 400
batch_size = 2000

In [39]:
x_train = torch.concat((torch.zeros(window_size - 1),x[train_samples_start:train_samples_end]))
y_train = torch.concat((torch.zeros(window_size - 1),y[train_samples_start:train_samples_end]))

x_val = torch.concat((torch.zeros(window_size - 1), x[val_samples_start:val_samples_end]))
y_val = torch.concat((torch.zeros(window_size - 1), y[val_samples_start:val_samples_end]))

batches = int(x_train.size(0) / batch_size)
val_batches = int(x_val.size(0) / batch_size)
print(batches)
print(val_batches)

3380
558


In [23]:
class WindowArray(Sequence):
        
    def __init__(self, x, y, window_len, batch_size=32):
        self.x = x
        self.y = y[window_len-1:] 
        self.window_len = window_len
        self.batch_size = batch_size
        
    def __len__(self):
        return (len(self.x) - self.window_len +1) // self.batch_size
    
    def __getitem__(self, index):
        x_out = torch.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)]).view(self.batch_size,self.window_len,-1)
        y_out = self.y[index*self.batch_size:(index+1)*self.batch_size].view(-1,1)
        return x_out, y_out

In [14]:
a = torch.tensor([1,2,3,4,5,6,7,8,9,10])
b = torch.tensor([1,2,3,4,5,6,7,8,9,10])
aw = WindowArray(a,b,6,3)
abdl = DataLoader(aw)

In [40]:
train_window = WindowArray(x_train,y_train,window_size,batch_size)
val_window = WindowArray(x_val,y_val,window_size,batch_size)

In [41]:
train_loader = DataLoader(train_window)
val_loader = DataLoader(val_window)

In [26]:
cnn = nn.Conv1d(1,16,11,1,5)
c2 = nn.Conv1d(16,1,11,1,5)
mp = nn.MaxPool1d(5)

In [42]:
for c,d in train_loader:
    print(c.shape)
    print(d.shape)
    print(c2(mp(cnn(c[0].permute(0,2,1)))).shape)
    print(mp(c2(mp(cnn(c[0].permute(0,2,1))))).shape)
    break

torch.Size([1, 2000, 400, 1])
torch.Size([1, 2000, 1])
torch.Size([2000, 1, 80])
torch.Size([2000, 1, 16])


In [43]:
class LSTM(nn.Module):
    def __init__(self, n_hidden, n_layers):
        super(LSTM, self).__init__()
        
        self.conv1 = nn.Conv1d(
            in_channels=1,
            out_channels=16,
            kernel_size=11,
            stride=1,
            padding='same'
        )

        self.relu1 = nn.ReLU()
        self.mp1 = nn.MaxPool1d(5)
        
        self.conv2 = nn.Conv1d(
            in_channels=16,
            out_channels=16,
            kernel_size=11,
            stride=1,
            padding='same'
        )

        self.relu2 = nn.ReLU()
        self.mp2 = nn.MaxPool1d(5)


        self.lstm = nn.LSTM(
            input_size=16, 
            hidden_size=n_hidden,
            num_layers=n_layers,
            batch_first=True
        )

        self.out = nn.Linear(n_hidden,1)

    def forward(self, x, state=None):
        # print('forward start')
        x = x.permute(0,2,1)
        x = self.mp1(self.relu1(self.conv1(x)))
        # print(x.shape)
        x = self.mp2(self.relu2(self.conv2(x)))
        # print(x.shape)
        x = x.permute(0,2,1)
        r_out, (h_s, c_s) = self.lstm(x, state)
        # print(r_out.shape)
        result = self.out(r_out[:,-1,:])
        # print(result.shape)
        # print('forward end')
        return result, h_s, c_s

In [44]:
from loss import ESRDCLoss

epochs = 20

rnn = LSTM(128,2)
optimizer = torch.optim.Adam(rnn.parameters(),lr=0.001)
loss_fn = ESRDCLoss()

In [45]:
losses_path = 'temporary/loss_scores.npy'
val_loss_path = 'temporary/val_loss_scores.npy'
means_path = 'temporary/mean_scores.npy'
loss_scores = np.zeros(shape=(epochs,batches,1))
val_loss_scores = np.zeros(shape=(epochs,val_batches,1))
mean_scores = np.zeros(shape=(epochs,2))
np.save(losses_path,loss_scores)
np.save(val_loss_path,val_loss_scores)
np.save(means_path,mean_scores)

In [None]:
epochs_losses = []
for epoch in range(epochs):
    print('epoch: %d' % epoch)
    rnn.train()
    losses = []
    for i,(x_b,y_b) in enumerate(train_loader):
        
        print('batch: ', i+1, '/', batches, end='\r')
        pred,_,_ = rnn(x_b[0])
        
        loss = loss_fn(pred,y_b[0])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_value = loss.item()

        losses.append(loss_value)
        loss_scores[epoch,i,0] = loss_value
        print('batch: ', i, '/', batches, ' loss: ', loss.item(), end='\r')
    print()
    mean_loss = np.mean(losses)
    mean_scores[epoch,0] = mean_loss
    epochs_losses.append(mean_loss)
    print(mean_loss)
    np.save(losses_path,loss_scores)

    rnn.eval()

    with torch.no_grad():
        val_losses = []
        for j,(xv,yv) in enumerate(val_loader):
            test,_,_ = rnn(xv[0])
            val_loss = loss_fn(test,yv[0])
            vl_value = val_loss.item()
            val_losses.append(vl_value)
            val_loss_scores[epoch,j,0] = vl_value

        np.save(val_loss_path,val_loss_scores)

        epoch_val_loss = np.mean(val_losses)
        mean_scores[epoch,1] = epoch_val_loss
        np.save(means_path,mean_scores)

        print('val_loss: ', epoch_val_loss)

epoch: 0
batch:  289 / 3380  loss:  0.47223877906799316