In [None]:
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader
from typing import Sequence
import numpy as np

from audio import get_clean_tensor,get_crunch_tensor,get_distortion_tensor,normalize_tensor
from utilities import plot_waveform

In [None]:
dry = get_clean_tensor()
dry = normalize_tensor(dry)
# plot_waveform(dry)
crunch = get_crunch_tensor()
# plot_waveform(crunch)
distortion = get_distortion_tensor()
# plot_waveform(distortion)

seconds = 1
test_split_ratio = 0.2

train_time_seconds = 24
val_time_seconds = 6
train_samples = 44_100 * train_time_seconds
val_samples = 44_100 * val_time_seconds

x = dry[0]
y = crunch[0]

In [None]:
window_size = 100
batch_size = 2000

In [None]:
x_train = torch.concat((torch.zeros(window_size - 1),x[:train_samples]))
y_train = torch.concat((torch.zeros(window_size - 1),y[:train_samples]))

x_val = torch.concat((torch.zeros(window_size - 1), x[train_samples:train_samples+val_samples]))
y_val = torch.concat((torch.zeros(window_size - 1), y[train_samples:train_samples+val_samples]))

batches = int(x_train.size(0) / batch_size)
print(batches)

In [None]:
class WindowArray(Sequence):
        
    def __init__(self, x, y, window_len, batch_size=32):
        self.x = x
        self.y = y[window_len-1:] 
        self.window_len = window_len
        self.batch_size = batch_size
        
    def __len__(self):
        return (len(self.x) - self.window_len +1) // self.batch_size
    
    def __getitem__(self, index):
        x_out = torch.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)]).view(self.batch_size,self.window_len,-1)
        y_out = self.y[index*self.batch_size:(index+1)*self.batch_size].view(-1,1)
        return x_out, y_out

In [None]:
a = torch.tensor([1,2,3,4,5,6,7,8,9,10])
b = torch.tensor([1,2,3,4,5,6,7,8,9,10])
aw = WindowArray(a,b,6,3)
abdl = DataLoader(aw)

In [None]:
train_window = WindowArray(x_train,y_train,window_size,batch_size)
val_window = WindowArray(x_val,y_val,window_size,batch_size)

In [None]:
train_loader = DataLoader(train_window)
val_loader = DataLoader(val_window)

In [None]:
class LSTM(nn.Module):
    def __init__(self, n_hidden, n_layers):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(
            input_size=1, 
            hidden_size=n_hidden,
            num_layers=n_layers,
            batch_first=True
        )

        self.out = nn.Linear(n_hidden,1)

    def forward(self, x, state=None):
        r_out, (h_s, c_s) = self.lstm(x, state)
        result = self.out(r_out[:,-1,:])
        return result, h_s, c_s

In [None]:
from loss import ESRDCLoss

epochs = 10

rnn = LSTM(64,1)
optimizer = torch.optim.Adam(rnn.parameters(),lr=0.01)
loss_fn = ESRDCLoss()

In [None]:
epochs_losses = []
for epoch in range(epochs):
    print('epoch: %d' % epoch)
    rnn.train()
    losses = []
    for i,(x_b,y_b) in enumerate(train_loader):
        
        print('batch: ', i, '/', batches, end='\r')
        pred,_,_ = rnn(x_b[0])
        
        loss = loss_fn(pred,y_b)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        print('batch: ', i, '/', batches, ' loss: ', loss.item(), end='\r')
    print()

    mean_loss = np.mean(losses)
    epochs_losses.append(mean_loss)

    rnn.eval()

    with torch.no_grad():
        val_losses = []
        for xv,yv in val_loader:
            test,_,_ = rnn(xv[0])
            val_loss = loss_fn(test,yv)
            val_losses.append(val_loss.item())

        epoch_val_loss = np.mean(val_losses)
        print('val_loss: ', epoch_val_loss)